26 lines
1.0 KiB
Python
26 lines
1.0 KiB
Python
import sys
|
|
import os
|
|
build_lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/build/lib"
|
|
sys.path.append(build_lib_dir)
|
|
import torch
|
|
|
|
class UnitTest:
|
|
def diff1(self, result: torch.Tensor, baseline: torch.Tensor):
|
|
result = result.flatten().float().to('cpu')
|
|
baseline = baseline.flatten().float().to('cpu')
|
|
assert result.shape == baseline.shape
|
|
error = torch.abs(baseline - result)
|
|
denominator = torch.sum(torch.abs(baseline)).item()
|
|
eps = 0.0 if denominator > 0 else 1e-9
|
|
diff1 = torch.sum(error) / (denominator + eps)
|
|
return diff1.item()
|
|
|
|
def diff2(self, result: torch.Tensor, baseline: torch.Tensor):
|
|
result = result.flatten().float().to('cpu')
|
|
baseline = baseline.flatten().float().to('cpu')
|
|
error = torch.abs(baseline - result)
|
|
denominator = torch.sum(baseline**2).item()
|
|
eps = 0.0 if denominator > 0 else 1e-9
|
|
diff2 = torch.sqrt(torch.sum(error**2) / (denominator + eps))
|
|
return diff2.item()
|