Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_preload.py
2026-02-04 17:39:32 +08:00

22 lines
682 B
Python

import torch
import unittest
import torch_mlu_ops as ops
from common_utils import *
class TestPreloadOp(BtTestCase):
def op_impl_base(self, *args):
wegiht, size = args
return super().op_impl_base(*args)
def test_preload(self):
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
ops.preload(weight, weight.element_size() * weight.numel())
torch.mlu.synchronize()
def test_inductor(self):
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
self.base_opcheck(torch.ops.torch_mlu_ops.preload, (weight, weight.element_size() * weight.numel()))
if __name__ == '__main__':
exit(run_unittest(TestPreloadOp))