import torch import torch_mlu import unittest import torch_mlu_ops as ops from common_utils import * from itertools import product import numpy as np class TestMatMulOp(BtTestCase): def op_impl_base(self, *args): a, b, bias, c, act_mode, alpha, beta, fast_act, approximate, d, \ a_scale, b_scale, trans_a, trans_b = args if a_scale is not None: a = a / a_scale if b_scale is not None: b = b / b_scale if trans_a: a = a.transpose(0, 1) if trans_b: b = b.transpose(0, 1) mul_out = alpha * torch.matmul(a, b) if bias is not None: mul_out += bias if c is not None: mul_out += beta * c if act_mode in act_mode_dict.keys(): active = act_mode_dict[act_mode] mul_out = active(mul_out.float()).to(a.dtype) return mul_out def test_matmul(self): mat_m_list = [32] mat_n_list = [256] mat_k_list = [128] has_res_list = [ False, True] has_bias_list = [True, False] trans_a_list = [False, True] trans_b_list = [False, True] act_mode_list = ['none', 'relu', 'gelu', 'silu'] dtype_list = [torch.half, torch.float] if torch_mlu.mlu.is_bf16_supported(): dtype_list.append(torch.bfloat16) alpha = 0.625 beta = 1.0 args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list) for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args: torch.manual_seed(1) print("m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format( mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True) if has_res : beta = 1.0 else : beta = 0. shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n) if trans_a: shape_a = (mat_k, 4, mat_m) if trans_b: shape_b = (mat_n, 3, mat_k) input0 = torch.randn(shape_a, dtype=dtype, device='mlu') weight0 = torch.randn(shape_b, dtype=dtype, device='mlu') input = input0[:, 1, :] weight = weight0[:, 0, :] bias = torch.randn((mat_n), dtype=dtype, device='mlu') residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') output = self.op_impl_base(input, weight, alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b) tmo_output = ops.matmul(input, weight, alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b) tmo_output_contiguous = ops.matmul(input.contiguous(), weight.contiguous(), alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b) if act_mode == 'gelu': tmo_output_high = ops.matmul(input.contiguous(), weight.contiguous(), alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, True, None, 1.0, 1.0, trans_a, trans_b) self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(), 0.004, use_MSE=True, use_RAE=True) self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(), 0.004, use_MSE=True, use_RAE=True) self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(), 0.004, use_MSE=True, use_RAE=True) # @unittest.skip("not test") def test_matmul_int8(self): mat_m_list = [32] mat_n_list = [256] mat_k_list = [128] has_res_list = [True, False] has_bias_list = [True, False] trans_a_list = [True, False] trans_b_list = [True, False] act_mode_list = ['none', 'relu', 'silu', 'gelu'] dtype_list = [torch.half, torch.float] alpha = 0.625 beta = 1.0 args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list) for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args: print("int8 test: m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format( mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True) torch.manual_seed(1) if has_res : beta = 1.0 else : beta = 0. shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n) if trans_a: shape_a = (mat_k, 4, mat_m) if trans_b: shape_b = (mat_n, 3, mat_k) input0 = torch.randn(shape_a, dtype=dtype, device='mlu') weight0 = torch.randn(shape_b, dtype=dtype, device='mlu') input = input0[:, 1, :] weight = weight0[:, 0, :] input8, a_scale = QuantByTensor(input, 8) weight8, b_scale = QuantByTensor(weight, 8) bias = torch.randn((mat_n), dtype=dtype, device='mlu') residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') output = self.op_impl_base(input8, weight8, alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, False, None, a_scale, b_scale, trans_a, trans_b) tmo_output = ops.matmul(input8, weight8, alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b) tmo_output_contiguous = ops.matmul(input8.contiguous(), weight8.contiguous(), alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b) self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(), 0.003, use_MSE=True, use_RAE=True) self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(), 0.003, use_MSE=True, use_RAE=True) if act_mode == 'gelu': tmo_output_high = ops.matmul(input8.contiguous(), weight8.contiguous(), alpha * bias if has_bias else None, residual if has_res else None, act_mode, alpha, beta, False, True, dtype, a_scale, b_scale, trans_a, trans_b) self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(), 0.003, use_MSE=True, use_RAE=True) def test_inductor(self): mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True trans_a_list = [True, False] trans_b_list = [True, False] dtype_list = [torch.half, torch.float] args = product(trans_a_list, trans_b_list, dtype_list) for trans_a, trans_b, dtype in args: print("trans_a: {}, trans_b: {}, dtype: {}".format(trans_a, trans_b, dtype)) shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n) if trans_a: shape_a = (mat_k, 4, mat_m) if trans_b: shape_b = (mat_n, 3, mat_k) input0 = torch.randn(shape_a, dtype=dtype, device='mlu') weight0 = torch.randn(shape_b, dtype=dtype, device='mlu') a = input0[:, 1, :] b = weight0[:, 0, :] bias = torch.randn((mat_n), dtype=dtype, device='mlu') c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') args = (a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b) self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args) a8, a_scale = QuantByTensor(a, 8) b8, b_scale = QuantByTensor(b, 8) str_dtype = "half" if dtype == torch.float: str_dtype = "float" elif dtype == torch.bfloat16: str_dtype = "bfloat16" args = (a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b) self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args) if __name__ == '__main__': exit(run_unittest(TestMatMulOp))