22 lines
682 B
Python
22 lines
682 B
Python
import torch
|
|
import unittest
|
|
import torch_mlu_ops as ops
|
|
from common_utils import *
|
|
|
|
class TestPreloadOp(BtTestCase):
|
|
def op_impl_base(self, *args):
|
|
wegiht, size = args
|
|
return super().op_impl_base(*args)
|
|
|
|
def test_preload(self):
|
|
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
|
|
ops.preload(weight, weight.element_size() * weight.numel())
|
|
torch.mlu.synchronize()
|
|
|
|
def test_inductor(self):
|
|
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.preload, (weight, weight.element_size() * weight.numel()))
|
|
|
|
if __name__ == '__main__':
|
|
exit(run_unittest(TestPreloadOp))
|