Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_matmul.py
2026-02-04 17:39:32 +08:00

194 lines
9.2 KiB
Python
Executable File

import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
import numpy as np
class TestMatMulOp(BtTestCase):
def op_impl_base(self, *args):
a, b, bias, c, act_mode, alpha, beta, fast_act, approximate, d, \
a_scale, b_scale, trans_a, trans_b = args
if a_scale is not None:
a = a / a_scale
if b_scale is not None:
b = b / b_scale
if trans_a:
a = a.transpose(0, 1)
if trans_b:
b = b.transpose(0, 1)
mul_out = alpha * torch.matmul(a, b)
if bias is not None:
mul_out += bias
if c is not None:
mul_out += beta * c
if act_mode in act_mode_dict.keys():
active = act_mode_dict[act_mode]
mul_out = active(mul_out.float()).to(a.dtype)
return mul_out
def test_matmul(self):
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [128]
has_res_list = [ False, True]
has_bias_list = [True, False]
trans_a_list = [False, True]
trans_b_list = [False, True]
act_mode_list = ['none', 'relu', 'gelu', 'silu']
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
alpha = 0.625
beta = 1.0
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
torch.manual_seed(1)
print("m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
if has_res :
beta = 1.0
else :
beta = 0.
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
if trans_a:
shape_a = (mat_k, 4, mat_m)
if trans_b:
shape_b = (mat_n, 3, mat_k)
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
input = input0[:, 1, :]
weight = weight0[:, 0, :]
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
output = self.op_impl_base(input,
weight,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
tmo_output = ops.matmul(input,
weight,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
tmo_output_contiguous = ops.matmul(input.contiguous(), weight.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
if act_mode == 'gelu':
tmo_output_high = ops.matmul(input.contiguous(), weight.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, True, None, 1.0, 1.0, trans_a, trans_b)
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
# @unittest.skip("not test")
def test_matmul_int8(self):
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [128]
has_res_list = [True, False]
has_bias_list = [True, False]
trans_a_list = [True, False]
trans_b_list = [True, False]
act_mode_list = ['none', 'relu', 'silu', 'gelu']
dtype_list = [torch.half, torch.float]
alpha = 0.625
beta = 1.0
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
print("int8 test: m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
torch.manual_seed(1)
if has_res :
beta = 1.0
else :
beta = 0.
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
if trans_a:
shape_a = (mat_k, 4, mat_m)
if trans_b:
shape_b = (mat_n, 3, mat_k)
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
input = input0[:, 1, :]
weight = weight0[:, 0, :]
input8, a_scale = QuantByTensor(input, 8)
weight8, b_scale = QuantByTensor(weight, 8)
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
output = self.op_impl_base(input8, weight8,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, a_scale, b_scale, trans_a, trans_b)
tmo_output = ops.matmul(input8, weight8,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
tmo_output_contiguous = ops.matmul(input8.contiguous(), weight8.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
if act_mode == 'gelu':
tmo_output_high = ops.matmul(input8.contiguous(), weight8.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, True, dtype, a_scale, b_scale, trans_a, trans_b)
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True
trans_a_list = [True, False]
trans_b_list = [True, False]
dtype_list = [torch.half, torch.float]
args = product(trans_a_list, trans_b_list, dtype_list)
for trans_a, trans_b, dtype in args:
print("trans_a: {}, trans_b: {}, dtype: {}".format(trans_a, trans_b, dtype))
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
if trans_a:
shape_a = (mat_k, 4, mat_m)
if trans_b:
shape_b = (mat_n, 3, mat_k)
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
a = input0[:, 1, :]
b = weight0[:, 0, :]
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
args = (a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b)
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
a8, a_scale = QuantByTensor(a, 8)
b8, b_scale = QuantByTensor(b, 8)
str_dtype = "half"
if dtype == torch.float:
str_dtype = "float"
elif dtype == torch.bfloat16:
str_dtype = "bfloat16"
args = (a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b)
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
if __name__ == '__main__':
exit(run_unittest(TestMatMulOp))