### What this PR does / why we need it? Add a custom op to acclerater the deepseek model. The fusion ops combine the bmm and transpose together, which is applied to mla module. Cherry-pick from this commtid c68ddc11ce53334fc9a17bad58342148cbf14e86 ### Does this PR introduce _any_ user-facing change? No --------- Signed-off-by: hust17yixuan <303660421@qq.com>
142 lines
5.4 KiB
Python
142 lines
5.4 KiB
Python
import random
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from vllm_ascend.utils import enable_custom_op
|
|
|
|
enable_custom_op()
|
|
|
|
torch.set_printoptions(threshold=float("inf"))
|
|
|
|
|
|
class TestMatrixMultiplication(unittest.TestCase):
|
|
|
|
def compute_golden(self, a, b, res1, m, n):
|
|
"""Compute reference result (golden)"""
|
|
torch.bmm(a.transpose(0, 1),
|
|
b,
|
|
out=res1.view(-1, m, n).transpose(0, 1))
|
|
|
|
def assert_tensors_almost_equal(self, actual, expected, dtype):
|
|
"""Check if two tensors are approximately equal (considering floating point errors)"""
|
|
self.assertEqual(actual.shape, expected.shape, "Shape mismatch")
|
|
|
|
# Check for NaN
|
|
self.assertFalse(
|
|
torch.isnan(actual).any(), "Actual result contains NaN")
|
|
self.assertFalse(
|
|
torch.isnan(expected).any(), "Expected result contains NaN")
|
|
|
|
# Check for Inf
|
|
self.assertFalse(
|
|
torch.isinf(actual).any(), "Actual result contains Inf")
|
|
self.assertFalse(
|
|
torch.isinf(expected).any(), "Expected result contains Inf")
|
|
|
|
# Set different tolerances based on data type
|
|
if dtype == torch.float16:
|
|
rtol, atol = 1e-5, 1e-5
|
|
else: # bfloat16
|
|
rtol, atol = 1.5e-5, 1.5e-5
|
|
|
|
# Compare values
|
|
diff = torch.abs(actual - expected)
|
|
max_diff = diff.max().item()
|
|
max_expected = torch.abs(expected).max().item()
|
|
|
|
# Check relative and absolute errors
|
|
if max_expected > 0:
|
|
relative_diff = max_diff / max_expected
|
|
self.assertLessEqual(
|
|
relative_diff,
|
|
rtol,
|
|
f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}",
|
|
)
|
|
|
|
self.assertLessEqual(max_diff, atol,
|
|
f"Absolute error too large: {max_diff} > {atol}")
|
|
|
|
def test_boundary_conditions(self):
|
|
"""Test boundary conditions"""
|
|
test_cases = [
|
|
# (b, m, k, n)
|
|
(1, 1, 1, 1), # Minimum size
|
|
(1, 10, 1, 1), # b=1
|
|
(10, 1, 1, 10), # m=1
|
|
(5, 5, 1, 5), # k=1
|
|
(2, 2, 2, 1), # n=1
|
|
(100, 1, 1, 100), # Flat case
|
|
(1, 100, 100, 1), # Flat case
|
|
(2, 3, 4, 5), # Random small size
|
|
(10, 20, 30, 40), # Medium size
|
|
(36, 128, 512, 128), # target case
|
|
(8, 160, 512, 128),
|
|
]
|
|
|
|
dtypes = [torch.float16, torch.bfloat16]
|
|
|
|
for dtype in dtypes:
|
|
for b, m, k, n in test_cases:
|
|
with self.subTest(dtype=dtype, shape=f"({b}, {m}, {k}, {n})"):
|
|
a = torch.randn(b, m, k, dtype=dtype, device="npu")
|
|
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
|
|
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
|
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
|
|
|
self.compute_golden(a, b_tensor, res1, m, n)
|
|
torch.ops._C_ascend.batch_matmul_transpose(
|
|
a, b_tensor, res2)
|
|
|
|
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
|
dtype)
|
|
|
|
def test_random_shapes(self):
|
|
"""Test randomly generated shapes"""
|
|
num_tests = 1
|
|
dtypes = [torch.float16, torch.bfloat16]
|
|
|
|
for dtype in dtypes:
|
|
for _ in range(num_tests):
|
|
# Generate reasonable random sizes
|
|
b = random.randint(1, 500)
|
|
m = random.randint(1, 500)
|
|
k = random.randint(1, 500)
|
|
n = random.randint(1, 500)
|
|
|
|
with self.subTest(dtype=dtype,
|
|
shape=f"Random ({b}, {m}, {k}, {n})"):
|
|
a = torch.randn(b, m, k, dtype=dtype, device="npu")
|
|
b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu")
|
|
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
|
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
|
|
|
self.compute_golden(a, b_tensor, res1, m, n)
|
|
torch.ops._C_ascend.batch_matmul_transpose(
|
|
a, b_tensor, res2)
|
|
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
|
dtype)
|
|
|
|
def test_zero_values(self):
|
|
"""Test zero input values"""
|
|
dtypes = [torch.float16, torch.bfloat16]
|
|
b, m, k, n = 5, 4, 3, 2
|
|
|
|
for dtype in dtypes:
|
|
with self.subTest(dtype=dtype):
|
|
a = torch.zeros(b, m, k, dtype=dtype, device="npu")
|
|
b_tensor = torch.zeros(m, k, n, dtype=dtype, device="npu")
|
|
res1 = torch.empty((b, m * n), dtype=dtype, device="npu")
|
|
res2 = torch.empty((b, m, n), dtype=dtype, device="npu")
|
|
|
|
self.compute_golden(a, b_tensor, res1, m, n)
|
|
torch.ops._C_ascend.batch_matmul_transpose(a, b_tensor, res2)
|
|
|
|
self.assert_tensors_almost_equal(res1.view(-1, m, n), res2,
|
|
dtype)
|
|
self.assertTrue(torch.all(res2 == 0))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|