194 lines
9.2 KiB
Python
Executable File
194 lines
9.2 KiB
Python
Executable File
import torch
|
|
import torch_mlu
|
|
import unittest
|
|
import torch_mlu_ops as ops
|
|
from common_utils import *
|
|
from itertools import product
|
|
import numpy as np
|
|
|
|
class TestMatMulOp(BtTestCase):
|
|
def op_impl_base(self, *args):
|
|
a, b, bias, c, act_mode, alpha, beta, fast_act, approximate, d, \
|
|
a_scale, b_scale, trans_a, trans_b = args
|
|
if a_scale is not None:
|
|
a = a / a_scale
|
|
if b_scale is not None:
|
|
b = b / b_scale
|
|
if trans_a:
|
|
a = a.transpose(0, 1)
|
|
if trans_b:
|
|
b = b.transpose(0, 1)
|
|
mul_out = alpha * torch.matmul(a, b)
|
|
if bias is not None:
|
|
mul_out += bias
|
|
if c is not None:
|
|
mul_out += beta * c
|
|
if act_mode in act_mode_dict.keys():
|
|
active = act_mode_dict[act_mode]
|
|
mul_out = active(mul_out.float()).to(a.dtype)
|
|
return mul_out
|
|
|
|
def test_matmul(self):
|
|
mat_m_list = [32]
|
|
mat_n_list = [256]
|
|
mat_k_list = [128]
|
|
has_res_list = [ False, True]
|
|
has_bias_list = [True, False]
|
|
trans_a_list = [False, True]
|
|
trans_b_list = [False, True]
|
|
act_mode_list = ['none', 'relu', 'gelu', 'silu']
|
|
dtype_list = [torch.half, torch.float]
|
|
if torch_mlu.mlu.is_bf16_supported():
|
|
dtype_list.append(torch.bfloat16)
|
|
alpha = 0.625
|
|
beta = 1.0
|
|
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
|
|
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
|
|
torch.manual_seed(1)
|
|
print("m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
|
|
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
|
|
if has_res :
|
|
beta = 1.0
|
|
else :
|
|
beta = 0.
|
|
|
|
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
|
|
if trans_a:
|
|
shape_a = (mat_k, 4, mat_m)
|
|
if trans_b:
|
|
shape_b = (mat_n, 3, mat_k)
|
|
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
|
|
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
|
|
input = input0[:, 1, :]
|
|
weight = weight0[:, 0, :]
|
|
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
|
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
|
|
|
output = self.op_impl_base(input,
|
|
weight,
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
|
|
tmo_output = ops.matmul(input,
|
|
weight,
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
|
|
tmo_output_contiguous = ops.matmul(input.contiguous(), weight.contiguous(),
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
|
|
|
|
if act_mode == 'gelu':
|
|
tmo_output_high = ops.matmul(input.contiguous(), weight.contiguous(),
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, True, None, 1.0, 1.0, trans_a, trans_b)
|
|
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
|
|
0.004, use_MSE=True, use_RAE=True)
|
|
|
|
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
|
|
0.004, use_MSE=True, use_RAE=True)
|
|
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
|
|
0.004, use_MSE=True, use_RAE=True)
|
|
|
|
# @unittest.skip("not test")
|
|
def test_matmul_int8(self):
|
|
mat_m_list = [32]
|
|
mat_n_list = [256]
|
|
mat_k_list = [128]
|
|
has_res_list = [True, False]
|
|
has_bias_list = [True, False]
|
|
trans_a_list = [True, False]
|
|
trans_b_list = [True, False]
|
|
act_mode_list = ['none', 'relu', 'silu', 'gelu']
|
|
dtype_list = [torch.half, torch.float]
|
|
alpha = 0.625
|
|
beta = 1.0
|
|
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
|
|
|
|
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
|
|
print("int8 test: m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
|
|
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
|
|
torch.manual_seed(1)
|
|
if has_res :
|
|
beta = 1.0
|
|
else :
|
|
beta = 0.
|
|
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
|
|
if trans_a:
|
|
shape_a = (mat_k, 4, mat_m)
|
|
if trans_b:
|
|
shape_b = (mat_n, 3, mat_k)
|
|
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
|
|
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
|
|
input = input0[:, 1, :]
|
|
weight = weight0[:, 0, :]
|
|
input8, a_scale = QuantByTensor(input, 8)
|
|
weight8, b_scale = QuantByTensor(weight, 8)
|
|
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
|
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
|
|
|
output = self.op_impl_base(input8, weight8,
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, False, None, a_scale, b_scale, trans_a, trans_b)
|
|
tmo_output = ops.matmul(input8, weight8,
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
|
|
tmo_output_contiguous = ops.matmul(input8.contiguous(), weight8.contiguous(),
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
|
|
|
|
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
|
|
0.003, use_MSE=True, use_RAE=True)
|
|
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
|
|
0.003, use_MSE=True, use_RAE=True)
|
|
|
|
if act_mode == 'gelu':
|
|
tmo_output_high = ops.matmul(input8.contiguous(), weight8.contiguous(),
|
|
alpha * bias if has_bias else None,
|
|
residual if has_res else None,
|
|
act_mode, alpha, beta, False, True, dtype, a_scale, b_scale, trans_a, trans_b)
|
|
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
|
|
0.003, use_MSE=True, use_RAE=True)
|
|
|
|
def test_inductor(self):
|
|
mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True
|
|
trans_a_list = [True, False]
|
|
trans_b_list = [True, False]
|
|
dtype_list = [torch.half, torch.float]
|
|
args = product(trans_a_list, trans_b_list, dtype_list)
|
|
for trans_a, trans_b, dtype in args:
|
|
print("trans_a: {}, trans_b: {}, dtype: {}".format(trans_a, trans_b, dtype))
|
|
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
|
|
if trans_a:
|
|
shape_a = (mat_k, 4, mat_m)
|
|
if trans_b:
|
|
shape_b = (mat_n, 3, mat_k)
|
|
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
|
|
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
|
|
a = input0[:, 1, :]
|
|
b = weight0[:, 0, :]
|
|
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
|
|
c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
|
|
args = (a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
|
|
|
|
a8, a_scale = QuantByTensor(a, 8)
|
|
b8, b_scale = QuantByTensor(b, 8)
|
|
str_dtype = "half"
|
|
if dtype == torch.float:
|
|
str_dtype = "float"
|
|
elif dtype == torch.bfloat16:
|
|
str_dtype = "bfloat16"
|
|
args = (a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
exit(run_unittest(TestMatMulOp))
|