import torch import torch_mlu import unittest import torch_mlu_ops as ops import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) from common_utils import * import random from itertools import product import time import os class TestActive(BtTestCase): def run_gen_case(self, dic): dump_data = dic.pop('dump_data') if dump_data: self.launch(*dic.values()) else: input = create_tensor_from_dic(dic['input']) act_mode = dic['act_mode']['data'] is_gated = dic['is_gated']['data'] active_coef = dic['active_coef']['data'] self.launch(input, act_mode, is_gated, active_coef) def launch(self, *args): torch_out = self.op_impl_base(*args) tmo_out = tmo.active(*args) self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), 0.004, use_MSE=True, use_RAE=True) def op_impl_base(self, *args): input, act_mode, is_gated, active_coef = args channel = input.size(-1) if act_mode == "gelu": if is_gated: out = torch.nn.functional.gelu(input[..., :channel//2]) out *= input[..., channel//2:] else: out = torch.nn.functional.gelu(input) else: if act_mode == "silu": active_coef = 1.0 elif act_mode == "quick_gelu": active_coef = 1.702 def swish(input, coef): return input * torch.sigmoid(coef * input) if is_gated: out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:] else: out = swish(input, active_coef) return out def test_active_random(self): for _ in range(500): dtype_list = [torch.float, torch.half] if torch_mlu.mlu.is_bf16_supported(): dtype_list.append(torch.bfloat16) input_dtype = random.choice(dtype_list) batch = random.randint(1, 10) seq = random.randint(1, 2048) hidden_size = random.randrange(2, 8192, 2) is_gated = random.choice([True, False]) act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish']) if act_mode == 'silu': active_coef = 1.0 elif act_mode == 'quick_gelu': active_coef = 1.702 else: active_coef = random.uniform(0, 1) print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \ [batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True) input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu") self.launch(input, act_mode, is_gated, active_coef) def test_inductor(self): input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu") output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu") is_gated = True act_mode = 'silu' coef = 1.0 args = (input, output, None, None, act_mode, is_gated, 0, 0, coef) self.base_opcheck(torch.ops.torch_mlu_ops.active, args) if __name__ == '__main__': exit(run_unittest(TestActive))