Fix batch invariant ops (#11368)
This commit is contained in:
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
|
||||
)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||
)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
|
||||
accumulator += bias
|
||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||
c = accumulator.to(tl.float8e4nv)
|
||||
elif c_ptr.dtype.element_ty == tl.bfloat16:
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
elif c_ptr.dtype.element_ty == tl.float32:
|
||||
c = accumulator.to(tl.float32)
|
||||
else:
|
||||
c = accumulator.to(tl.float16)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
163
test/srt/batch_invariant/test_batch_invariant_ops.py
Normal file
163
test/srt/batch_invariant/test_batch_invariant_ops.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu")
|
||||
torch.set_default_device(device_type)
|
||||
|
||||
# Just to get the logging out of the way
|
||||
with set_batch_invariant_mode(True):
|
||||
pass
|
||||
|
||||
|
||||
class TestBatchInvariantOps(CustomTestCase):
|
||||
def _test_batch_invariance(self, M, K, N, dtype):
|
||||
"""
|
||||
Test that matrix operations produce identical results for:
|
||||
- Method 1: Matrix-vector multiplication (batch size 1)
|
||||
- Method 2: Matrix-matrix multiplication, then slice (full batch)
|
||||
"""
|
||||
a = torch.linspace(-100, 100, M * K, dtype=dtype).reshape(M, K)
|
||||
|
||||
# Create non-contiguous tensor
|
||||
b = torch.linspace(-100, 100, K * N, dtype=dtype).reshape(N, K)
|
||||
b = b.transpose(0, 1)
|
||||
|
||||
# Method 1: Matrix-vector multiplication (batch size 1)
|
||||
out1 = torch.mm(a[:1], b)
|
||||
|
||||
# Method 2: Matrix-matrix multiplication, then slice (full batch)
|
||||
out2_pre = torch.mm(a, b)
|
||||
out2 = out2_pre[:1]
|
||||
|
||||
# Check if results are identical
|
||||
diff = (out1 - out2).abs().max()
|
||||
return diff.item()
|
||||
|
||||
def _run_multiple_iterations(self, iters, M, K, N, dtype):
|
||||
"""Run multiple iterations and collect diff statistics"""
|
||||
difflist = []
|
||||
for _ in range(iters):
|
||||
diff = self._test_batch_invariance(M, K, N, dtype)
|
||||
difflist.append(diff)
|
||||
return difflist
|
||||
|
||||
def _assert_batch_invariant_results(self, difflist, dtype, test_name):
|
||||
"""
|
||||
Assert that in batch-invariant mode:
|
||||
1. All diffs must not be NaN
|
||||
2. All diffs must be exactly 0
|
||||
3. Max, min, and diff of diffs must all be 0
|
||||
"""
|
||||
max_diff = max(difflist)
|
||||
min_diff = min(difflist)
|
||||
diff_range = max_diff - min_diff
|
||||
|
||||
# Check for NaN values
|
||||
self.assertFalse(
|
||||
math.isnan(max_diff), f"{test_name}: max_diff is NaN for {dtype}"
|
||||
)
|
||||
self.assertFalse(
|
||||
math.isnan(min_diff), f"{test_name}: min_diff is NaN for {dtype}"
|
||||
)
|
||||
self.assertFalse(
|
||||
math.isnan(diff_range), f"{test_name}: diff_range is NaN for {dtype}"
|
||||
)
|
||||
|
||||
# Check that all diffs are exactly 0
|
||||
self.assertEqual(
|
||||
max_diff,
|
||||
0.0,
|
||||
f"{test_name}: max_diff must be 0 in batch-invariant mode, got {max_diff} for {dtype}",
|
||||
)
|
||||
self.assertEqual(
|
||||
min_diff,
|
||||
0.0,
|
||||
f"{test_name}: min_diff must be 0 in batch-invariant mode, got {min_diff} for {dtype}",
|
||||
)
|
||||
self.assertEqual(
|
||||
diff_range,
|
||||
0.0,
|
||||
f"{test_name}: diff_range must be 0 in batch-invariant mode, got {diff_range} for {dtype}",
|
||||
)
|
||||
|
||||
def test_small_matrices(self):
|
||||
"""Test batch invariance with small matrix sizes"""
|
||||
test_cases = [
|
||||
("Small-1", 8, 64, 128),
|
||||
("Small-2", 16, 128, 256),
|
||||
("Small-3", 4, 32, 64),
|
||||
]
|
||||
|
||||
for name, M, K, N in test_cases:
|
||||
with self.subTest(name=name, M=M, K=K, N=N):
|
||||
for dtype in [torch.float32, torch.bfloat16]:
|
||||
with self.subTest(dtype=dtype):
|
||||
# Run with batch-invariant mode
|
||||
with set_batch_invariant_mode(True):
|
||||
difflist = self._run_multiple_iterations(
|
||||
iters=5, M=M, K=K, N=N, dtype=dtype
|
||||
)
|
||||
self._assert_batch_invariant_results(difflist, dtype, name)
|
||||
|
||||
def test_medium_matrices(self):
|
||||
"""Test batch invariance with medium matrix sizes"""
|
||||
test_cases = [
|
||||
("Medium-1", 32, 128, 1024),
|
||||
("Medium-2", 64, 512, 2048),
|
||||
("Medium-3", 24, 192, 768),
|
||||
]
|
||||
|
||||
for name, M, K, N in test_cases:
|
||||
with self.subTest(name=name, M=M, K=K, N=N):
|
||||
for dtype in [torch.float32, torch.bfloat16]:
|
||||
with self.subTest(dtype=dtype):
|
||||
# Run with batch-invariant mode
|
||||
with set_batch_invariant_mode(True):
|
||||
difflist = self._run_multiple_iterations(
|
||||
iters=5, M=M, K=K, N=N, dtype=dtype
|
||||
)
|
||||
self._assert_batch_invariant_results(difflist, dtype, name)
|
||||
|
||||
def test_large_matrices(self):
|
||||
"""Test batch invariance with large matrix sizes"""
|
||||
test_cases = [
|
||||
("Large-1", 128, 1024, 4096),
|
||||
("Large-2", 256, 2048, 8192),
|
||||
("Large-3", 96, 768, 3072),
|
||||
]
|
||||
|
||||
for name, M, K, N in test_cases:
|
||||
with self.subTest(name=name, M=M, K=K, N=N):
|
||||
for dtype in [torch.float32, torch.bfloat16]:
|
||||
with self.subTest(dtype=dtype):
|
||||
# Run with batch-invariant mode
|
||||
with set_batch_invariant_mode(True):
|
||||
difflist = self._run_multiple_iterations(
|
||||
iters=5, M=M, K=K, N=N, dtype=dtype
|
||||
)
|
||||
self._assert_batch_invariant_results(difflist, dtype, name)
|
||||
|
||||
def test_without_batch_invariant_mode(self):
|
||||
"""
|
||||
Test that without batch-invariant mode, results may differ.
|
||||
This test demonstrates the difference batch-invariant mode makes.
|
||||
"""
|
||||
M, K, N = 32, 128, 1024
|
||||
dtype = torch.float32
|
||||
|
||||
# Run without batch-invariant mode
|
||||
with set_batch_invariant_mode(False):
|
||||
difflist = self._run_multiple_iterations(
|
||||
iters=5, M=M, K=K, N=N, dtype=dtype
|
||||
)
|
||||
print(f"Without batch-invariant mode, we get diffs: {difflist}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -33,6 +33,7 @@ suites = {
|
||||
TestFile("models/test_generation_models.py", 103),
|
||||
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
|
||||
TestFile("models/test_qwen_models.py", 82),
|
||||
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
|
||||
TestFile("models/test_reward_models.py", 132),
|
||||
TestFile("models/test_transformers_models.py", 320),
|
||||
TestFile("models/test_vlm_models.py", 741),
|
||||
|
||||
Reference in New Issue
Block a user