This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,107 @@
import threading
import traceback
from typing import List
from .._vacc_libs import _torch_vacc
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = []
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
def is_initialized():
r"""Returns whether PyTorch's VACC state has been initialized."""
return _initialized and not _is_in_bad_fork()
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
# Note: `manual_seed_all` followed by `manual_seed` overwrites
# the seed on current device. We track the order of **latest**
# calls between these two API.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order
_lazy_seed_tracker = _LazySeedTracker()
def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
# TODO(torch_deploy): this accesses linecache, which attempts to read the
# file system to get traceback info. Patch linecache or do something
# else here if this ends up being important.
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
class DeferredVaccCallError(Exception):
pass
def _lazy_init():
"""Initialize VACC device state."""
global _initialized, _queued_calls
if _initialized or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
if _initialized:
return
# It is important to prevent other threads from entering _lazy_init
# immediately, while we are still guaranteed to have the GIL, because some
# of the C calls we make below will release the GIL
if _is_in_bad_fork():
raise RuntimeError(
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
"multiprocessing, you must use the 'spawn' start method"
)
_torch_vacc._vacc_init()
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
try:
for queued_call, orig_traceback in _queued_calls:
try:
queued_call()
except Exception as e:
msg = (
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
)
raise DeferredVaccCallError(msg) from e
finally:
delattr(_tls, "is_initializing")
_initialized = True