init
This commit is contained in:
107
torch_vacc/vacc/lazy_initialize.py
Normal file
107
torch_vacc/vacc/lazy_initialize.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import threading
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from .._vacc_libs import _torch_vacc
|
||||
|
||||
_initialized = False
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
_queued_calls = []
|
||||
|
||||
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Returns whether PyTorch's VACC state has been initialized."""
|
||||
return _initialized and not _is_in_bad_fork()
|
||||
|
||||
|
||||
class _LazySeedTracker:
|
||||
# Since seeding is memory-less, only track the latest seed.
|
||||
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||||
# the seed on current device. We track the order of **latest**
|
||||
# calls between these two API.
|
||||
def __init__(self):
|
||||
self.manual_seed_all_cb = None
|
||||
self.manual_seed_cb = None
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
||||
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
|
||||
|
||||
def _lazy_call(callable, **kwargs):
|
||||
if is_initialized():
|
||||
callable()
|
||||
else:
|
||||
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
||||
# file system to get traceback info. Patch linecache or do something
|
||||
# else here if this ends up being important.
|
||||
global _lazy_seed_tracker
|
||||
if kwargs.get("seed_all", False):
|
||||
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
||||
elif kwargs.get("seed", False):
|
||||
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
||||
else:
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
_queued_calls.append((callable, traceback.format_stack()))
|
||||
|
||||
|
||||
class DeferredVaccCallError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _lazy_init():
|
||||
"""Initialize VACC device state."""
|
||||
|
||||
global _initialized, _queued_calls
|
||||
if _initialized or hasattr(_tls, "is_initializing"):
|
||||
return
|
||||
with _initialization_lock:
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
# It is important to prevent other threads from entering _lazy_init
|
||||
# immediately, while we are still guaranteed to have the GIL, because some
|
||||
# of the C calls we make below will release the GIL
|
||||
if _is_in_bad_fork():
|
||||
raise RuntimeError(
|
||||
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
|
||||
"multiprocessing, you must use the 'spawn' start method"
|
||||
)
|
||||
|
||||
_torch_vacc._vacc_init()
|
||||
|
||||
_tls.is_initializing = True
|
||||
|
||||
for calls in _lazy_seed_tracker.get_calls():
|
||||
if calls:
|
||||
_queued_calls.append(calls)
|
||||
|
||||
try:
|
||||
for queued_call, orig_traceback in _queued_calls:
|
||||
try:
|
||||
queued_call()
|
||||
except Exception as e:
|
||||
msg = (
|
||||
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
|
||||
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||||
)
|
||||
raise DeferredVaccCallError(msg) from e
|
||||
finally:
|
||||
delattr(_tls, "is_initializing")
|
||||
_initialized = True
|
||||
Reference in New Issue
Block a user