forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
25
torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py
Normal file
25
torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import sys
|
||||
import os
|
||||
build_lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/build/lib"
|
||||
sys.path.append(build_lib_dir)
|
||||
import torch
|
||||
|
||||
class UnitTest:
|
||||
def diff1(self, result: torch.Tensor, baseline: torch.Tensor):
|
||||
result = result.flatten().float().to('cpu')
|
||||
baseline = baseline.flatten().float().to('cpu')
|
||||
assert result.shape == baseline.shape
|
||||
error = torch.abs(baseline - result)
|
||||
denominator = torch.sum(torch.abs(baseline)).item()
|
||||
eps = 0.0 if denominator > 0 else 1e-9
|
||||
diff1 = torch.sum(error) / (denominator + eps)
|
||||
return diff1.item()
|
||||
|
||||
def diff2(self, result: torch.Tensor, baseline: torch.Tensor):
|
||||
result = result.flatten().float().to('cpu')
|
||||
baseline = baseline.flatten().float().to('cpu')
|
||||
error = torch.abs(baseline - result)
|
||||
denominator = torch.sum(baseline**2).item()
|
||||
eps = 0.0 if denominator > 0 else 1e-9
|
||||
diff2 = torch.sqrt(torch.sum(error**2) / (denominator + eps))
|
||||
return diff2.item()
|
||||
Reference in New Issue
Block a user