Files
2026-02-04 17:39:32 +08:00

47 lines
1.8 KiB
Python
Executable File

import torch
import torch_mlu
import unittest
import torch_mlu_ops as tmo
from common_utils import *
import random
class TestQuantizeOp(BtTestCase):
def op_impl_base(self, *args):
x, smooth, zero = args
return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8)
def test_random_case(self):
torch.manual_seed(0)
case_list = set()
while(len(case_list) < 100):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
dtype = random.choice(dtype_list)
ci = random.randint(1, 4096)
co = random.randint(1, 4096)
case = (ci, co)
if case in case_list:
continue
else:
case_list.add((ci, co))
x = torch.randn(ci, co, device="mlu", dtype=dtype)
scale = torch.randn(co, device="mlu", dtype=torch.float32)
print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True)
param = (x, scale, None)
tmo_output = tmo.quantize(*param)
torch_output = self.op_impl_base(*param)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
def test_inductor(self):
x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half)
scale = torch.randn(1024, device="mlu", dtype=torch.float32)
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
if __name__ == '__main__':
exit(run_unittest(TestQuantizeOp))