Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py
2026-02-04 17:39:32 +08:00

66 lines
3.5 KiB
Python
Executable File

import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
class TestFFNOp(BtTestCase):
def op_impl_base(self, *args):
input, w1, bias1, w2, bias2, w3, bias3, act_mode = args
up = F.linear(input, w1, bias1)
act = act_mode_dict[act_mode](up.float()).to(input.dtype)
if w3 is not None:
gate = F.linear(input, w3, bias3)
act = act * gate
output = F.linear(act, w2, bias2)
return output
def test_ffn(self):
input_size_list = [128, 256, 512]
hidden_size = 1024
seq_len_list = [10, 16, 20]
bool_value_list = [True, False]
batch = 5
dtype_list = [torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for input_size, seq_len, bool_value in zip(input_size_list,
seq_len_list,
bool_value_list):
print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format(
input_size, seq_len, bool_value, bool_value), flush=True)
use_gate = bool_value
for dtype in dtype_list:
w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
b1 = torch.randn((hidden_size), dtype=dtype, device="mlu")
w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
b2 = torch.randn((input_size), dtype=dtype, device="mlu")
w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None
b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
args = (input, w1, b1, w2, b2, w3, b3, 'silu')
output = self.op_impl_base(*args)
tmo_output1 = ops.ffn(*args)
self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
# use matmul to implement ffn
f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1
f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1
pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0)
act_out = ops.active(pre_gemm_out, 'silu', use_gate)
tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0)
tmo_output2 = tmo_output2.view(batch, seq_len, input_size)
self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
def test_inductor(self):
batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.)
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
if __name__ == '__main__':
exit(run_unittest(TestFFNOp))