forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
65
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py
Executable file
65
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py
Executable file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
from common_utils import *
|
||||
|
||||
class TestFFNOp(BtTestCase):
|
||||
def op_impl_base(self, *args):
|
||||
input, w1, bias1, w2, bias2, w3, bias3, act_mode = args
|
||||
up = F.linear(input, w1, bias1)
|
||||
act = act_mode_dict[act_mode](up.float()).to(input.dtype)
|
||||
if w3 is not None:
|
||||
gate = F.linear(input, w3, bias3)
|
||||
act = act * gate
|
||||
output = F.linear(act, w2, bias2)
|
||||
return output
|
||||
|
||||
def test_ffn(self):
|
||||
input_size_list = [128, 256, 512]
|
||||
hidden_size = 1024
|
||||
seq_len_list = [10, 16, 20]
|
||||
bool_value_list = [True, False]
|
||||
batch = 5
|
||||
dtype_list = [torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
for input_size, seq_len, bool_value in zip(input_size_list,
|
||||
seq_len_list,
|
||||
bool_value_list):
|
||||
print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format(
|
||||
input_size, seq_len, bool_value, bool_value), flush=True)
|
||||
use_gate = bool_value
|
||||
for dtype in dtype_list:
|
||||
w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
|
||||
b1 = torch.randn((hidden_size), dtype=dtype, device="mlu")
|
||||
w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
|
||||
b2 = torch.randn((input_size), dtype=dtype, device="mlu")
|
||||
w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None
|
||||
b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None
|
||||
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
|
||||
args = (input, w1, b1, w2, b2, w3, b3, 'silu')
|
||||
output = self.op_impl_base(*args)
|
||||
tmo_output1 = ops.ffn(*args)
|
||||
self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
# use matmul to implement ffn
|
||||
f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1
|
||||
f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1
|
||||
pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0)
|
||||
act_out = ops.active(pre_gemm_out, 'silu', use_gate)
|
||||
tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0)
|
||||
tmo_output2 = tmo_output2.view(batch, seq_len, input_size)
|
||||
self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(),
|
||||
0.005, use_MSE=True, use_RAE=True)
|
||||
|
||||
def test_inductor(self):
|
||||
batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
|
||||
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
|
||||
up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
|
||||
down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
|
||||
args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestFFNOp))
|
||||
Reference in New Issue
Block a user