47 lines
1.8 KiB
Python
Executable File
47 lines
1.8 KiB
Python
Executable File
import torch
|
|
import torch_mlu
|
|
import unittest
|
|
import torch_mlu_ops as tmo
|
|
from common_utils import *
|
|
import random
|
|
|
|
|
|
class TestQuantizeOp(BtTestCase):
|
|
def op_impl_base(self, *args):
|
|
x, smooth, zero = args
|
|
return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8)
|
|
|
|
def test_random_case(self):
|
|
torch.manual_seed(0)
|
|
case_list = set()
|
|
while(len(case_list) < 100):
|
|
dtype_list = [torch.half, torch.float]
|
|
if torch_mlu.mlu.is_bf16_supported():
|
|
dtype_list.append(torch.bfloat16)
|
|
dtype = random.choice(dtype_list)
|
|
ci = random.randint(1, 4096)
|
|
co = random.randint(1, 4096)
|
|
case = (ci, co)
|
|
if case in case_list:
|
|
continue
|
|
else:
|
|
case_list.add((ci, co))
|
|
x = torch.randn(ci, co, device="mlu", dtype=dtype)
|
|
scale = torch.randn(co, device="mlu", dtype=torch.float32)
|
|
print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True)
|
|
param = (x, scale, None)
|
|
tmo_output = tmo.quantize(*param)
|
|
torch_output = self.op_impl_base(*param)
|
|
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
|
|
|
|
def test_inductor(self):
|
|
x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half)
|
|
scale = torch.randn(1024, device="mlu", dtype=torch.float32)
|
|
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
|
|
args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
exit(run_unittest(TestQuantizeOp))
|