import torch import torch_mlu import unittest import torch_mlu_ops as ops from common_utils import * class TestFFNOp(BtTestCase): def op_impl_base(self, *args): input, w1, bias1, w2, bias2, w3, bias3, act_mode = args up = F.linear(input, w1, bias1) act = act_mode_dict[act_mode](up.float()).to(input.dtype) if w3 is not None: gate = F.linear(input, w3, bias3) act = act * gate output = F.linear(act, w2, bias2) return output def test_ffn(self): input_size_list = [128, 256, 512] hidden_size = 1024 seq_len_list = [10, 16, 20] bool_value_list = [True, False] batch = 5 dtype_list = [torch.half] if torch_mlu.mlu.is_bf16_supported(): dtype_list.append(torch.bfloat16) for input_size, seq_len, bool_value in zip(input_size_list, seq_len_list, bool_value_list): print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format( input_size, seq_len, bool_value, bool_value), flush=True) use_gate = bool_value for dtype in dtype_list: w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") b1 = torch.randn((hidden_size), dtype=dtype, device="mlu") w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu") b2 = torch.randn((input_size), dtype=dtype, device="mlu") w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu") args = (input, w1, b1, w2, b2, w3, b3, 'silu') output = self.op_impl_base(*args) tmo_output1 = ops.ffn(*args) self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(), 0.005, use_MSE=True, use_RAE=True) # use matmul to implement ffn f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1 f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1 pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0) act_out = ops.active(pre_gemm_out, 'silu', use_gate) tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0) tmo_output2 = tmo_output2.view(batch, seq_len, input_size) self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(), 0.005, use_MSE=True, use_RAE=True) def test_inductor(self): batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu") up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu") args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.) self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args) if __name__ == '__main__': exit(run_unittest(TestFFNOp))