63 lines
1.4 KiB
Python
63 lines
1.4 KiB
Python
from contextlib import contextmanager
|
|
import os
|
|
import sys
|
|
import torch
|
|
from torch.testing import make_tensor
|
|
from functools import partial, wraps
|
|
import torch.testing._internal.common_device_type as cdt
|
|
from torch.testing._internal.common_device_type import (
|
|
DeviceTypeTestBase,
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
onlyOn,
|
|
onlyPRIVATEUSE1,
|
|
ops,
|
|
)
|
|
|
|
if sys.version_info > (3, 8):
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
init_multigpu_helper,
|
|
skip_if_lt_x_gpu,
|
|
get_timeout,
|
|
#skip_if_rocm,
|
|
with_dist_debug_levels,
|
|
)
|
|
else:
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
init_multigpu_helper,
|
|
skip_if_lt_x_gpu,
|
|
get_timeout,
|
|
skip_if_rocm,
|
|
with_dist_debug_levels,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
load_tests,
|
|
parametrize,
|
|
run_tests,
|
|
subtest,
|
|
retry_on_connect_failures,
|
|
instantiate_parametrized_tests,
|
|
|
|
)
|
|
|
|
onlyVacc = onlyPRIVATEUSE1
|
|
|
|
|
|
class VaccTestBase(DeviceTypeTestBase):
|
|
device_type = "vacc"
|
|
|
|
|
|
if VaccTestBase not in cdt.device_type_test_bases:
|
|
cdt.device_type_test_bases.append(VaccTestBase)
|
|
|
|
|
|
@contextmanager
|
|
def freeze_rng_state():
|
|
rng_state = torch.get_rng_state()
|
|
yield
|
|
torch.set_rng_state(rng_state)
|