from contextlib import contextmanager import os import sys import torch from torch.testing import make_tensor from functools import partial, wraps import torch.testing._internal.common_device_type as cdt from torch.testing._internal.common_device_type import ( DeviceTypeTestBase, dtypes, instantiate_device_type_tests, onlyOn, onlyPRIVATEUSE1, ops, ) if sys.version_info > (3, 8): from torch.testing._internal.common_distributed import ( MultiProcessTestCase, init_multigpu_helper, skip_if_lt_x_gpu, get_timeout, #skip_if_rocm, with_dist_debug_levels, ) else: from torch.testing._internal.common_distributed import ( MultiProcessTestCase, init_multigpu_helper, skip_if_lt_x_gpu, get_timeout, skip_if_rocm, with_dist_debug_levels, ) from torch.testing._internal.common_utils import ( TestCase, load_tests, parametrize, run_tests, subtest, retry_on_connect_failures, instantiate_parametrized_tests, ) onlyVacc = onlyPRIVATEUSE1 class VaccTestBase(DeviceTypeTestBase): device_type = "vacc" if VaccTestBase not in cdt.device_type_test_bases: cdt.device_type_test_bases.append(VaccTestBase) @contextmanager def freeze_rng_state(): rng_state = torch.get_rng_state() yield torch.set_rng_state(rng_state)