Files
2026-04-02 04:55:00 +00:00

63 lines
1.4 KiB
Python

from contextlib import contextmanager
import os
import sys
import torch
from torch.testing import make_tensor
from functools import partial, wraps
import torch.testing._internal.common_device_type as cdt
from torch.testing._internal.common_device_type import (
DeviceTypeTestBase,
dtypes,
instantiate_device_type_tests,
onlyOn,
onlyPRIVATEUSE1,
ops,
)
if sys.version_info > (3, 8):
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
#skip_if_rocm,
with_dist_debug_levels,
)
else:
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
skip_if_rocm,
with_dist_debug_levels,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
parametrize,
run_tests,
subtest,
retry_on_connect_failures,
instantiate_parametrized_tests,
)
onlyVacc = onlyPRIVATEUSE1
class VaccTestBase(DeviceTypeTestBase):
device_type = "vacc"
if VaccTestBase not in cdt.device_type_test_bases:
cdt.device_type_test_bases.append(VaccTestBase)
@contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
yield
torch.set_rng_state(rng_state)