Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py
2026-02-04 17:39:32 +08:00

89 lines
2.9 KiB
Python

import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from common_utils import *
import random
from itertools import product
import time
import os
class TestActive(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
input = create_tensor_from_dic(dic['input'])
act_mode = dic['act_mode']['data']
is_gated = dic['is_gated']['data']
active_coef = dic['active_coef']['data']
self.launch(input, act_mode, is_gated, active_coef)
def launch(self, *args):
torch_out = self.op_impl_base(*args)
tmo_out = tmo.active(*args)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
def op_impl_base(self, *args):
input, act_mode, is_gated, active_coef = args
channel = input.size(-1)
if act_mode == "gelu":
if is_gated:
out = torch.nn.functional.gelu(input[..., :channel//2])
out *= input[..., channel//2:]
else:
out = torch.nn.functional.gelu(input)
else:
if act_mode == "silu":
active_coef = 1.0
elif act_mode == "quick_gelu":
active_coef = 1.702
def swish(input, coef):
return input * torch.sigmoid(coef * input)
if is_gated:
out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:]
else:
out = swish(input, active_coef)
return out
def test_active_random(self):
for _ in range(500):
dtype_list = [torch.float, torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
input_dtype = random.choice(dtype_list)
batch = random.randint(1, 10)
seq = random.randint(1, 2048)
hidden_size = random.randrange(2, 8192, 2)
is_gated = random.choice([True, False])
act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish'])
if act_mode == 'silu':
active_coef = 1.0
elif act_mode == 'quick_gelu':
active_coef = 1.702
else:
active_coef = random.uniform(0, 1)
print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \
[batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True)
input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu")
self.launch(input, act_mode, is_gated, active_coef)
def test_inductor(self):
input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu")
output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu")
is_gated = True
act_mode = 'silu'
coef = 1.0
args = (input, output, None, None, act_mode, is_gated, 0, 0, coef)
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
if __name__ == '__main__':
exit(run_unittest(TestActive))