add ops
This commit is contained in:
46
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py
Executable file
46
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py
Executable file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as tmo
|
||||
from common_utils import *
|
||||
import random
|
||||
|
||||
|
||||
class TestQuantizeOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
x, smooth, zero = args
|
||||
return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8)
|
||||
|
||||
def test_random_case(self):
|
||||
torch.manual_seed(0)
|
||||
case_list = set()
|
||||
while(len(case_list) < 100):
|
||||
dtype_list = [torch.half, torch.float]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
dtype = random.choice(dtype_list)
|
||||
ci = random.randint(1, 4096)
|
||||
co = random.randint(1, 4096)
|
||||
case = (ci, co)
|
||||
if case in case_list:
|
||||
continue
|
||||
else:
|
||||
case_list.add((ci, co))
|
||||
x = torch.randn(ci, co, device="mlu", dtype=dtype)
|
||||
scale = torch.randn(co, device="mlu", dtype=torch.float32)
|
||||
print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True)
|
||||
param = (x, scale, None)
|
||||
tmo_output = tmo.quantize(*param)
|
||||
torch_output = self.op_impl_base(*param)
|
||||
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half)
|
||||
scale = torch.randn(1024, device="mlu", dtype=torch.float32)
|
||||
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
|
||||
args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestQuantizeOp))
|
||||
Reference in New Issue
Block a user