164 lines
6.0 KiB
Python
164 lines
6.0 KiB
Python
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py
|
|
import math
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu")
|
|
torch.set_default_device(device_type)
|
|
|
|
# Just to get the logging out of the way
|
|
with set_batch_invariant_mode(True):
|
|
pass
|
|
|
|
|
|
class TestBatchInvariantOps(CustomTestCase):
|
|
def _test_batch_invariance(self, M, K, N, dtype):
|
|
"""
|
|
Test that matrix operations produce identical results for:
|
|
- Method 1: Matrix-vector multiplication (batch size 1)
|
|
- Method 2: Matrix-matrix multiplication, then slice (full batch)
|
|
"""
|
|
a = torch.linspace(-100, 100, M * K, dtype=dtype).reshape(M, K)
|
|
|
|
# Create non-contiguous tensor
|
|
b = torch.linspace(-100, 100, K * N, dtype=dtype).reshape(N, K)
|
|
b = b.transpose(0, 1)
|
|
|
|
# Method 1: Matrix-vector multiplication (batch size 1)
|
|
out1 = torch.mm(a[:1], b)
|
|
|
|
# Method 2: Matrix-matrix multiplication, then slice (full batch)
|
|
out2_pre = torch.mm(a, b)
|
|
out2 = out2_pre[:1]
|
|
|
|
# Check if results are identical
|
|
diff = (out1 - out2).abs().max()
|
|
return diff.item()
|
|
|
|
def _run_multiple_iterations(self, iters, M, K, N, dtype):
|
|
"""Run multiple iterations and collect diff statistics"""
|
|
difflist = []
|
|
for _ in range(iters):
|
|
diff = self._test_batch_invariance(M, K, N, dtype)
|
|
difflist.append(diff)
|
|
return difflist
|
|
|
|
def _assert_batch_invariant_results(self, difflist, dtype, test_name):
|
|
"""
|
|
Assert that in batch-invariant mode:
|
|
1. All diffs must not be NaN
|
|
2. All diffs must be exactly 0
|
|
3. Max, min, and diff of diffs must all be 0
|
|
"""
|
|
max_diff = max(difflist)
|
|
min_diff = min(difflist)
|
|
diff_range = max_diff - min_diff
|
|
|
|
# Check for NaN values
|
|
self.assertFalse(
|
|
math.isnan(max_diff), f"{test_name}: max_diff is NaN for {dtype}"
|
|
)
|
|
self.assertFalse(
|
|
math.isnan(min_diff), f"{test_name}: min_diff is NaN for {dtype}"
|
|
)
|
|
self.assertFalse(
|
|
math.isnan(diff_range), f"{test_name}: diff_range is NaN for {dtype}"
|
|
)
|
|
|
|
# Check that all diffs are exactly 0
|
|
self.assertEqual(
|
|
max_diff,
|
|
0.0,
|
|
f"{test_name}: max_diff must be 0 in batch-invariant mode, got {max_diff} for {dtype}",
|
|
)
|
|
self.assertEqual(
|
|
min_diff,
|
|
0.0,
|
|
f"{test_name}: min_diff must be 0 in batch-invariant mode, got {min_diff} for {dtype}",
|
|
)
|
|
self.assertEqual(
|
|
diff_range,
|
|
0.0,
|
|
f"{test_name}: diff_range must be 0 in batch-invariant mode, got {diff_range} for {dtype}",
|
|
)
|
|
|
|
def test_small_matrices(self):
|
|
"""Test batch invariance with small matrix sizes"""
|
|
test_cases = [
|
|
("Small-1", 8, 64, 128),
|
|
("Small-2", 16, 128, 256),
|
|
("Small-3", 4, 32, 64),
|
|
]
|
|
|
|
for name, M, K, N in test_cases:
|
|
with self.subTest(name=name, M=M, K=K, N=N):
|
|
for dtype in [torch.float32, torch.bfloat16]:
|
|
with self.subTest(dtype=dtype):
|
|
# Run with batch-invariant mode
|
|
with set_batch_invariant_mode(True):
|
|
difflist = self._run_multiple_iterations(
|
|
iters=5, M=M, K=K, N=N, dtype=dtype
|
|
)
|
|
self._assert_batch_invariant_results(difflist, dtype, name)
|
|
|
|
def test_medium_matrices(self):
|
|
"""Test batch invariance with medium matrix sizes"""
|
|
test_cases = [
|
|
("Medium-1", 32, 128, 1024),
|
|
("Medium-2", 64, 512, 2048),
|
|
("Medium-3", 24, 192, 768),
|
|
]
|
|
|
|
for name, M, K, N in test_cases:
|
|
with self.subTest(name=name, M=M, K=K, N=N):
|
|
for dtype in [torch.float32, torch.bfloat16]:
|
|
with self.subTest(dtype=dtype):
|
|
# Run with batch-invariant mode
|
|
with set_batch_invariant_mode(True):
|
|
difflist = self._run_multiple_iterations(
|
|
iters=5, M=M, K=K, N=N, dtype=dtype
|
|
)
|
|
self._assert_batch_invariant_results(difflist, dtype, name)
|
|
|
|
def test_large_matrices(self):
|
|
"""Test batch invariance with large matrix sizes"""
|
|
test_cases = [
|
|
("Large-1", 128, 1024, 4096),
|
|
("Large-2", 256, 2048, 8192),
|
|
("Large-3", 96, 768, 3072),
|
|
]
|
|
|
|
for name, M, K, N in test_cases:
|
|
with self.subTest(name=name, M=M, K=K, N=N):
|
|
for dtype in [torch.float32, torch.bfloat16]:
|
|
with self.subTest(dtype=dtype):
|
|
# Run with batch-invariant mode
|
|
with set_batch_invariant_mode(True):
|
|
difflist = self._run_multiple_iterations(
|
|
iters=5, M=M, K=K, N=N, dtype=dtype
|
|
)
|
|
self._assert_batch_invariant_results(difflist, dtype, name)
|
|
|
|
def test_without_batch_invariant_mode(self):
|
|
"""
|
|
Test that without batch-invariant mode, results may differ.
|
|
This test demonstrates the difference batch-invariant mode makes.
|
|
"""
|
|
M, K, N = 32, 128, 1024
|
|
dtype = torch.float32
|
|
|
|
# Run without batch-invariant mode
|
|
with set_batch_invariant_mode(False):
|
|
difflist = self._run_multiple_iterations(
|
|
iters=5, M=M, K=K, N=N, dtype=dtype
|
|
)
|
|
print(f"Without batch-invariant mode, we get diffs: {difflist}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|