66 lines
3.5 KiB
Python
Executable File
66 lines
3.5 KiB
Python
Executable File
import torch
|
|
import torch_mlu
|
|
import unittest
|
|
import torch_mlu_ops as ops
|
|
from common_utils import *
|
|
|
|
class TestFFNOp(BtTestCase):
|
|
def op_impl_base(self, *args):
|
|
input, w1, bias1, w2, bias2, w3, bias3, act_mode = args
|
|
up = F.linear(input, w1, bias1)
|
|
act = act_mode_dict[act_mode](up.float()).to(input.dtype)
|
|
if w3 is not None:
|
|
gate = F.linear(input, w3, bias3)
|
|
act = act * gate
|
|
output = F.linear(act, w2, bias2)
|
|
return output
|
|
|
|
def test_ffn(self):
|
|
input_size_list = [128, 256, 512]
|
|
hidden_size = 1024
|
|
seq_len_list = [10, 16, 20]
|
|
bool_value_list = [True, False]
|
|
batch = 5
|
|
dtype_list = [torch.half]
|
|
if torch_mlu.mlu.is_bf16_supported():
|
|
dtype_list.append(torch.bfloat16)
|
|
for input_size, seq_len, bool_value in zip(input_size_list,
|
|
seq_len_list,
|
|
bool_value_list):
|
|
print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format(
|
|
input_size, seq_len, bool_value, bool_value), flush=True)
|
|
use_gate = bool_value
|
|
for dtype in dtype_list:
|
|
w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
|
|
b1 = torch.randn((hidden_size), dtype=dtype, device="mlu")
|
|
w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
|
|
b2 = torch.randn((input_size), dtype=dtype, device="mlu")
|
|
w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None
|
|
b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None
|
|
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
|
|
args = (input, w1, b1, w2, b2, w3, b3, 'silu')
|
|
output = self.op_impl_base(*args)
|
|
tmo_output1 = ops.ffn(*args)
|
|
self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(),
|
|
0.005, use_MSE=True, use_RAE=True)
|
|
# use matmul to implement ffn
|
|
f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1
|
|
f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1
|
|
pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0)
|
|
act_out = ops.active(pre_gemm_out, 'silu', use_gate)
|
|
tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0)
|
|
tmo_output2 = tmo_output2.view(batch, seq_len, input_size)
|
|
self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(),
|
|
0.005, use_MSE=True, use_RAE=True)
|
|
|
|
def test_inductor(self):
|
|
batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
|
|
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
|
|
up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
|
|
down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
|
|
args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
|
|
|
|
if __name__ == '__main__':
|
|
exit(run_unittest(TestFFNOp))
|