import torch import unittest import torch_mlu_ops as ops from common_utils import * class TestPreloadOp(BtTestCase): def op_impl_base(self, *args): wegiht, size = args return super().op_impl_base(*args) def test_preload(self): weight = torch.randn((1024, 8, 5, 1024)).half().mlu() ops.preload(weight, weight.element_size() * weight.numel()) torch.mlu.synchronize() def test_inductor(self): weight = torch.randn((1024, 8, 5, 1024)).half().mlu() self.base_opcheck(torch.ops.torch_mlu_ops.preload, (weight, weight.element_size() * weight.numel())) if __name__ == '__main__': exit(run_unittest(TestPreloadOp))