init
This commit is contained in:
62
torch_vacc/testing/__init__.py
Normal file
62
torch_vacc/testing/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from functools import partial, wraps
|
||||
import torch.testing._internal.common_device_type as cdt
|
||||
from torch.testing._internal.common_device_type import (
|
||||
DeviceTypeTestBase,
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
onlyOn,
|
||||
onlyPRIVATEUSE1,
|
||||
ops,
|
||||
)
|
||||
|
||||
if sys.version_info > (3, 8):
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
init_multigpu_helper,
|
||||
skip_if_lt_x_gpu,
|
||||
get_timeout,
|
||||
#skip_if_rocm,
|
||||
with_dist_debug_levels,
|
||||
)
|
||||
else:
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
init_multigpu_helper,
|
||||
skip_if_lt_x_gpu,
|
||||
get_timeout,
|
||||
skip_if_rocm,
|
||||
with_dist_debug_levels,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
load_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
retry_on_connect_failures,
|
||||
instantiate_parametrized_tests,
|
||||
|
||||
)
|
||||
|
||||
onlyVacc = onlyPRIVATEUSE1
|
||||
|
||||
|
||||
class VaccTestBase(DeviceTypeTestBase):
|
||||
device_type = "vacc"
|
||||
|
||||
|
||||
if VaccTestBase not in cdt.device_type_test_bases:
|
||||
cdt.device_type_test_bases.append(VaccTestBase)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_rng_state():
|
||||
rng_state = torch.get_rng_state()
|
||||
yield
|
||||
torch.set_rng_state(rng_state)
|
||||
Reference in New Issue
Block a user