init
This commit is contained in:
184
torch_vacc/vacc/__init__.py
Normal file
184
torch_vacc/vacc/__init__.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ._device import (
|
||||
current_device,
|
||||
device,
|
||||
device_count,
|
||||
get_device_capability,
|
||||
get_device_name,
|
||||
get_device_properties,
|
||||
is_available,
|
||||
is_bf16_supported,
|
||||
set_device,
|
||||
synchronize,
|
||||
)
|
||||
from .amp import (
|
||||
get_amp_supported_dtype,
|
||||
get_autocast_dtype,
|
||||
is_autocast_enabled,
|
||||
set_autocast_dtype,
|
||||
set_autocast_enabled,
|
||||
)
|
||||
from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init
|
||||
from .memory import ( # caching_allocator_alloc,; caching_allocator_delete,
|
||||
empty_cache,
|
||||
get_allocator_backend,
|
||||
max_memory_allocated,
|
||||
max_memory_cached,
|
||||
max_memory_reserved,
|
||||
mem_get_info,
|
||||
memory_allocated,
|
||||
memory_cached,
|
||||
memory_reserved,
|
||||
memory_snapshot,
|
||||
memory_stats,
|
||||
memory_stats_as_nested_dict,
|
||||
memory_summary,
|
||||
reset_accumulated_memory_stats,
|
||||
reset_max_memory_allocated,
|
||||
reset_max_memory_cached,
|
||||
reset_peak_memory_stats,
|
||||
set_per_process_memory_fraction,
|
||||
)
|
||||
from .streams import Event, Stream, current_stream, default_stream, set_stream, stream
|
||||
|
||||
|
||||
def init():
|
||||
r"""Initialize PyTorch's VACC state. You may need to call
|
||||
this explicitly if you are interacting with PyTorch via
|
||||
its C API, as Python bindings for VACC functionality will not
|
||||
be available until this initialization takes place. Ordinary users
|
||||
should not need this, as all of PyTorch's VACC methods
|
||||
automatically initialize VACC state on-demand.
|
||||
|
||||
Does nothing if the VACC state is already initialized.
|
||||
"""
|
||||
_lazy_init()
|
||||
|
||||
|
||||
# default_generators is empty util _lazy_init() is called
|
||||
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
|
||||
from .custom_ops import *
|
||||
from .custom_qwen3_ops import *
|
||||
from .random import * # noqa: F403
|
||||
|
||||
__all__ = [
|
||||
"device",
|
||||
"is_available",
|
||||
"is_bf16_supported",
|
||||
"current_device",
|
||||
"set_device",
|
||||
"device_count",
|
||||
"get_device_properties",
|
||||
"get_device_name",
|
||||
"get_device_capability",
|
||||
"synchronize",
|
||||
"amp",
|
||||
"get_amp_supported_dtype",
|
||||
"is_autocast_enabled",
|
||||
"set_autocast_enabled",
|
||||
"get_autocast_dtype",
|
||||
"set_autocast_dtype",
|
||||
"_is_in_bad_fork",
|
||||
"_lazy_call",
|
||||
"get_rng_state",
|
||||
"get_rng_state_all",
|
||||
"set_rng_state",
|
||||
"set_rng_state_all",
|
||||
"manual_seed",
|
||||
"manual_seed_all",
|
||||
"seed",
|
||||
"seed_all",
|
||||
"initial_seed",
|
||||
"set_stream",
|
||||
"current_stream",
|
||||
"default_generators",
|
||||
"default_stream",
|
||||
"stream",
|
||||
"Stream",
|
||||
"Event",
|
||||
"mem_get_info",
|
||||
"set_per_process_memory_fraction",
|
||||
"empty_cache",
|
||||
"memory_stats",
|
||||
"memory_stats_as_nested_dict",
|
||||
"reset_accumulated_memory_stats",
|
||||
"reset_peak_memory_stats",
|
||||
"reset_max_memory_allocated",
|
||||
"reset_max_memory_cached",
|
||||
"memory_allocated",
|
||||
"max_memory_allocated",
|
||||
"memory_reserved",
|
||||
"max_memory_reserved",
|
||||
"memory_cached",
|
||||
"max_memory_cached",
|
||||
"memory_snapshot",
|
||||
"memory_summary",
|
||||
"get_allocator_backend",
|
||||
"rms_norm",
|
||||
"RotaryPosEmbedding",
|
||||
"scaled_dot_product_attention",
|
||||
"scaled_dot_product_attention_cp_forward",
|
||||
"scaled_dot_product_attention_cp_backward",
|
||||
"swiglu",
|
||||
"paged_attention",
|
||||
"reshape_and_cache_attention",
|
||||
"concat_and_cache_attention",
|
||||
"w8a8_block_fp8_matmul",
|
||||
"moe_expert_token_group_reassign",
|
||||
"fused_mlp_mm_fp8",
|
||||
"fused_mlp_fp8",
|
||||
"fused_moe_preprocess",
|
||||
"fused_residual_rmsnorm",
|
||||
"parallel_embedding",
|
||||
"all_reduce",
|
||||
"all_gather",
|
||||
"broadcast",
|
||||
"fused_mlp_moe_with_rmsnorm",
|
||||
"fuse_moe_decode_v2_allreduce",
|
||||
"topk_topp",
|
||||
"fused_mla",
|
||||
"fused_mla_allreduce",
|
||||
"fused_mlp_with_rmsnorm",
|
||||
"fused_mlp_allreduce",
|
||||
"ds3_sampler",
|
||||
"sampler_v1",
|
||||
"rejection_sampler",
|
||||
"rejection_sampler_update_hidden_states",
|
||||
"rejection_sampler_v1",
|
||||
"fused_matmul_allgather",
|
||||
"fused_mla_v2",
|
||||
"fused_mla_allreduce_v2",
|
||||
"mla_matmul_scale",
|
||||
"mla_matmul",
|
||||
"fused_mla_prefill_stage0",
|
||||
"fused_mla_prefill_stage1",
|
||||
"fused_mla_prefill_stage0_allreduce",
|
||||
"fuse_moe_prefill_stage0",
|
||||
"fuse_mla_mlp_v2_allreduce_decode",
|
||||
"fuse_mla_moe_v2_allreduce_decode",
|
||||
"fuse_mla_mlp_v2_allreduce_decode_layers",
|
||||
"fuse_mla_moe_v2_allreduce_decode_layers",
|
||||
"fuse_mla_mlp_v2_allreduce_decode_layers_v2",
|
||||
"fuse_mla_moe_v2_allreduce_decode_layers_v2",
|
||||
"fuse_mlp_qwen_int4",
|
||||
"fuse_mlp_qwen_int4_reduce",
|
||||
"w4a8_block_int4_matmul",
|
||||
"fuse_atten_qwen3",
|
||||
"fuse_atten_qwen2",
|
||||
"qwen3_fuse_attention_moe_decode",
|
||||
"fuse_mtp_stage0",
|
||||
"fuse_mtp_allreduce",
|
||||
"roll_out",
|
||||
"fused_experts_int4_prefill",
|
||||
"fuse_bge_embedding_stage1",
|
||||
"l2_norm",
|
||||
"fuse_mlp_vision",
|
||||
"patch_merger_vision",
|
||||
"fuse_atten_vit",
|
||||
"apply_penalties",
|
||||
]
|
||||
BIN
torch_vacc/vacc/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/_device.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/_device.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/custom_ops.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/custom_ops_cpu.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/custom_ops_cpu.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/custom_qwen3_ops.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/custom_qwen3_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/lazy_initialize.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/lazy_initialize.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/memory.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/memory.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/random.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/random.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/streams.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/streams.cpython-312.pyc
Normal file
Binary file not shown.
106
torch_vacc/vacc/_device.py
Normal file
106
torch_vacc/vacc/_device.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Device information
|
||||
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
|
||||
# see https://pytorch.org/docs/stable/cuda.html
|
||||
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch_vacc
|
||||
from torch._utils import _get_device_index
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
from .lazy_initialize import _lazy_init
|
||||
|
||||
if hasattr(_torch_vacc, "_exchange_device"):
|
||||
_exchange_device = _torch_vacc._exchange_device
|
||||
else:
|
||||
|
||||
def _exchange_device(device: int) -> int:
|
||||
return _torch_vacc._exchange_device()
|
||||
if device < 0:
|
||||
return -1
|
||||
prev_device = current_device()
|
||||
if device != prev_device:
|
||||
set_device(device)
|
||||
return prev_device
|
||||
|
||||
|
||||
class device(object):
|
||||
"""Context-manager that changes the selected device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int): device index to select. It's a no-op if
|
||||
this argument is a negative integer or ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, device: Any):
|
||||
self.idx = _get_device_index(device, optional=True)
|
||||
self.prev_idx = -1
|
||||
|
||||
def __enter__(self):
|
||||
self.prev_idx = _exchange_device(self.idx)
|
||||
|
||||
def __exit__(self, *args):
|
||||
_exchange_device(self.prev_idx)
|
||||
return False
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns whether vacc is available."""
|
||||
return device_count() > 0
|
||||
|
||||
def is_bf16_supported() -> bool:
|
||||
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
|
||||
return True
|
||||
|
||||
def current_device() -> int:
|
||||
r"""Returns the index of a currently selected vacc device."""
|
||||
_lazy_init()
|
||||
return _torch_vacc._current_device()
|
||||
|
||||
|
||||
def set_device(device: torch.device):
|
||||
device_index = _get_device_index(device, optional=True)
|
||||
if device_index >= 0:
|
||||
_torch_vacc._set_device(device_index)
|
||||
|
||||
|
||||
def get_device_capability(device=None):
|
||||
r"""Query the minor and major data of device. Cann does not
|
||||
have a corresponding concept and is not supported. By default, it returns None
|
||||
"""
|
||||
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
|
||||
raise AssertionError(_infos)
|
||||
|
||||
|
||||
def get_device_name(device_name=None):
|
||||
device_id = _get_device_index(device_name, optional=True)
|
||||
if device_id < 0 or device_id >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
_lazy_init()
|
||||
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
|
||||
return device_prop.name
|
||||
|
||||
|
||||
def get_device_properties(device_name=None):
|
||||
device_id = _get_device_index(device_name, optional=True)
|
||||
if device_id < 0 or device_id >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
_lazy_init()
|
||||
return _torch_vacc._vacc_getDeviceProperties(device_id)
|
||||
|
||||
|
||||
def device_count():
|
||||
r"""Returns the number of available vacc devices"""
|
||||
return _torch_vacc._device_count()
|
||||
|
||||
|
||||
def synchronize(device=None) -> None:
|
||||
"""Waits for all operations in all streams on a VACC device to complete."""
|
||||
_lazy_init()
|
||||
with torch_vacc.vacc.device(device):
|
||||
return _torch_vacc._device_synchronize()
|
||||
|
||||
|
||||
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)
|
||||
26
torch_vacc/vacc/amp/__init__.py
Normal file
26
torch_vacc/vacc/amp/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
from .grad_scaler import OptState, GradScaler
|
||||
from .autocast_mode import autocast, custom_fwd, custom_bwd
|
||||
|
||||
|
||||
def get_amp_supported_dtype() -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def is_autocast_enabled() -> bool:
|
||||
return _torch_vacc.is_autocast_enabled()
|
||||
|
||||
|
||||
def set_autocast_enabled(enable: bool):
|
||||
_torch_vacc.set_autocast_enabled(enable)
|
||||
|
||||
|
||||
def get_autocast_dtype() -> torch.dtype:
|
||||
return _torch_vacc.get_autocast_dtype()
|
||||
|
||||
|
||||
def set_autocast_dtype(dtype: torch.dtype):
|
||||
return _torch_vacc.set_autocast_dtype(dtype)
|
||||
BIN
torch_vacc/vacc/amp/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/amp/__pycache__/autocast_mode.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/autocast_mode.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/amp/__pycache__/common.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/common.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/amp/__pycache__/grad_scaler.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/grad_scaler.cpython-312.pyc
Normal file
Binary file not shown.
144
torch_vacc/vacc/amp/autocast_mode.py
Normal file
144
torch_vacc/vacc/amp/autocast_mode.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import collections
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
HAS_NUMPY = True
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
||||
|
||||
|
||||
class autocast(torch.amp.autocast_mode.autocast):
|
||||
r"""See :class:`torch.autocast`.
|
||||
|
||||
``torch.vacc.amp.autocast(args...)`` is equivalent to ``torch.autocast("vacc", args...)``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
cache_enabled: bool = True,
|
||||
):
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = "vacc"
|
||||
self.fast_dtype = dtype
|
||||
return
|
||||
super().__init__(
|
||||
"vacc", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
if torch._jit_internal.is_scripting():
|
||||
return self
|
||||
return super().__enter__()
|
||||
|
||||
# TODO: discuss a unified TorchScript-friendly API for autocast
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
||||
if torch._jit_internal.is_scripting():
|
||||
return
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __call__(self, func):
|
||||
if torch._jit_internal.is_scripting():
|
||||
return func
|
||||
return super().__call__(func)
|
||||
|
||||
|
||||
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
|
||||
# may be falsely detected as "Iterables."
|
||||
def _cast(value, dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (
|
||||
value.is_floating_point()
|
||||
and value.is_vacc
|
||||
and (value.dtype is not torch.float64)
|
||||
)
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Mapping):
|
||||
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
||||
elif isinstance(value, collections.abc.Iterable):
|
||||
iterable = (_cast(v, dtype) for v in value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return type(value)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
# custom_fwd is a decorator that may or may not be used with arguments, following
|
||||
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
|
||||
# this works:
|
||||
# @custom_fwd
|
||||
# def forward(...):
|
||||
# this also works:
|
||||
# @custom_fwd(cast_inputs=torch.float)
|
||||
# def forward(...):
|
||||
def custom_fwd(fwd=None, *, cast_inputs=None):
|
||||
"""
|
||||
Create a helper decorator for ``forward`` methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||
floating-point VACC Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||
then executes ``forward`` with autocast disabled.
|
||||
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||
|
||||
.. note::
|
||||
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||
"""
|
||||
if fwd is None:
|
||||
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
|
||||
|
||||
@functools.wraps(fwd)
|
||||
def decorate_fwd(*args, **kwargs):
|
||||
args[0]._dtype = torch.get_autocast_gpu_dtype()
|
||||
if cast_inputs is None:
|
||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
|
||||
return fwd(*args, **kwargs)
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled()
|
||||
args[0]._fwd_used_autocast = False
|
||||
if autocast_context:
|
||||
with autocast(enabled=False):
|
||||
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||
else:
|
||||
return fwd(*args, **kwargs)
|
||||
|
||||
return decorate_fwd
|
||||
|
||||
|
||||
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
||||
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
||||
# cast_inputs supplied to custom_fwd.
|
||||
def custom_bwd(bwd):
|
||||
"""Create a helper decorator for backward methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
"""
|
||||
|
||||
@functools.wraps(bwd)
|
||||
def decorate_bwd(*args, **kwargs):
|
||||
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
|
||||
return bwd(*args, **kwargs)
|
||||
|
||||
return decorate_bwd
|
||||
7
torch_vacc/vacc/amp/common.py
Normal file
7
torch_vacc/vacc/amp/common.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import torch
|
||||
|
||||
__all__ = ["amp_definitely_not_available"]
|
||||
|
||||
|
||||
def amp_definitely_not_available():
|
||||
return not torch.vacc.is_available()
|
||||
667
torch_vacc/vacc/amp/grad_scaler.py
Normal file
667
torch_vacc/vacc/amp/grad_scaler.py
Normal file
@@ -0,0 +1,667 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from collections import abc, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from .common import amp_definitely_not_available
|
||||
|
||||
|
||||
__all__ = ["OptState", "GradScaler"]
|
||||
|
||||
|
||||
class _MultiDeviceReplicator:
|
||||
"""
|
||||
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert (
|
||||
master_tensor.is_cuda
|
||||
or master_tensor.device.type == "xla"
|
||||
or master_tensor.device.type == "vacc"
|
||||
)
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
def get(self, device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
if retval is None:
|
||||
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||
self._per_device_tensors[device] = retval
|
||||
return retval
|
||||
|
||||
|
||||
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||
# as well as associated "enum" values. Prefers defining these at top level because
|
||||
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||
# causes a circular reference, which we'd rather avoid.
|
||||
class OptState(Enum):
|
||||
READY = 0
|
||||
UNSCALED = 1
|
||||
STEPPED = 2
|
||||
|
||||
|
||||
def _refresh_per_optimizer_state():
|
||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||
|
||||
|
||||
class GradScaler:
|
||||
_scale: Optional[torch.Tensor]
|
||||
_grows_tracker: Optional[torch.Tensor]
|
||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||
"""
|
||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||
conveniently.
|
||||
|
||||
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||
|
||||
Example::
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
scaler = GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales gradients of the optimizer's params.
|
||||
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||
# otherwise, optimizer.step() is skipped.
|
||||
scaler.step(optimizer)
|
||||
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||
and multiple losses/optimizers.
|
||||
|
||||
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||
without incurring inf or NaN gradient values.
|
||||
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||
|
||||
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||
|
||||
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||
``growth_factor``.
|
||||
|
||||
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_scale=2.0**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True,
|
||||
):
|
||||
if enabled and amp_definitely_not_available():
|
||||
warnings.warn(
|
||||
"torch.vacc.amp.GradScaler is enabled, but VACC device is not available. Disabling."
|
||||
)
|
||||
self._enabled = False
|
||||
else:
|
||||
self._enabled = enabled
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
||||
self._init_scale = init_scale
|
||||
# self._scale will be lazily initialized during the first call to scale()
|
||||
self._scale = None
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(
|
||||
self, funcname
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, (
|
||||
f"Attempted {funcname} but _scale is None. " + fix
|
||||
)
|
||||
assert self._growth_tracker is not None, (
|
||||
f"Attempted {funcname} but _growth_tracker is None. " + fix
|
||||
)
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self, dev):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
|
||||
self._growth_tracker = torch.full(
|
||||
(), self._init_growth_tracker, dtype=torch.int32, device=dev
|
||||
)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert (
|
||||
outputs.is_cuda
|
||||
or outputs.device.type == "xla"
|
||||
or outputs.device.type == "vacc"
|
||||
)
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||
|
||||
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||
stash: List[
|
||||
_MultiDeviceReplicator
|
||||
] = [] # holds a reference that can be overwritten by apply_scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert (
|
||||
val.is_cuda or val.device.type == "xla" or val.device.type == "vacc"
|
||||
)
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
assert self._scale is not None
|
||||
stash.append(_MultiDeviceReplicator(self._scale))
|
||||
return val * stash[0].get(val.device)
|
||||
elif isinstance(val, abc.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, (list, tuple)):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
if param.grad.is_sparse:
|
||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||
# so we should check the coalesced _values().
|
||||
if param.grad.dtype is torch.float16:
|
||||
param.grad = param.grad.coalesce()
|
||||
to_unscale = param.grad._values()
|
||||
else:
|
||||
to_unscale = param.grad
|
||||
|
||||
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||
per_device_and_dtype_grads[to_unscale.device][
|
||||
to_unscale.dtype
|
||||
].append(to_unscale)
|
||||
|
||||
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(
|
||||
grads,
|
||||
per_device_found_inf.get(device),
|
||||
per_device_inv_scale.get(device),
|
||||
)
|
||||
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
|
||||
.. note::
|
||||
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||
raise RuntimeError(
|
||||
"unscale_() has already been called on this optimizer since the last update()."
|
||||
)
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, inv_scale, found_inf, False
|
||||
)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||
retval = None
|
||||
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
||||
retval = optimizer.step(*args, **kwargs)
|
||||
return retval
|
||||
|
||||
def step(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
:meth:`step` carries out the following two operations:
|
||||
|
||||
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||
|
||||
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||
|
||||
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||
args: Any arguments.
|
||||
kwargs: Any keyword arguments.
|
||||
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError(
|
||||
"Closure use is not currently supported if GradScaler is enabled."
|
||||
)
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError(
|
||||
"step() has already been called since the last update()."
|
||||
)
|
||||
|
||||
retval = None
|
||||
|
||||
if (
|
||||
hasattr(optimizer, "_step_supports_amp_scaling")
|
||||
and optimizer._step_supports_amp_scaling
|
||||
):
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
# it can query its own state, invoke unscale_ on itself, etc
|
||||
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
|
||||
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
|
||||
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
|
||||
# to skip the parameter updates or unscale gradients before updating parameters in
|
||||
# the fused kernel, e.g. `FusedAdamMathFunctor`.
|
||||
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
|
||||
# while the method is expected to be called by users side, i.e. their optimizers.
|
||||
kwargs_ = kwargs
|
||||
has_grad_scaler_kwarg = (
|
||||
"grad_scaler" in inspect.signature(optimizer.step).parameters
|
||||
)
|
||||
if has_grad_scaler_kwarg:
|
||||
warnings.warn(
|
||||
"GradScaler is going to stop passing itself as a keyword argument to the passed "
|
||||
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
|
||||
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
|
||||
FutureWarning,
|
||||
)
|
||||
kwargs_.update({"grad_scaler": self})
|
||||
else:
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self._check_inf_per_device(optimizer)
|
||||
scaler = self._get_scale_async()
|
||||
found_inf = cast(
|
||||
torch.Tensor,
|
||||
sum(
|
||||
[
|
||||
t.to(scaler.device, non_blocking=True)
|
||||
for t in optimizer_state["found_inf_per_device"].values()
|
||||
]
|
||||
),
|
||||
)
|
||||
optimizer.grad_scale = (
|
||||
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||
)
|
||||
optimizer.found_inf = found_inf
|
||||
retval = optimizer.step(*args, **kwargs_)
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
if not has_grad_scaler_kwarg:
|
||||
del optimizer.grad_scale
|
||||
del optimizer.found_inf
|
||||
return retval
|
||||
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self.unscale_(optimizer)
|
||||
|
||||
assert (
|
||||
len(optimizer_state["found_inf_per_device"]) > 0
|
||||
), "No inf checks were recorded for this optimizer."
|
||||
|
||||
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
|
||||
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
|
||||
return retval
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
|
||||
Args:
|
||||
new_scale (float or :class:`torch.vacc.FloatTensor`, optional, default=None): New scale factor.
|
||||
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
|
||||
.. warning::
|
||||
For performance reasons, we do not check the scale factor value to avoid synchronizations,
|
||||
so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
|
||||
you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
|
||||
bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.vacc.FloatTensor with requires_grad=False."
|
||||
# assert isinstance(new_scale, torch.vacc.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert (
|
||||
isinstance(new_scale, torch.Tensor)
|
||||
and new_scale.dtype == torch.float32
|
||||
), reason
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [
|
||||
found_inf.to(device=_scale.device, non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
torch._amp_update_scale_(
|
||||
_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval,
|
||||
)
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _get_scale_async(self):
|
||||
return self._scale
|
||||
|
||||
def get_scale(self):
|
||||
"""
|
||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||
|
||||
.. warning::
|
||||
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||
"""
|
||||
if self._enabled:
|
||||
return (
|
||||
self._init_scale
|
||||
if self._scale is None
|
||||
else self._get_scale_async().item()
|
||||
)
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
def get_growth_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale growth factor.
|
||||
"""
|
||||
return self._growth_factor
|
||||
|
||||
def set_growth_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale growth factor.
|
||||
"""
|
||||
self._growth_factor = new_factor
|
||||
|
||||
def get_backoff_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale backoff factor.
|
||||
"""
|
||||
return self._backoff_factor
|
||||
|
||||
def set_backoff_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale backoff factor.
|
||||
"""
|
||||
self._backoff_factor = new_factor
|
||||
|
||||
def get_growth_interval(self):
|
||||
r"""
|
||||
Returns a Python int containing the growth interval.
|
||||
"""
|
||||
return self._growth_interval
|
||||
|
||||
def set_growth_interval(self, new_interval):
|
||||
r"""
|
||||
Args:
|
||||
new_interval (int): Value to use as the new growth interval.
|
||||
"""
|
||||
self._growth_interval = new_interval
|
||||
|
||||
def _get_growth_tracker(self):
|
||||
if self._enabled:
|
||||
return (
|
||||
self._init_growth_tracker
|
||||
if self._growth_tracker is None
|
||||
else self._growth_tracker.item()
|
||||
)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_enabled(self):
|
||||
r"""
|
||||
Returns a bool indicating whether this instance is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self):
|
||||
r"""
|
||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||
|
||||
* ``"scale"`` - a Python float containing the current scale
|
||||
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||
|
||||
If this instance is not enabled, returns an empty dict.
|
||||
|
||||
.. note::
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return (
|
||||
{
|
||||
"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker(),
|
||||
}
|
||||
if self._enabled
|
||||
else {}
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
|
||||
Args:
|
||||
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError(
|
||||
"The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler."
|
||||
)
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
self._scale.fill_(state_dict["scale"])
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._growth_interval = state_dict["growth_interval"]
|
||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, (
|
||||
"A GradScaler instance may only be pickled at the beginning "
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
)
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state["_init_scale"] = self.get_scale()
|
||||
state["_init_growth_tracker"] = self._get_growth_tracker()
|
||||
state["_scale"] = None
|
||||
state["_growth_tracker"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
|
||||
self._per_optimizer_states[id(optimizer)][
|
||||
"found_inf_per_device"
|
||||
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer):
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
2819
torch_vacc/vacc/custom_ops.py
Normal file
2819
torch_vacc/vacc/custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
306
torch_vacc/vacc/custom_ops_cpu.py
Normal file
306
torch_vacc/vacc/custom_ops_cpu.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from typing import Tuple, Union, Optional, List
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def split_last_two_dims_into_blocks(x, h, w):
|
||||
leading_dims = x.shape[:-2]
|
||||
H, W = x.shape[-2:]
|
||||
assert (
|
||||
H % h == 0 and W % w == 0
|
||||
), "The last two dimensions must be divisible by block size."
|
||||
x_reshaped = x.view(-1, 1, H, W)
|
||||
|
||||
unfolded = F.unfold(x_reshaped, kernel_size=(h, w), stride=(h, w))
|
||||
unfolded = unfolded.view(-1, 1, h, w, H // h, W // w)
|
||||
unfolded = unfolded.permute(0, 1, 4, 5, 2, 3)
|
||||
final_shape = leading_dims + (H // h, W // w, h, w)
|
||||
result = unfolded.view(final_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def merge_blocks_to_original_layout(x, h, w):
|
||||
leading_dims = x.shape[:-4]
|
||||
H_div_h, W_div_w, h, w = x.shape[-4:]
|
||||
H = H_div_h * h
|
||||
W = W_div_w * w
|
||||
|
||||
x_reshaped = x.view(-1, 1, H_div_h, W_div_w, h, w)
|
||||
x_reshaped = x_reshaped.permute(0, 1, 4, 5, 2, 3)
|
||||
x_reshaped = x_reshaped.view(-1, h * w, H_div_h * W_div_w)
|
||||
folded = F.fold(x_reshaped, output_size=(H, W), kernel_size=(h, w), stride=(h, w))
|
||||
|
||||
final_shape = leading_dims + (H, W)
|
||||
result = folded.view(final_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor],
|
||||
weight_scale: Optional[torch.Tensor],
|
||||
block_size: List[int],
|
||||
is_linear_weight: bool = False,
|
||||
output_opt: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
b0, b1 = block_size
|
||||
dim0, dim1 = weight.shape
|
||||
dim0pad, dim1pad = 0, 0
|
||||
|
||||
if dim0 % b0 != 0:
|
||||
dim0pad = b0 - dim0 % b0
|
||||
if dim1 % b1 != 0:
|
||||
dim1pad = b1 - dim1 % b1
|
||||
|
||||
dim0_origin, dim1_origin = dim0, dim1
|
||||
dim0 += dim0pad
|
||||
dim1 += dim1pad
|
||||
|
||||
bs0, bs1 = dim0 // b0, dim1 // b1
|
||||
weight_dequant = torch.nn.functional.pad(weight, (0, dim1pad, 0, dim0pad), value=0)
|
||||
weight_dequant = weight_dequant.cpu().view(bs0, b0, bs1, b1).permute(
|
||||
0, 2, 1, 3
|
||||
).reshape(bs0, bs1, -1).float().to(input.device) * weight_scale.unsqueeze(-1)
|
||||
weight_dequant = (
|
||||
weight_dequant.reshape(bs0, bs1, b0, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(dim0, dim1)
|
||||
.to(input.dtype)
|
||||
)
|
||||
weight_dequant = weight_dequant[:dim0_origin, :dim1_origin]
|
||||
output = torch.matmul(
|
||||
input, weight_dequant.T if is_linear_weight else weight_dequant
|
||||
)
|
||||
if output_opt is not None:
|
||||
output = output_opt.copy_(output)
|
||||
return output
|
||||
|
||||
|
||||
def w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor],
|
||||
weight_scale: Optional[torch.Tensor],
|
||||
block_size: List[int],
|
||||
**kwargs
|
||||
):
|
||||
assert input_scale is None, "w8a8_block_fp8_matmul only support quant weight now"
|
||||
return w8a8_block_fp8_matmul(
|
||||
input, weight, None, weight_scale, block_size, is_linear_weight=True
|
||||
)
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = True,
|
||||
w13_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a13_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
decode_with_batch: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_seq_all, hidden_dims = hidden_states.shape
|
||||
intermediate_size = w2_weight.shape[-1]
|
||||
num_experts = w13_weight.shape[0]
|
||||
w13_weight = w13_weight.contiguous()
|
||||
w2_weight = w2_weight.contiguous()
|
||||
w13_scale = w13_scale.contiguous()
|
||||
w2_scale = w2_scale.contiguous()
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
import torch.nn.functional as F
|
||||
|
||||
w1_scale = w13_scale
|
||||
w2_scale = w2_scale
|
||||
|
||||
_, bs0_w13, bs1_w13 = w1_scale.shape
|
||||
_, bs0_w2, bs1_w2 = w2_scale.shape
|
||||
|
||||
sel_experts = topk_ids.shape[1]
|
||||
if hidden_states.shape[0] == 1:
|
||||
for id in range(sel_experts):
|
||||
expert_idx = topk_ids[0][id]
|
||||
expert_w1 = w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = w2_weight[expert_idx].contiguous()
|
||||
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
|
||||
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
|
||||
|
||||
dim0, dim1 = expert_w1.shape
|
||||
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
|
||||
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
|
||||
expert_w1 = (
|
||||
expert_w1
|
||||
.view(bs0_w13, b0, bs1_w13, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(bs0_w13, bs1_w13, -1)
|
||||
.float()
|
||||
.to(hidden_states.device)
|
||||
* ws1
|
||||
)
|
||||
expert_w1 = (
|
||||
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(dim0, dim1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
dim0, dim1 = expert_w2.shape
|
||||
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
|
||||
# assert (bs0, bs1, 1)==ws2.shape
|
||||
expert_w2 = (
|
||||
expert_w2
|
||||
.view(bs0_w2, b0, bs1_w2, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(bs0_w2, bs1_w2, -1)
|
||||
.float()
|
||||
.to(hidden_states.device)
|
||||
* ws2
|
||||
)
|
||||
expert_w2 = (
|
||||
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(dim0, dim1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
|
||||
|
||||
x = hidden_states
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
final_hidden_states += current_hidden_states
|
||||
else:
|
||||
for expert_idx in range(num_experts):
|
||||
# topk_ids [tokens, experts] => sample:[10, 8]
|
||||
# expert_mask [tokens, experts] => sample:[10, 8]
|
||||
expert_mask = topk_ids == expert_idx
|
||||
|
||||
idx = torch.where(expert_mask)[0]
|
||||
if idx.numel() == 0:
|
||||
continue
|
||||
|
||||
expert_w1 = w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = w2_weight[expert_idx].contiguous()
|
||||
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
|
||||
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
|
||||
|
||||
dim0, dim1 = expert_w1.shape
|
||||
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
|
||||
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
|
||||
expert_w1 = (
|
||||
expert_w1
|
||||
.view(bs0_w13, b0, bs1_w13, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(bs0_w13, bs1_w13, -1)
|
||||
.float()
|
||||
.to(hidden_states.device)
|
||||
* ws1
|
||||
)
|
||||
expert_w1 = (
|
||||
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(dim0, dim1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
dim0, dim1 = expert_w2.shape
|
||||
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
|
||||
# assert (bs0, bs1, 1)==ws2.shape
|
||||
expert_w2 = (
|
||||
expert_w2
|
||||
.view(bs0_w2, b0, bs1_w2, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(bs0_w2, bs1_w2, -1)
|
||||
.float()
|
||||
.to(hidden_states.device)
|
||||
* ws2
|
||||
)
|
||||
expert_w2 = (
|
||||
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(dim0, dim1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
# [seq, experts]
|
||||
expert_weights = (
|
||||
topk_weights.masked_select(expert_mask)
|
||||
.unsqueeze(1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
x = hidden_states[idx]
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
# final_hidden_states[idx] += current_hidden_states
|
||||
final_hidden_states.index_add_(0, idx, current_hidden_states)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_seq_all, hidden_dims)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_mlp_mm_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
use_fp8_w8a8: bool = True,
|
||||
w13_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a13_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape_w13: Optional[List[int]] = None,
|
||||
block_shape_w2: Optional[List[int]] = None,
|
||||
):
|
||||
def fp8_to_fp16(inp, scale, block_size, trans_type):
|
||||
inp_t = inp.to(trans_type)
|
||||
inp_t = split_last_two_dims_into_blocks(inp_t, block_size[0], block_size[1])
|
||||
assert scale.size(0) == inp_t.size(-4)
|
||||
assert scale.size(1) == inp_t.size(-3)
|
||||
inp_t = inp_t * scale.unsqueeze(-1).unsqueeze(-1)
|
||||
inp_t = merge_blocks_to_original_layout(inp_t, block_size[0], block_size[1])
|
||||
return inp_t.to(trans_type)
|
||||
|
||||
w13_weight = w13_weight.contiguous()
|
||||
w2_weight = w2_weight.contiguous()
|
||||
w13_scale = w13_scale.contiguous()
|
||||
w2_scale = w2_scale.contiguous()
|
||||
w13_fp = fp8_to_fp16(w13_weight, w13_scale, block_shape_w13, hidden_states.dtype)
|
||||
w2_fp = fp8_to_fp16(w2_weight, w2_scale, block_shape_w2, hidden_states.dtype)
|
||||
out = hidden_states @ w13_fp
|
||||
out = torch.chunk(out, 2, dim=-1)
|
||||
out = F.silu(out[0]) * out[1]
|
||||
out = out @ w2_fp
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mla_matmul_scale(input: torch.Tensor, weight: torch.Tensor, scale: float):
|
||||
output = torch.matmul(input, weight)
|
||||
output = output * scale
|
||||
output = output.to(input.dtype)
|
||||
return output
|
||||
|
||||
|
||||
def mla_matmul(input: torch.Tensor, weight: torch.Tensor):
|
||||
output = torch.matmul(input, weight)
|
||||
output = output.to(input.dtype)
|
||||
return output
|
||||
146
torch_vacc/vacc/custom_qwen3_ops.py
Normal file
146
torch_vacc/vacc/custom_qwen3_ops.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Generator
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
|
||||
def fuse_moe_prefill_stage0_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
gate_weight,
|
||||
rms_hidden_state_opt: Optional[torch.Tensor] = None,
|
||||
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
|
||||
topk_ids_opt: Optional[torch.Tensor] = None,
|
||||
topk_weight_opt: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
gate_weight,
|
||||
rms_hidden_state_opt,
|
||||
zero_moe_hidden_state_opt,
|
||||
topk_ids_opt,
|
||||
topk_weight_opt,
|
||||
)
|
||||
|
||||
|
||||
def fuse_moe_decode_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
moe_weight_13,
|
||||
moe_weight_2,
|
||||
moe_weight_13_dequat,
|
||||
moe_weight_2_dequant,
|
||||
gate_weight,
|
||||
block_size_13,
|
||||
block_size_2,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
group_id: int,
|
||||
dev_info: List[int] = None,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if 0 == len(dev_info):
|
||||
dev_info = [i | (i << 16) for i in range(world_size)]
|
||||
return _torch_vacc.fuse_moe_decode_qwen(
|
||||
hidden_states,
|
||||
rms_residual,
|
||||
rms_weight,
|
||||
moe_weight_13,
|
||||
moe_weight_2,
|
||||
moe_weight_13_dequat,
|
||||
moe_weight_2_dequant,
|
||||
gate_weight,
|
||||
block_size_13,
|
||||
block_size_2,
|
||||
world_size,
|
||||
rank,
|
||||
group_id,
|
||||
dev_info,
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
|
||||
hidden_size: int,
|
||||
head_num: int,
|
||||
spatial_merge_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[int, str, torch.device] = "vacc"):
|
||||
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("vacc", device)
|
||||
|
||||
thws = []
|
||||
for i in grid_thw:
|
||||
thws.extend(i)
|
||||
return _torch_vacc.rot_pos_emb_qwenvl(thws,
|
||||
hidden_size,
|
||||
head_num,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
device)
|
||||
|
||||
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
|
||||
grid_thw: List[List[int]],
|
||||
num_grid_per_side: int,
|
||||
spatial_merge_size: int,
|
||||
hidden_dim: int):
|
||||
thws = []
|
||||
for i in grid_thw:
|
||||
thws.extend(i)
|
||||
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
|
||||
thws,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
hidden_dim)
|
||||
# qwen2_vl and qwen3_vl img preocess op is same
|
||||
def qwen2vl_img_preprocess(
|
||||
image: "torch.Tensor",
|
||||
do_resize: bool,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
resized_height: int,
|
||||
resized_width: int,
|
||||
interpolation: int, #Optional["F.InterpolationMode"],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
image_mean0: float,
|
||||
image_mean1: float,
|
||||
image_mean2: float,
|
||||
image_std0: float,
|
||||
image_std1: float,
|
||||
image_std2: float,
|
||||
# batch_size: int = 1,
|
||||
# grid_t: int = 1,
|
||||
# channel: int = 3,
|
||||
# output: Optional[torch.Tensor] = None
|
||||
):
|
||||
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
|
||||
return _torch_vacc.qwen2vl_img_preprocess(
|
||||
image,
|
||||
do_resize,
|
||||
min_pixels,
|
||||
max_pixels,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
resized_height,
|
||||
resized_width,
|
||||
interpolation,
|
||||
patch_size,
|
||||
temporal_patch_size,
|
||||
merge_size,
|
||||
image_mean0, image_mean1, image_mean2,
|
||||
image_std0, image_std1, image_std2
|
||||
)
|
||||
107
torch_vacc/vacc/lazy_initialize.py
Normal file
107
torch_vacc/vacc/lazy_initialize.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import threading
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from .._vacc_libs import _torch_vacc
|
||||
|
||||
_initialized = False
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
_queued_calls = []
|
||||
|
||||
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Returns whether PyTorch's VACC state has been initialized."""
|
||||
return _initialized and not _is_in_bad_fork()
|
||||
|
||||
|
||||
class _LazySeedTracker:
|
||||
# Since seeding is memory-less, only track the latest seed.
|
||||
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||||
# the seed on current device. We track the order of **latest**
|
||||
# calls between these two API.
|
||||
def __init__(self):
|
||||
self.manual_seed_all_cb = None
|
||||
self.manual_seed_cb = None
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
||||
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
|
||||
|
||||
def _lazy_call(callable, **kwargs):
|
||||
if is_initialized():
|
||||
callable()
|
||||
else:
|
||||
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
||||
# file system to get traceback info. Patch linecache or do something
|
||||
# else here if this ends up being important.
|
||||
global _lazy_seed_tracker
|
||||
if kwargs.get("seed_all", False):
|
||||
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
||||
elif kwargs.get("seed", False):
|
||||
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
||||
else:
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
_queued_calls.append((callable, traceback.format_stack()))
|
||||
|
||||
|
||||
class DeferredVaccCallError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _lazy_init():
|
||||
"""Initialize VACC device state."""
|
||||
|
||||
global _initialized, _queued_calls
|
||||
if _initialized or hasattr(_tls, "is_initializing"):
|
||||
return
|
||||
with _initialization_lock:
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
# It is important to prevent other threads from entering _lazy_init
|
||||
# immediately, while we are still guaranteed to have the GIL, because some
|
||||
# of the C calls we make below will release the GIL
|
||||
if _is_in_bad_fork():
|
||||
raise RuntimeError(
|
||||
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
|
||||
"multiprocessing, you must use the 'spawn' start method"
|
||||
)
|
||||
|
||||
_torch_vacc._vacc_init()
|
||||
|
||||
_tls.is_initializing = True
|
||||
|
||||
for calls in _lazy_seed_tracker.get_calls():
|
||||
if calls:
|
||||
_queued_calls.append(calls)
|
||||
|
||||
try:
|
||||
for queued_call, orig_traceback in _queued_calls:
|
||||
try:
|
||||
queued_call()
|
||||
except Exception as e:
|
||||
msg = (
|
||||
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
|
||||
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||||
)
|
||||
raise DeferredVaccCallError(msg) from e
|
||||
finally:
|
||||
delattr(_tls, "is_initializing")
|
||||
_initialized = True
|
||||
535
torch_vacc/vacc/memory.py
Normal file
535
torch_vacc/vacc/memory.py
Normal file
@@ -0,0 +1,535 @@
|
||||
import collections
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch._utils import _get_device_index
|
||||
|
||||
import torch_vacc
|
||||
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
from .lazy_initialize import is_initialized, _lazy_init
|
||||
|
||||
__all__ = [
|
||||
"mem_get_info",
|
||||
# "caching_allocator_alloc",
|
||||
# "caching_allocator_delete",
|
||||
"set_per_process_memory_fraction",
|
||||
"empty_cache",
|
||||
"memory_stats",
|
||||
"memory_stats_as_nested_dict",
|
||||
"reset_accumulated_memory_stats",
|
||||
"reset_peak_memory_stats",
|
||||
"reset_max_memory_allocated",
|
||||
"reset_max_memory_cached",
|
||||
"memory_allocated",
|
||||
"max_memory_allocated",
|
||||
"memory_reserved",
|
||||
"max_memory_reserved",
|
||||
"memory_cached",
|
||||
"max_memory_cached",
|
||||
"memory_snapshot",
|
||||
"memory_summary",
|
||||
"get_allocator_backend",
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _free_mutex():
|
||||
_torch_vacc._vacc_lock_mutex()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_torch_vacc._vacc_unlock_mutex()
|
||||
|
||||
|
||||
# def caching_allocator_alloc(size, device=None, stream=None):
|
||||
# r"""Performs a memory allocation using the VACC memory allocator.
|
||||
|
||||
# Memory is allocated for a given device and a stream, this
|
||||
# function is intended to be used for interoperability with other
|
||||
# frameworks. Allocated memory is released through
|
||||
# :func:`~torch_vacc.vacc.caching_allocator_delete`.
|
||||
|
||||
# Arguments:
|
||||
# size (int): number of bytes to be allocated.
|
||||
# device (torch.device or int, optional): selected device. If it is
|
||||
# ``None`` the default VACC device is used.
|
||||
# stream (torch_vacc.vacc.Stream or int, optional): selected stream. If is ``None`` then
|
||||
# the default stream for the selected device is used.
|
||||
# """
|
||||
# if device is None:
|
||||
# device = torch_vacc.vacc.current_device()
|
||||
# device = _get_device_index(device)
|
||||
# if stream is None:
|
||||
# stream = torch_vacc.vacc.current_stream(device)
|
||||
# if isinstance(stream, torch_vacc.vacc.streams.Stream):
|
||||
# stream = stream.vacc_stream
|
||||
# if not isinstance(stream, int):
|
||||
# raise TypeError(
|
||||
# "Invalid type for stream argument, must be "
|
||||
# "`torch_vacc.vacc.Stream` or `int` representing a pointer "
|
||||
# "to a exisiting stream"
|
||||
# )
|
||||
# with torch_vacc.vacc.device(device):
|
||||
# return _torch_vacc._vacc_vaccCachingAllocator_raw_alloc(size, stream)
|
||||
|
||||
|
||||
# def caching_allocator_delete(mem_ptr):
|
||||
# r"""Deletes memory allocated using the VACC memory allocator.
|
||||
|
||||
# Memory allocated with :func:`~torch_vacc.vacc.caching_allocator_alloc`.
|
||||
# is freed here. The associated device and stream are tracked inside
|
||||
# the allocator.
|
||||
|
||||
# Arguments:
|
||||
# mem_ptr (int): memory address to be freed by the allocator.
|
||||
# """
|
||||
# _torch_vacc._vacc_vaccCachingAllocator_raw_delete(mem_ptr)
|
||||
|
||||
|
||||
def set_per_process_memory_fraction(fraction, device=None) -> None:
|
||||
r"""Set memory fraction for a process.
|
||||
The fraction is used to limit an caching allocator to allocated memory on a VACC device.
|
||||
The allowed value equals the total visible memory multiplied fraction.
|
||||
If trying to allocate more than the allowed value in a process, will raise an out of
|
||||
memory error in allocator.
|
||||
Arguments:
|
||||
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
|
||||
device (torch.device or int, optional): selected device. If it is
|
||||
``None`` the default VACC device is used.
|
||||
.. note::
|
||||
In general, the total available free memory is less than the total capacity.
|
||||
"""
|
||||
_lazy_init()
|
||||
if device is None:
|
||||
device = torch_vacc.vacc.current_device()
|
||||
device = _get_device_index(device)
|
||||
if not isinstance(fraction, float):
|
||||
raise TypeError("Invalid type for fraction argument, must be `float`")
|
||||
if fraction < 0 or fraction > 1:
|
||||
raise ValueError(
|
||||
"Invalid fraction value: {}. " "Allowed range: 0~1".format(fraction)
|
||||
)
|
||||
|
||||
_torch_vacc._vacc_setMemoryFraction(fraction, device)
|
||||
|
||||
|
||||
def empty_cache():
|
||||
r"""Releases all unoccupied cached memory currently held by the caching
|
||||
allocator so that those can be used in other VACC application and visible in
|
||||
`nvidia-smi`.
|
||||
|
||||
.. note::
|
||||
:func:`~torch_vacc.vacc.empty_cache` doesn't increase the amount of VACC
|
||||
memory available for PyTorch. However, it may help reduce fragmentation
|
||||
of VACC memory in certain cases.
|
||||
"""
|
||||
if is_initialized():
|
||||
_torch_vacc._vacc_emptyCache()
|
||||
|
||||
|
||||
def memory_stats(device=None):
|
||||
"""Returns a dictionary of VACC memory allocator statistics for a
|
||||
given device.
|
||||
The return value of this function is a dictionary of statistics, each of
|
||||
which is a non-negative integer.
|
||||
Core statistics:
|
||||
- ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
number of allocation requests received by the memory allocator.
|
||||
- ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
amount of allocated memory.
|
||||
- ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
number of reserved segments from ``vaccMalloc()``.
|
||||
- ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
amount of reserved memory.
|
||||
- ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
number of active memory blocks.
|
||||
- ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
amount of active memory.
|
||||
- ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
number of inactive, non-releasable memory blocks.
|
||||
- ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||
amount of inactive, non-releasable memory.
|
||||
For these core statistics, values are broken down as follows.
|
||||
Pool type:
|
||||
- ``all``: combined statistics across all memory pools.
|
||||
- ``large_pool``: statistics for the large allocation pool
|
||||
(as of October 2019, for size >= 1MB allocations).
|
||||
- ``small_pool``: statistics for the small allocation pool
|
||||
(as of October 2019, for size < 1MB allocations).
|
||||
Metric type:
|
||||
- ``current``: current value of this metric.
|
||||
- ``peak``: maximum value of this metric.
|
||||
- ``allocated``: historical total increase in this metric.
|
||||
- ``freed``: historical total decrease in this metric.
|
||||
In addition to the core statistics, we also provide some simple event
|
||||
counters:
|
||||
- ``"num_alloc_retries"``: number of failed ``vaccMalloc`` calls that
|
||||
result in a cache flush and retry.
|
||||
- ``"num_ooms"``: number of out-of-memory errors thrown.
|
||||
The caching allocator can be configured via ENV to not split blocks larger than a
|
||||
defined size (see Memory Management section of the Cuda Semantics documentation).
|
||||
This helps avoid memory framentation but may have a performance
|
||||
penalty. Additional outputs to assist with tuning and evaluating impact:
|
||||
- ``"max_split_size"``: blocks above this size will not be split.
|
||||
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
|
||||
number of over-size allocation requests received by the memory allocator.
|
||||
- ``"oversize_segments.{current,peak,allocated,freed}"``:
|
||||
number of over-size reserved segments from ``cudaMalloc()``.
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistics for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
result = []
|
||||
|
||||
def _recurse_add_to_result(prefix, obj):
|
||||
if isinstance(obj, dict):
|
||||
if len(prefix) > 0:
|
||||
prefix += "."
|
||||
for k, v in obj.items():
|
||||
_recurse_add_to_result(prefix + k, v)
|
||||
else:
|
||||
result.append((prefix, obj))
|
||||
|
||||
stats = memory_stats_as_nested_dict(device=device)
|
||||
_recurse_add_to_result("", stats)
|
||||
result.sort()
|
||||
|
||||
return collections.OrderedDict(result)
|
||||
|
||||
|
||||
def memory_stats_as_nested_dict(device=None):
|
||||
r"""Returns the result of :func:`~torch_vacc.vacc.memory_stats` as a nested dictionary."""
|
||||
device = _get_device_index(device, optional=True)
|
||||
return _torch_vacc._vacc_memoryStats(device)
|
||||
|
||||
|
||||
def reset_accumulated_memory_stats(device=None):
|
||||
r"""Resets the "accumulated" (historical) stats tracked by the VACC memory allocator.
|
||||
|
||||
See :func:`~torch_vacc.vacc.memory_stats` for details. Accumulated stats correspond to
|
||||
the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
|
||||
`"num_alloc_retries"` and `"num_ooms"`.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
device = _get_device_index(device, optional=True)
|
||||
return _torch_vacc._vacc_resetAccumulatedMemoryStats(device)
|
||||
|
||||
|
||||
def reset_peak_memory_stats(device=None):
|
||||
r"""Resets the "peak" stats tracked by the VACC memory allocator.
|
||||
|
||||
See :func:`~torch_vacc.vacc.memory_stats` for details. Peak stats correspond to the
|
||||
`"peak"` key in each individual stat dict.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
device = _get_device_index(device, optional=True)
|
||||
return _torch_vacc._vacc_resetPeakMemoryStats(device)
|
||||
|
||||
|
||||
def reset_max_memory_allocated(device=None):
|
||||
r"""Resets the starting point in tracking maximum VACC memory occupied by
|
||||
tensors for a given device.
|
||||
|
||||
See :func:`~torch_vacc.vacc.max_memory_allocated` for details.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
|
||||
.. warning::
|
||||
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
|
||||
/all/ peak memory stats.
|
||||
"""
|
||||
# warnings.warn(
|
||||
# "torch_vacc.vacc.reset_max_memory_allocated now calls torch_vacc.vacc.reset_peak_memory_stats, "
|
||||
# "which resets /all/ peak memory stats.",
|
||||
# DeprecationWarning,
|
||||
# )
|
||||
return reset_peak_memory_stats(device=device)
|
||||
|
||||
|
||||
def reset_max_memory_cached(device=None):
|
||||
r"""Resets the starting point in tracking maximum VACC memory managed by the
|
||||
caching allocator for a given device.
|
||||
|
||||
See :func:`~torch_vacc.vacc.max_memory_cached` for details.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
|
||||
.. warning::
|
||||
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
|
||||
/all/ peak memory stats.
|
||||
"""
|
||||
# warnings.warn(
|
||||
# "torch_vacc.vacc.reset_max_memory_cached now calls torch_vacc.vacc.reset_peak_memory_stats, "
|
||||
# "which resets /all/ peak memory stats.",
|
||||
# DeprecationWarning,
|
||||
# )
|
||||
return reset_peak_memory_stats(device=device)
|
||||
|
||||
|
||||
def memory_allocated(device=None):
|
||||
r"""Returns the current VACC memory occupied by tensors in bytes for a given
|
||||
device.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
return memory_stats(device=device)["allocated_bytes.all.current"]
|
||||
|
||||
|
||||
def max_memory_allocated(device=None):
|
||||
r"""Returns the maximum VACC memory occupied by tensors in bytes for a given
|
||||
device.
|
||||
|
||||
By default, this returns the peak allocated memory since the beginning of
|
||||
this program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to
|
||||
reset the starting point in tracking this metric. For example, these two
|
||||
functions can measure the peak allocated memory usage of each iteration in a
|
||||
training loop.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
return memory_stats(device=device)["allocated_bytes.all.peak"]
|
||||
|
||||
|
||||
def memory_reserved(device=None):
|
||||
r"""Returns the current VACC memory managed by the caching allocator in bytes
|
||||
for a given device.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
return memory_stats(device=device)["reserved_bytes.all.current"]
|
||||
|
||||
|
||||
def max_memory_reserved(device=None):
|
||||
r"""Returns the maximum VACC memory managed by the caching allocator in bytes
|
||||
for a given device.
|
||||
|
||||
By default, this returns the peak cached memory since the beginning of this
|
||||
program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to reset
|
||||
the starting point in tracking this metric. For example, these two functions
|
||||
can measure the peak cached memory amount of each iteration in a training
|
||||
loop.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
return memory_stats(device=device)["reserved_bytes.all.peak"]
|
||||
|
||||
|
||||
def memory_cached(device=None):
|
||||
r"""Deprecated; see :func:`~torch_vacc.vacc.memory_reserved`."""
|
||||
# warnings.warn(
|
||||
# "torch_vacc.vacc.memory_cached has been renamed to torch_vacc.vacc.memory_reserved",
|
||||
# DeprecationWarning,
|
||||
# )
|
||||
return memory_reserved(device=device)
|
||||
|
||||
|
||||
def max_memory_cached(device=None):
|
||||
r"""Deprecated; see :func:`~torch_vacc.vacc.max_memory_reserved`."""
|
||||
# warnings.warn(
|
||||
# "torch_vacc.vacc.max_memory_cached has been renamed to torch_vacc.vacc.max_memory_reserved",
|
||||
# DeprecationWarning,
|
||||
# )
|
||||
return max_memory_reserved(device=device)
|
||||
|
||||
|
||||
def memory_snapshot():
|
||||
r"""Returns a snapshot of the VACC memory allocator state across all devices.
|
||||
|
||||
Interpreting the output of this function requires familiarity with the
|
||||
memory allocator internals.
|
||||
"""
|
||||
return _torch_vacc._vacc_memorySnapshot()
|
||||
|
||||
|
||||
def _format_size(sz, pref_sz):
|
||||
prefixes = ["B ", "KB", "MB", "GB", "TB", "PB"]
|
||||
prefix = prefixes[0]
|
||||
for new_prefix in prefixes[1:]:
|
||||
if pref_sz < 768 * 1024:
|
||||
break
|
||||
prefix = new_prefix
|
||||
sz //= 1024
|
||||
pref_sz /= 1024
|
||||
return "{:7d} {}".format(sz, prefix)
|
||||
|
||||
|
||||
def _format_count(cnt, pref_cnt):
|
||||
prefixes = [" ", "K", "M"]
|
||||
prefix = prefixes[0]
|
||||
for new_prefix in prefixes[1:]:
|
||||
if pref_cnt < 750 * 1000:
|
||||
break
|
||||
prefix = new_prefix
|
||||
cnt //= 1000
|
||||
pref_cnt /= 1000
|
||||
return "{:7d} {} ".format(cnt, prefix)
|
||||
|
||||
|
||||
def create_metrics_to_display():
|
||||
metrics_to_display = [
|
||||
("allocated_bytes", "Allocated memory", _format_size),
|
||||
("active_bytes", "Active memory", _format_size),
|
||||
("reserved_bytes", "VACC reserved memory", _format_size),
|
||||
("inactive_split_bytes", "Non-releasable memory", _format_size),
|
||||
("allocation", "Allocations", _format_count),
|
||||
("active", "Active allocs", _format_count),
|
||||
("segment", "VACC reserved segments", _format_count),
|
||||
("inactive_split", "Non-releasable allocs", _format_count),
|
||||
]
|
||||
|
||||
lines = []
|
||||
lines.append("=" * 75)
|
||||
lines.append(" {_:16} PyTorch VACC memory summary, device ID {device:<18d} ")
|
||||
lines.append("-" * 75)
|
||||
lines.append(
|
||||
" {_:9} VACC OOMs: {num_ooms:<13d} | {_:6} vaccMalloc retries: {num_alloc_retries:<9d} "
|
||||
)
|
||||
lines.append("=" * 75)
|
||||
lines.append(
|
||||
" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
|
||||
)
|
||||
return metrics_to_display, lines
|
||||
|
||||
|
||||
def memory_summary(device=None, abbreviated=False):
|
||||
r"""Returns a human-readable printout of the current memory allocator
|
||||
statistics for a given device.
|
||||
|
||||
This can be useful to display periodically during training, or when
|
||||
handling out-of-memory exceptions.
|
||||
|
||||
Arguments:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
printout for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
abbreviated (bool, optional): whether to return an abbreviated summary
|
||||
(default: False).
|
||||
"""
|
||||
device = _get_device_index(device, optional=True)
|
||||
stats = memory_stats(device=device)
|
||||
metrics_to_display, lines = create_metrics_to_display()
|
||||
|
||||
for metric_key, metric_name, formatter in metrics_to_display:
|
||||
lines.append("-" * 75)
|
||||
submetrics = [("all", metric_name)]
|
||||
if not abbreviated:
|
||||
submetrics.append(("large_pool", " from large pool"))
|
||||
submetrics.append(("small_pool", " from small pool"))
|
||||
|
||||
current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
for submetric_key, submetric_name in submetrics:
|
||||
prefix = metric_key + "." + submetric_key + "."
|
||||
|
||||
current = stats[prefix + "current"]
|
||||
peak = stats[prefix + "peak"]
|
||||
allocated = stats[prefix + "allocated"]
|
||||
freed = stats[prefix + "freed"]
|
||||
|
||||
if current_prefval is None:
|
||||
current_prefval = current
|
||||
peak_prefval = peak
|
||||
allocated_prefval = allocated
|
||||
freed_prefval = freed
|
||||
|
||||
lines.append(
|
||||
" {:<21} | {} | {} | {} | {} ".format(
|
||||
submetric_name,
|
||||
formatter(current, current_prefval),
|
||||
formatter(peak, peak_prefval),
|
||||
formatter(allocated, allocated_prefval),
|
||||
formatter(freed, freed_prefval),
|
||||
),
|
||||
)
|
||||
|
||||
metrics_to_display = [
|
||||
("oversize_allocations", "Oversize allocations", _format_count),
|
||||
("oversize_segments", "Oversize VACC segments", _format_count),
|
||||
]
|
||||
|
||||
for metric_key, metric_name, formatter in metrics_to_display:
|
||||
lines.append("-" * 75)
|
||||
|
||||
prefix = metric_key + "."
|
||||
|
||||
current = stats[prefix + "current"]
|
||||
peak = stats[prefix + "peak"]
|
||||
allocated = stats[prefix + "allocated"]
|
||||
freed = stats[prefix + "freed"]
|
||||
|
||||
lines.append(
|
||||
" {:<21} | {} | {} | {} | {} ".format(
|
||||
metric_name,
|
||||
formatter(current, current),
|
||||
formatter(peak, peak),
|
||||
formatter(allocated, allocated),
|
||||
formatter(freed, freed),
|
||||
),
|
||||
)
|
||||
|
||||
lines.append("=" * 75)
|
||||
|
||||
fmt_dict = {"_": "", "device": device}
|
||||
for k, v in stats.items():
|
||||
fmt_dict[k.replace(".", "-")] = v
|
||||
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
|
||||
|
||||
|
||||
def mem_get_info(device=None) -> Tuple[int, int]:
|
||||
r"""Returns the global free and total VACC memory for a given
|
||||
device using vaccrtMemGetInfo.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
_lazy_init()
|
||||
if device is None:
|
||||
device = torch_vacc.vacc.current_device()
|
||||
device = _get_device_index(device)
|
||||
return _torch_vacc._vacc_getDeviceMemories(device)
|
||||
|
||||
|
||||
def get_allocator_backend() -> str:
|
||||
r"""Returns a string describing the active allocator backend as set by
|
||||
``PYTORCH_VACC_ALLOC_CONF``. Currently available backends are
|
||||
``native`` (PyTorch's native caching allocator).
|
||||
"""
|
||||
return _torch_vacc._vacc_getAllocatorBackend()
|
||||
179
torch_vacc/vacc/random.py
Normal file
179
torch_vacc/vacc/random.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from typing import Union, List, Iterable
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from . import _lazy_call, _lazy_init, current_device, device_count
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_rng_state",
|
||||
"get_rng_state_all",
|
||||
"set_rng_state",
|
||||
"set_rng_state_all",
|
||||
"manual_seed",
|
||||
"manual_seed_all",
|
||||
"seed",
|
||||
"seed_all",
|
||||
"initial_seed",
|
||||
]
|
||||
|
||||
# Random Number Generator related functions (https://pytorch.org/docs/stable/cuda.html#random-number-generator)
|
||||
|
||||
|
||||
def get_rng_state(device: Union[int, str, torch.device] = "vacc") -> Tensor:
|
||||
r"""Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): The device to return the RNG state of.
|
||||
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
|
||||
|
||||
.. warning::
|
||||
This function eagerly initializes VACC.
|
||||
"""
|
||||
_lazy_init()
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("vacc", device)
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = torch.vacc.default_generators[idx]
|
||||
return default_generator.get_state()
|
||||
|
||||
|
||||
def get_rng_state_all() -> List[Tensor]:
|
||||
r"""Returns a list of ByteTensor representing the random number states of all devices."""
|
||||
|
||||
results = []
|
||||
for i in range(device_count()):
|
||||
results.append(get_rng_state(i))
|
||||
return results
|
||||
|
||||
|
||||
def set_rng_state(
|
||||
new_state: Tensor, device: Union[int, str, torch.device] = "vacc"
|
||||
) -> None:
|
||||
r"""Sets the random number generator state of the specified GPU.
|
||||
|
||||
Args:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
device (torch.device or int, optional): The device to set the RNG state.
|
||||
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
|
||||
"""
|
||||
with torch._C._DisableFuncTorch():
|
||||
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("vacc", device)
|
||||
|
||||
def cb():
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = torch.vacc.default_generators[idx]
|
||||
default_generator.set_state(new_state_copy)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
|
||||
r"""Sets the random number generator state of all devices.
|
||||
|
||||
Args:
|
||||
new_states (Iterable of torch.ByteTensor): The desired state for each device"""
|
||||
for i, state in enumerate(new_states):
|
||||
set_rng_state(state, i)
|
||||
|
||||
|
||||
def manual_seed(seed: int) -> None:
|
||||
r"""Sets the seed for generating random numbers for the current GPU.
|
||||
It's safe to call this function if VACC is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
||||
Args:
|
||||
seed (int): The desired seed.
|
||||
|
||||
.. warning::
|
||||
If you are working with a multi-GPU model, this function is insufficient
|
||||
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
|
||||
"""
|
||||
seed = int(seed)
|
||||
|
||||
def cb():
|
||||
idx = current_device()
|
||||
default_generator = torch.vacc.default_generators[idx]
|
||||
default_generator.manual_seed(seed)
|
||||
|
||||
_lazy_call(cb, seed=True)
|
||||
|
||||
|
||||
def manual_seed_all(seed: int) -> None:
|
||||
r"""Sets the seed for generating random numbers on all GPUs.
|
||||
It's safe to call this function if VACC is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
||||
Args:
|
||||
seed (int): The desired seed.
|
||||
"""
|
||||
seed = int(seed)
|
||||
|
||||
def cb():
|
||||
for i in range(device_count()):
|
||||
default_generator = torch.vacc.default_generators[i]
|
||||
default_generator.manual_seed(seed)
|
||||
|
||||
_lazy_call(cb, seed_all=True)
|
||||
|
||||
|
||||
def seed() -> None:
|
||||
r"""Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
It's safe to call this function if VACC is not available; in that
|
||||
case, it is silently ignored.
|
||||
|
||||
.. warning::
|
||||
If you are working with a multi-GPU model, this function will only initialize
|
||||
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
|
||||
"""
|
||||
|
||||
def cb():
|
||||
idx = current_device()
|
||||
default_generator = torch.vacc.default_generators[idx]
|
||||
default_generator.seed()
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
def seed_all() -> None:
|
||||
r"""Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
It's safe to call this function if VACC is not available; in that
|
||||
case, it is silently ignored.
|
||||
"""
|
||||
|
||||
def cb():
|
||||
random_seed = 0
|
||||
seeded = False
|
||||
for i in range(device_count()):
|
||||
default_generator = torch.vacc.default_generators[i]
|
||||
if not seeded:
|
||||
default_generator.seed()
|
||||
random_seed = default_generator.initial_seed()
|
||||
seeded = True
|
||||
else:
|
||||
default_generator.manual_seed(random_seed)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
def initial_seed() -> int:
|
||||
r"""Returns the current random seed of the current GPU.
|
||||
|
||||
.. warning::
|
||||
This function eagerly initializes VACC.
|
||||
"""
|
||||
_lazy_init()
|
||||
idx = current_device()
|
||||
default_generator = torch.vacc.default_generators[idx]
|
||||
return default_generator.initial_seed()
|
||||
327
torch_vacc/vacc/streams.py
Normal file
327
torch_vacc/vacc/streams.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import ctypes
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch._utils import _get_device_index
|
||||
|
||||
try:
|
||||
from torch._streambase import _StreamBase, _EventBase
|
||||
except ImportError:
|
||||
# torch <= 2.1
|
||||
_StreamBase = _EventBase = object
|
||||
|
||||
import torch_vacc
|
||||
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
from ._device import device
|
||||
from .lazy_initialize import _lazy_init
|
||||
|
||||
|
||||
# remove torch version arch-suffix(i.e. +cpu)
|
||||
torch_version = torch.__version__.split('+')[0]
|
||||
|
||||
class _StreamCommon:
|
||||
"""Wrapper around a VACC stream.
|
||||
|
||||
A VACC stream is a linear sequence of execution that belongs to a specific
|
||||
device, independent from other streams.
|
||||
|
||||
Args:
|
||||
device(torch.device or int, optional): a device on which to allocate
|
||||
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||||
integer, this will use the current device.
|
||||
priority(int, optional): priority of the stream. Can be either
|
||||
-1 (high priority) or 0 (low priority). By default, streams have
|
||||
priority 0.
|
||||
"""
|
||||
|
||||
def __new__(cls, device=None, priority=0, **kwargs):
|
||||
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
||||
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
|
||||
else:
|
||||
with torch_vacc.vacc.device(device):
|
||||
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
|
||||
|
||||
def wait_event(self, event):
|
||||
event.wait(self)
|
||||
|
||||
def record_event(self, event=None):
|
||||
"""Records an event.
|
||||
|
||||
Args:
|
||||
event (torch_vacc.Event, optional): event to record. If not given, a new one
|
||||
will be allocated.
|
||||
|
||||
Returns:
|
||||
Recorded event.
|
||||
"""
|
||||
if event is None:
|
||||
event = Event()
|
||||
event.record(self)
|
||||
return event
|
||||
|
||||
def wait_stream(self, stream):
|
||||
"""Synchronizes with another stream.
|
||||
|
||||
All future work submitted to this stream will wait until all kernels
|
||||
submitted to a given stream at the time of call complete.
|
||||
|
||||
Args:
|
||||
stream (Stream): a stream to synchronize.
|
||||
"""
|
||||
self.wait_event(stream.record_event())
|
||||
|
||||
def query(self):
|
||||
return super().query()
|
||||
|
||||
def synchronize(self):
|
||||
super().synchronize()
|
||||
|
||||
@property
|
||||
def _as_parameter_(self):
|
||||
return ctypes.c_void_p(self.vacc_stream)
|
||||
|
||||
def __eq__(self, o):
|
||||
if isinstance(o, Stream):
|
||||
return super().__eq__(o)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.vacc_stream, self.device))
|
||||
|
||||
def __repr__(self):
|
||||
return f"torch_vacc.vacc.Stream device={self.device} vacc_stream={self.vacc_stream:#x}"
|
||||
|
||||
if version.parse(torch_version) <= version.parse("2.1"):
|
||||
# torch <= 2.1
|
||||
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
|
||||
pass
|
||||
elif version.parse(torch_version) < version.parse("2.6"):
|
||||
# torch < 2.6
|
||||
class Stream(_torch_vacc._VACCStreamBase, _StreamBase, _StreamCommon):
|
||||
pass
|
||||
else:
|
||||
# torch >= 2.6
|
||||
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
|
||||
pass
|
||||
|
||||
|
||||
class _EventCommon:
|
||||
"""Wrapper around a VACC event.
|
||||
|
||||
VACC events are synchronization markers that can be used to monitor the
|
||||
device's progress, to accurately measure timing, and to synchronize VACC
|
||||
streams.
|
||||
|
||||
The underlying VACC events are lazily initialized when the event is first
|
||||
recorded or exported to another process. After creation, only streams on the
|
||||
same device may record the event. However, streams on any device can wait on
|
||||
the event.
|
||||
|
||||
Args:
|
||||
calc_time (bool, optional): indicates if the event should measure time
|
||||
(default: ``False``)
|
||||
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
|
||||
"""
|
||||
|
||||
def __new__(cls, enable_timing=False, blocking=False):
|
||||
return super(Event, cls).__new__(
|
||||
cls,
|
||||
calc_time=enable_timing,
|
||||
blocking=blocking,
|
||||
)
|
||||
|
||||
def record(self, stream=None):
|
||||
"""Records the event in a given stream.
|
||||
|
||||
Uses ``torch_vacc.vacc.current_stream()`` if no stream is specified. The
|
||||
stream's device must match the event's device."""
|
||||
if stream is None:
|
||||
stream = torch_vacc.vacc.current_stream()
|
||||
super().record(stream)
|
||||
|
||||
def wait(self, stream=None):
|
||||
"""Makes all future work submitted to the given stream wait for this
|
||||
event.
|
||||
|
||||
Use ``torch_vacc.vacc.current_stream()`` if no stream is specified.
|
||||
|
||||
.. note:: This is a wrapper around ``vaccrtStreamWaitEvent()``
|
||||
"""
|
||||
if stream is None:
|
||||
stream = torch_vacc.vacc.current_stream()
|
||||
super().wait(stream)
|
||||
|
||||
def query(self):
|
||||
"""Checks if all work currently captured by event has completed.
|
||||
|
||||
Returns:
|
||||
A boolean indicating if all work currently captured by event has
|
||||
completed.
|
||||
"""
|
||||
return super().query()
|
||||
|
||||
def elapsed_time(self, end_event):
|
||||
"""Returns the time elapsed in milliseconds after the event was
|
||||
recorded and before the end_event was recorded.
|
||||
"""
|
||||
return super().elapsed_time(end_event)
|
||||
|
||||
def synchronize(self):
|
||||
r"""Waits for the event to complete.
|
||||
|
||||
Waits until the completion of all work currently captured in this event.
|
||||
This prevents the CPU thread from proceeding until the event completes.
|
||||
|
||||
.. note:: This is a wrapper around ``vaccEventSynchronize()``.
|
||||
"""
|
||||
super().synchronize()
|
||||
|
||||
@property
|
||||
def _as_parameter_(self):
|
||||
return ctypes.c_void_p(self.vacc_event)
|
||||
|
||||
def __repr__(self):
|
||||
if self.vacc_event:
|
||||
return f"<torch_vacc.vacc.Event {self._as_parameter_.value:#x}>"
|
||||
else:
|
||||
return "<torch_vacc.vacc.Event uninitialized>"
|
||||
|
||||
if version.parse(torch_version) <= version.parse("2.1"):
|
||||
# torch <= 2.1
|
||||
class Event(_torch_vacc._VACCEventBase, _EventCommon):
|
||||
pass
|
||||
elif version.parse(torch_version) < version.parse("2.6"):
|
||||
# torch < 2.6
|
||||
class Event(_torch_vacc._VACCEventBase, _EventBase, _EventCommon):
|
||||
pass
|
||||
else:
|
||||
# torch >= 2.6
|
||||
class Event(_torch_vacc._VACCEventBase, _EventCommon):
|
||||
pass
|
||||
|
||||
class StreamContext:
|
||||
r"""Context-manager that selects a given stream.
|
||||
|
||||
All VACC kernels queued within its context will be enqueued on a selected
|
||||
stream.
|
||||
|
||||
Args:
|
||||
stream (stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
.. note:: Streams are per-device.
|
||||
"""
|
||||
cur_stream: Optional["torch_vacc.vacc.Stream"]
|
||||
|
||||
def __init__(self, stream: Optional["torch_vacc.vacc.Stream"]):
|
||||
self.stream = stream
|
||||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
if self.idx is None:
|
||||
self.idx = -1
|
||||
|
||||
self.src_prev_stream = (
|
||||
None
|
||||
if not torch.jit.is_scripting()
|
||||
else torch_vacc.vacc.default_stream(None)
|
||||
)
|
||||
self.dst_prev_stream = (
|
||||
None
|
||||
if not torch.jit.is_scripting()
|
||||
else torch_vacc.vacc.default_stream(None)
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# Local cur_stream variable for type refinement
|
||||
cur_stream = self.stream
|
||||
# Return if stream is None or VACC device not available
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
self.src_prev_stream = torch_vacc.vacc.current_stream(None)
|
||||
|
||||
# If the stream is not on the current device, then
|
||||
# set the current stream on the device
|
||||
if self.src_prev_stream.device != cur_stream.device:
|
||||
with device(cur_stream.device):
|
||||
self.dst_prev_stream = torch_vacc.vacc.current_stream(cur_stream.device)
|
||||
torch_vacc.vacc.set_stream(cur_stream)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
# Local cur_stream variable for type refinement
|
||||
cur_stream = self.stream
|
||||
# If stream is None or no VACC device available, return
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
|
||||
# Reset the stream on the original device
|
||||
# and destination device
|
||||
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
||||
torch_vacc.vacc.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
||||
torch_vacc.vacc.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def stream(stream: Optional["torch_vacc.vacc.Stream"]) -> StreamContext:
|
||||
r"""Wrapper around the Context-manager StreamContext that
|
||||
selects a given stream.
|
||||
|
||||
Arguments:
|
||||
stream (Stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
"""
|
||||
return StreamContext(stream)
|
||||
|
||||
|
||||
def set_stream(stream: Stream):
|
||||
r"""Sets the current stream.This is a wrapper API to set the stream.
|
||||
Usage of this function is discouraged in favor of the ``stream``
|
||||
context manager.
|
||||
|
||||
Args:
|
||||
stream (Stream): selected stream. This function is a no-op
|
||||
if this argument is ``None``.
|
||||
"""
|
||||
if stream is None:
|
||||
return
|
||||
_torch_vacc._vacc_setStream(
|
||||
stream_id=stream.stream_id,
|
||||
device_index=stream.device_index,
|
||||
device_type=stream.device_type,
|
||||
)
|
||||
|
||||
|
||||
def current_stream(device=None) -> Stream:
|
||||
r"""Returns the currently selected :class:`Stream` for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
the currently selected :class:`Stream` for the current device, given
|
||||
by :func:`~torch_vacc.vacc.current_device`, if :attr:`device` is ``None``
|
||||
(default).
|
||||
"""
|
||||
_lazy_init()
|
||||
streamdata = _torch_vacc._vacc_getCurrentStream(
|
||||
_get_device_index(device, optional=True)
|
||||
)
|
||||
return Stream(
|
||||
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
||||
)
|
||||
|
||||
|
||||
def default_stream(device=None) -> Stream:
|
||||
r"""Returns the default :class:`Stream` for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
the default :class:`Stream` for the current device, given by
|
||||
:func:`_torch_vacc.current_device`, if :attr:`device` is ``None``
|
||||
(default).
|
||||
"""
|
||||
_lazy_init()
|
||||
streamdata = _torch_vacc._vacc_getDefaultStream(
|
||||
_get_device_index(device, optional=True)
|
||||
)
|
||||
|
||||
return Stream(
|
||||
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
||||
)
|
||||
Reference in New Issue
Block a user