import threading import traceback from typing import List from .._vacc_libs import _torch_vacc _initialized = False _tls = threading.local() _initialization_lock = threading.Lock() _queued_calls = [] _is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False) def is_initialized(): r"""Returns whether PyTorch's VACC state has been initialized.""" return _initialized and not _is_in_bad_fork() class _LazySeedTracker: # Since seeding is memory-less, only track the latest seed. # Note: `manual_seed_all` followed by `manual_seed` overwrites # the seed on current device. We track the order of **latest** # calls between these two API. def __init__(self): self.manual_seed_all_cb = None self.manual_seed_cb = None self.call_order = [] def queue_seed_all(self, cb, traceback): self.manual_seed_all_cb = (cb, traceback) # update seed_all to be latest self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] def queue_seed(self, cb, traceback): self.manual_seed_cb = (cb, traceback) # update seed to be latest self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] def get_calls(self) -> List: return self.call_order _lazy_seed_tracker = _LazySeedTracker() def _lazy_call(callable, **kwargs): if is_initialized(): callable() else: # TODO(torch_deploy): this accesses linecache, which attempts to read the # file system to get traceback info. Patch linecache or do something # else here if this ends up being important. global _lazy_seed_tracker if kwargs.get("seed_all", False): _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) elif kwargs.get("seed", False): _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) else: # Don't store the actual traceback to avoid memory cycle _queued_calls.append((callable, traceback.format_stack())) class DeferredVaccCallError(Exception): pass def _lazy_init(): """Initialize VACC device state.""" global _initialized, _queued_calls if _initialized or hasattr(_tls, "is_initializing"): return with _initialization_lock: if _initialized: return # It is important to prevent other threads from entering _lazy_init # immediately, while we are still guaranteed to have the GIL, because some # of the C calls we make below will release the GIL if _is_in_bad_fork(): raise RuntimeError( "Cannot re-initialize VACC in forked subprocess. To use VACC with " "multiprocessing, you must use the 'spawn' start method" ) _torch_vacc._vacc_init() _tls.is_initializing = True for calls in _lazy_seed_tracker.get_calls(): if calls: _queued_calls.append(calls) try: for queued_call, orig_traceback in _queued_calls: try: queued_call() except Exception as e: msg = ( f"VACC call failed lazily at initialization with error: {str(e)}\n\n" f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}" ) raise DeferredVaccCallError(msg) from e finally: delattr(_tls, "is_initializing") _initialized = True