forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
88
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py
Normal file
88
torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import unittest
|
||||
import torch_mlu_ops as ops
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from common_utils import *
|
||||
import random
|
||||
from itertools import product
|
||||
import time
|
||||
import os
|
||||
|
||||
class TestActive(BtTestCase):
|
||||
def run_gen_case(self, dic):
|
||||
dump_data = dic.pop('dump_data')
|
||||
if dump_data:
|
||||
self.launch(*dic.values())
|
||||
else:
|
||||
input = create_tensor_from_dic(dic['input'])
|
||||
act_mode = dic['act_mode']['data']
|
||||
is_gated = dic['is_gated']['data']
|
||||
active_coef = dic['active_coef']['data']
|
||||
self.launch(input, act_mode, is_gated, active_coef)
|
||||
|
||||
def launch(self, *args):
|
||||
torch_out = self.op_impl_base(*args)
|
||||
tmo_out = tmo.active(*args)
|
||||
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
|
||||
0.004, use_MSE=True, use_RAE=True)
|
||||
|
||||
def op_impl_base(self, *args):
|
||||
input, act_mode, is_gated, active_coef = args
|
||||
channel = input.size(-1)
|
||||
if act_mode == "gelu":
|
||||
if is_gated:
|
||||
out = torch.nn.functional.gelu(input[..., :channel//2])
|
||||
out *= input[..., channel//2:]
|
||||
else:
|
||||
out = torch.nn.functional.gelu(input)
|
||||
else:
|
||||
if act_mode == "silu":
|
||||
active_coef = 1.0
|
||||
elif act_mode == "quick_gelu":
|
||||
active_coef = 1.702
|
||||
|
||||
def swish(input, coef):
|
||||
return input * torch.sigmoid(coef * input)
|
||||
|
||||
if is_gated:
|
||||
out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:]
|
||||
else:
|
||||
out = swish(input, active_coef)
|
||||
return out
|
||||
|
||||
def test_active_random(self):
|
||||
for _ in range(500):
|
||||
dtype_list = [torch.float, torch.half]
|
||||
if torch_mlu.mlu.is_bf16_supported():
|
||||
dtype_list.append(torch.bfloat16)
|
||||
input_dtype = random.choice(dtype_list)
|
||||
batch = random.randint(1, 10)
|
||||
seq = random.randint(1, 2048)
|
||||
hidden_size = random.randrange(2, 8192, 2)
|
||||
is_gated = random.choice([True, False])
|
||||
act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish'])
|
||||
if act_mode == 'silu':
|
||||
active_coef = 1.0
|
||||
elif act_mode == 'quick_gelu':
|
||||
active_coef = 1.702
|
||||
else:
|
||||
active_coef = random.uniform(0, 1)
|
||||
print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \
|
||||
[batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True)
|
||||
input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu")
|
||||
self.launch(input, act_mode, is_gated, active_coef)
|
||||
|
||||
def test_inductor(self):
|
||||
input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu")
|
||||
output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu")
|
||||
is_gated = True
|
||||
act_mode = 'silu'
|
||||
coef = 1.0
|
||||
args = (input, output, None, None, act_mode, is_gated, 0, 0, coef)
|
||||
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(run_unittest(TestActive))
|
||||
Reference in New Issue
Block a user