89 lines
2.9 KiB
Python
89 lines
2.9 KiB
Python
import torch
|
|
import torch_mlu
|
|
import unittest
|
|
import torch_mlu_ops as ops
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
from common_utils import *
|
|
import random
|
|
from itertools import product
|
|
import time
|
|
import os
|
|
|
|
class TestActive(BtTestCase):
|
|
def run_gen_case(self, dic):
|
|
dump_data = dic.pop('dump_data')
|
|
if dump_data:
|
|
self.launch(*dic.values())
|
|
else:
|
|
input = create_tensor_from_dic(dic['input'])
|
|
act_mode = dic['act_mode']['data']
|
|
is_gated = dic['is_gated']['data']
|
|
active_coef = dic['active_coef']['data']
|
|
self.launch(input, act_mode, is_gated, active_coef)
|
|
|
|
def launch(self, *args):
|
|
torch_out = self.op_impl_base(*args)
|
|
tmo_out = tmo.active(*args)
|
|
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
|
0.004, use_MSE=True, use_RAE=True)
|
|
|
|
def op_impl_base(self, *args):
|
|
input, act_mode, is_gated, active_coef = args
|
|
channel = input.size(-1)
|
|
if act_mode == "gelu":
|
|
if is_gated:
|
|
out = torch.nn.functional.gelu(input[..., :channel//2])
|
|
out *= input[..., channel//2:]
|
|
else:
|
|
out = torch.nn.functional.gelu(input)
|
|
else:
|
|
if act_mode == "silu":
|
|
active_coef = 1.0
|
|
elif act_mode == "quick_gelu":
|
|
active_coef = 1.702
|
|
|
|
def swish(input, coef):
|
|
return input * torch.sigmoid(coef * input)
|
|
|
|
if is_gated:
|
|
out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:]
|
|
else:
|
|
out = swish(input, active_coef)
|
|
return out
|
|
|
|
def test_active_random(self):
|
|
for _ in range(500):
|
|
dtype_list = [torch.float, torch.half]
|
|
if torch_mlu.mlu.is_bf16_supported():
|
|
dtype_list.append(torch.bfloat16)
|
|
input_dtype = random.choice(dtype_list)
|
|
batch = random.randint(1, 10)
|
|
seq = random.randint(1, 2048)
|
|
hidden_size = random.randrange(2, 8192, 2)
|
|
is_gated = random.choice([True, False])
|
|
act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish'])
|
|
if act_mode == 'silu':
|
|
active_coef = 1.0
|
|
elif act_mode == 'quick_gelu':
|
|
active_coef = 1.702
|
|
else:
|
|
active_coef = random.uniform(0, 1)
|
|
print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \
|
|
[batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True)
|
|
input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu")
|
|
self.launch(input, act_mode, is_gated, active_coef)
|
|
|
|
def test_inductor(self):
|
|
input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu")
|
|
output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu")
|
|
is_gated = True
|
|
act_mode = 'silu'
|
|
coef = 1.0
|
|
args = (input, output, None, None, act_mode, is_gated, 0, 0, coef)
|
|
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
|
|
|
|
if __name__ == '__main__':
|
|
exit(run_unittest(TestActive))
|