From eae9a9fb9daaea5863c2e274c53fd30e48466699 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Fri, 10 Oct 2025 20:49:08 -0700 Subject: [PATCH] Fix batch invariant ops (#11368) --- .../batch_invariant_ops.py | 10 +- .../test_batch_invariant_ops.py | 163 ++++++++++++++++++ test/srt/run_suite.py | 1 + 3 files changed, 168 insertions(+), 6 deletions(-) create mode 100644 test/srt/batch_invariant/test_batch_invariant_ops.py diff --git a/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py b/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py index 465d53ee0..be0bb3dcf 100644 --- a/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py +++ b/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py @@ -77,8 +77,6 @@ def matmul_kernel_persistent( k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - tile_id_c = start_pid - NUM_SMS - offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -120,10 +118,6 @@ def matmul_kernel_persistent( ) accumulator = tl.dot(a, b, accumulator) - tile_id_c += NUM_SMS - pid_m, pid_n = _compute_pid( - tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS - ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if C_LARGE: @@ -137,6 +131,10 @@ def matmul_kernel_persistent( accumulator += bias if c_ptr.dtype.element_ty == tl.float8e4nv: c = accumulator.to(tl.float8e4nv) + elif c_ptr.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif c_ptr.dtype.element_ty == tl.float32: + c = accumulator.to(tl.float32) else: c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) diff --git a/test/srt/batch_invariant/test_batch_invariant_ops.py b/test/srt/batch_invariant/test_batch_invariant_ops.py new file mode 100644 index 000000000..7acee6cfc --- /dev/null +++ b/test/srt/batch_invariant/test_batch_invariant_ops.py @@ -0,0 +1,163 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py +import math +import unittest + +import torch + +from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode +from sglang.test.test_utils import CustomTestCase + +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + +# Just to get the logging out of the way +with set_batch_invariant_mode(True): + pass + + +class TestBatchInvariantOps(CustomTestCase): + def _test_batch_invariance(self, M, K, N, dtype): + """ + Test that matrix operations produce identical results for: + - Method 1: Matrix-vector multiplication (batch size 1) + - Method 2: Matrix-matrix multiplication, then slice (full batch) + """ + a = torch.linspace(-100, 100, M * K, dtype=dtype).reshape(M, K) + + # Create non-contiguous tensor + b = torch.linspace(-100, 100, K * N, dtype=dtype).reshape(N, K) + b = b.transpose(0, 1) + + # Method 1: Matrix-vector multiplication (batch size 1) + out1 = torch.mm(a[:1], b) + + # Method 2: Matrix-matrix multiplication, then slice (full batch) + out2_pre = torch.mm(a, b) + out2 = out2_pre[:1] + + # Check if results are identical + diff = (out1 - out2).abs().max() + return diff.item() + + def _run_multiple_iterations(self, iters, M, K, N, dtype): + """Run multiple iterations and collect diff statistics""" + difflist = [] + for _ in range(iters): + diff = self._test_batch_invariance(M, K, N, dtype) + difflist.append(diff) + return difflist + + def _assert_batch_invariant_results(self, difflist, dtype, test_name): + """ + Assert that in batch-invariant mode: + 1. All diffs must not be NaN + 2. All diffs must be exactly 0 + 3. Max, min, and diff of diffs must all be 0 + """ + max_diff = max(difflist) + min_diff = min(difflist) + diff_range = max_diff - min_diff + + # Check for NaN values + self.assertFalse( + math.isnan(max_diff), f"{test_name}: max_diff is NaN for {dtype}" + ) + self.assertFalse( + math.isnan(min_diff), f"{test_name}: min_diff is NaN for {dtype}" + ) + self.assertFalse( + math.isnan(diff_range), f"{test_name}: diff_range is NaN for {dtype}" + ) + + # Check that all diffs are exactly 0 + self.assertEqual( + max_diff, + 0.0, + f"{test_name}: max_diff must be 0 in batch-invariant mode, got {max_diff} for {dtype}", + ) + self.assertEqual( + min_diff, + 0.0, + f"{test_name}: min_diff must be 0 in batch-invariant mode, got {min_diff} for {dtype}", + ) + self.assertEqual( + diff_range, + 0.0, + f"{test_name}: diff_range must be 0 in batch-invariant mode, got {diff_range} for {dtype}", + ) + + def test_small_matrices(self): + """Test batch invariance with small matrix sizes""" + test_cases = [ + ("Small-1", 8, 64, 128), + ("Small-2", 16, 128, 256), + ("Small-3", 4, 32, 64), + ] + + for name, M, K, N in test_cases: + with self.subTest(name=name, M=M, K=K, N=N): + for dtype in [torch.float32, torch.bfloat16]: + with self.subTest(dtype=dtype): + # Run with batch-invariant mode + with set_batch_invariant_mode(True): + difflist = self._run_multiple_iterations( + iters=5, M=M, K=K, N=N, dtype=dtype + ) + self._assert_batch_invariant_results(difflist, dtype, name) + + def test_medium_matrices(self): + """Test batch invariance with medium matrix sizes""" + test_cases = [ + ("Medium-1", 32, 128, 1024), + ("Medium-2", 64, 512, 2048), + ("Medium-3", 24, 192, 768), + ] + + for name, M, K, N in test_cases: + with self.subTest(name=name, M=M, K=K, N=N): + for dtype in [torch.float32, torch.bfloat16]: + with self.subTest(dtype=dtype): + # Run with batch-invariant mode + with set_batch_invariant_mode(True): + difflist = self._run_multiple_iterations( + iters=5, M=M, K=K, N=N, dtype=dtype + ) + self._assert_batch_invariant_results(difflist, dtype, name) + + def test_large_matrices(self): + """Test batch invariance with large matrix sizes""" + test_cases = [ + ("Large-1", 128, 1024, 4096), + ("Large-2", 256, 2048, 8192), + ("Large-3", 96, 768, 3072), + ] + + for name, M, K, N in test_cases: + with self.subTest(name=name, M=M, K=K, N=N): + for dtype in [torch.float32, torch.bfloat16]: + with self.subTest(dtype=dtype): + # Run with batch-invariant mode + with set_batch_invariant_mode(True): + difflist = self._run_multiple_iterations( + iters=5, M=M, K=K, N=N, dtype=dtype + ) + self._assert_batch_invariant_results(difflist, dtype, name) + + def test_without_batch_invariant_mode(self): + """ + Test that without batch-invariant mode, results may differ. + This test demonstrates the difference batch-invariant mode makes. + """ + M, K, N = 32, 128, 1024 + dtype = torch.float32 + + # Run without batch-invariant mode + with set_batch_invariant_mode(False): + difflist = self._run_multiple_iterations( + iters=5, M=M, K=K, N=N, dtype=dtype + ) + print(f"Without batch-invariant mode, we get diffs: {difflist}") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 02bbecac6..995a8dc98 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -33,6 +33,7 @@ suites = { TestFile("models/test_generation_models.py", 103), TestFile("models/test_nvidia_nemotron_nano_v2.py", 180), TestFile("models/test_qwen_models.py", 82), + TestFile("batch_invariant/test_batch_invariant_ops.py", 10), TestFile("models/test_reward_models.py", 132), TestFile("models/test_transformers_models.py", 320), TestFile("models/test_vlm_models.py", 741),