import torch import torch_mlu import unittest import torch_mlu_ops as tmo from common_utils import * import random class TestQuantizeOp(BtTestCase): def op_impl_base(self, *args): x, smooth, zero = args return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8) def test_random_case(self): torch.manual_seed(0) case_list = set() while(len(case_list) < 100): dtype_list = [torch.half, torch.float] if torch_mlu.mlu.is_bf16_supported(): dtype_list.append(torch.bfloat16) dtype = random.choice(dtype_list) ci = random.randint(1, 4096) co = random.randint(1, 4096) case = (ci, co) if case in case_list: continue else: case_list.add((ci, co)) x = torch.randn(ci, co, device="mlu", dtype=dtype) scale = torch.randn(co, device="mlu", dtype=torch.float32) print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True) param = (x, scale, None) tmo_output = tmo.quantize(*param) torch_output = self.op_impl_base(*param) self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True) def test_inductor(self): x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half) scale = torch.randn(1024, device="mlu", dtype=torch.float32) output = torch.empty(x.size(), dtype=torch.int8, device="mlu") args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False) self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args) if __name__ == '__main__': exit(run_unittest(TestQuantizeOp))