108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
|
|
import threading
|
||
|
|
import traceback
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
from .._vacc_libs import _torch_vacc
|
||
|
|
|
||
|
|
_initialized = False
|
||
|
|
_tls = threading.local()
|
||
|
|
_initialization_lock = threading.Lock()
|
||
|
|
_queued_calls = []
|
||
|
|
|
||
|
|
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
|
||
|
|
|
||
|
|
|
||
|
|
def is_initialized():
|
||
|
|
r"""Returns whether PyTorch's VACC state has been initialized."""
|
||
|
|
return _initialized and not _is_in_bad_fork()
|
||
|
|
|
||
|
|
|
||
|
|
class _LazySeedTracker:
|
||
|
|
# Since seeding is memory-less, only track the latest seed.
|
||
|
|
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||
|
|
# the seed on current device. We track the order of **latest**
|
||
|
|
# calls between these two API.
|
||
|
|
def __init__(self):
|
||
|
|
self.manual_seed_all_cb = None
|
||
|
|
self.manual_seed_cb = None
|
||
|
|
self.call_order = []
|
||
|
|
|
||
|
|
def queue_seed_all(self, cb, traceback):
|
||
|
|
self.manual_seed_all_cb = (cb, traceback)
|
||
|
|
# update seed_all to be latest
|
||
|
|
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||
|
|
|
||
|
|
def queue_seed(self, cb, traceback):
|
||
|
|
self.manual_seed_cb = (cb, traceback)
|
||
|
|
# update seed to be latest
|
||
|
|
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||
|
|
|
||
|
|
def get_calls(self) -> List:
|
||
|
|
return self.call_order
|
||
|
|
|
||
|
|
|
||
|
|
_lazy_seed_tracker = _LazySeedTracker()
|
||
|
|
|
||
|
|
|
||
|
|
def _lazy_call(callable, **kwargs):
|
||
|
|
if is_initialized():
|
||
|
|
callable()
|
||
|
|
else:
|
||
|
|
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
||
|
|
# file system to get traceback info. Patch linecache or do something
|
||
|
|
# else here if this ends up being important.
|
||
|
|
global _lazy_seed_tracker
|
||
|
|
if kwargs.get("seed_all", False):
|
||
|
|
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
||
|
|
elif kwargs.get("seed", False):
|
||
|
|
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
||
|
|
else:
|
||
|
|
# Don't store the actual traceback to avoid memory cycle
|
||
|
|
_queued_calls.append((callable, traceback.format_stack()))
|
||
|
|
|
||
|
|
|
||
|
|
class DeferredVaccCallError(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
def _lazy_init():
|
||
|
|
"""Initialize VACC device state."""
|
||
|
|
|
||
|
|
global _initialized, _queued_calls
|
||
|
|
if _initialized or hasattr(_tls, "is_initializing"):
|
||
|
|
return
|
||
|
|
with _initialization_lock:
|
||
|
|
if _initialized:
|
||
|
|
return
|
||
|
|
|
||
|
|
# It is important to prevent other threads from entering _lazy_init
|
||
|
|
# immediately, while we are still guaranteed to have the GIL, because some
|
||
|
|
# of the C calls we make below will release the GIL
|
||
|
|
if _is_in_bad_fork():
|
||
|
|
raise RuntimeError(
|
||
|
|
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
|
||
|
|
"multiprocessing, you must use the 'spawn' start method"
|
||
|
|
)
|
||
|
|
|
||
|
|
_torch_vacc._vacc_init()
|
||
|
|
|
||
|
|
_tls.is_initializing = True
|
||
|
|
|
||
|
|
for calls in _lazy_seed_tracker.get_calls():
|
||
|
|
if calls:
|
||
|
|
_queued_calls.append(calls)
|
||
|
|
|
||
|
|
try:
|
||
|
|
for queued_call, orig_traceback in _queued_calls:
|
||
|
|
try:
|
||
|
|
queued_call()
|
||
|
|
except Exception as e:
|
||
|
|
msg = (
|
||
|
|
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
|
||
|
|
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||
|
|
)
|
||
|
|
raise DeferredVaccCallError(msg) from e
|
||
|
|
finally:
|
||
|
|
delattr(_tls, "is_initializing")
|
||
|
|
_initialized = True
|