diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 6944f9a44..c87f8d548 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -80,7 +80,8 @@ jobs: - name: Install run: | - pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 + bash scripts/ci_install_dependency.sh + pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2 pip3 uninstall sgl-kernel -y || true pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel @@ -89,7 +90,7 @@ jobs: timeout-minutes: 30 run: | cd sgl-kernel - find tests -name "test_*.py" | xargs -n 1 python3 + pytest tests/ - name: Uninstall dependencies run: | diff --git a/sgl-kernel/tests/speculative/test_eagle_utils.py b/sgl-kernel/tests/speculative/test_eagle_utils.py index 1514029ec..12aa2e498 100644 --- a/sgl-kernel/tests/speculative/test_eagle_utils.py +++ b/sgl-kernel/tests/speculative/test_eagle_utils.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.nn.functional as F from sgl_kernel import verify_tree_greedy @@ -85,14 +86,14 @@ def test_verify_tree_greedy(): print(f"{accept_index=}") print(f"{accept_token_num=}") - return predicts, accept_index, accept_token_num - - -if __name__ == "__main__": - predicts, accept_index, accept_token_num = test_verify_tree_greedy() + # Check the expected output. assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] assert accept_index.tolist() == [ [0, 3, 4, 5], [6, 10, 11, -1], ] assert accept_token_num.tolist() == [3, 2] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/speculative/test_speculative_sampling.py b/sgl-kernel/tests/speculative/test_speculative_sampling.py index 2d45db2d0..93f3f5093 100644 --- a/sgl-kernel/tests/speculative/test_speculative_sampling.py +++ b/sgl-kernel/tests/speculative/test_speculative_sampling.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.nn.functional as F from sgl_kernel import tree_speculative_sampling_target_only @@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc print(f"{accept_index=}") print(f"{accept_token_num=}") - return predicts, accept_index, accept_token_num + if threshold_single == 1 and threshold_acc == 1: + assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 3, 4, 5], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [3, 2] + elif threshold_single == 0 and threshold_acc == 0: + assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 1, 2, -1], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [2, 2] if __name__ == "__main__": - predicts, accept_index, accept_token_num = ( - test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1) - ) - assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] - assert accept_index.tolist() == [ - [0, 3, 4, 5], - [6, 10, 11, -1], - ] - assert accept_token_num.tolist() == [3, 2] - - predicts, accept_index, accept_token_num = ( - test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0) - ) - assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18] - assert accept_index.tolist() == [ - [0, 1, 2, -1], - [6, 10, 11, -1], - ] - assert accept_token_num.tolist() == [2, 2] + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_awq_dequant.py b/sgl-kernel/tests/test_awq_dequant.py index bad3e2c10..33380180b 100644 --- a/sgl-kernel/tests/test_awq_dequant.py +++ b/sgl-kernel/tests/test_awq_dequant.py @@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations( if __name__ == "__main__": - # Run the specific test function directly pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_cublas_grouped_gemm.py b/sgl-kernel/tests/test_cublas_grouped_gemm.py index 9aac569f2..70b3dc5cf 100644 --- a/sgl-kernel/tests/test_cublas_grouped_gemm.py +++ b/sgl-kernel/tests/test_cublas_grouped_gemm.py @@ -1,49 +1,40 @@ -import unittest - +import pytest import torch from sgl_kernel import cublas_grouped_gemm def torch_grouped_gemm(a_array, b_array, out_dtype): - c_array = [] - for a, b in zip(a_array, b_array): - c_array.append(torch.matmul(a, b.t()).to(out_dtype)) - return c_array + return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)] -class TestGroupedGemm(unittest.TestCase): - def _test_accuracy(self, Ms, Ns, Ks, out_dtype): - group_count = len(Ms) - a_array = [] - b_array = [] - c_array_cublas = [] - for i in range(group_count): - M, N, K = Ms[i], Ns[i], Ks[i] - a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5) - b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5) - c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype)) +skip_condition = not torch.cuda.is_available() or ( + torch.version.cuda is None + or tuple(map(int, torch.version.cuda.split("."))) < (12, 5) +) - c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype) - cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype) - for i in range(group_count): - M, N, K = Ms[i], Ns[i], Ks[i] - torch.testing.assert_close(c_array_torch[i], c_array_cublas[i]) - print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") +@pytest.mark.skipif( + skip_condition, reason="CUDA not available or CUDA version lower than 12.5" +) +@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024]) +@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096]) +@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192]) +def test_grouped_gemm_accuracy(out_dtype, M, N, K): + a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5 + b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5 + expected = torch.matmul(a, b.t()).to(out_dtype) - def test_accuracy(self): - Ms = [1, 16, 32, 256, 1024] - Ns = [2, 16, 128, 256, 4096] - Ks = [3, 16, 32, 512, 8192] - out_dtypes = [torch.float16, torch.bfloat16] - for out_dtype in out_dtypes: - self._test_accuracy(Ms, Ns, Ks, out_dtype) + a_array = [a] + b_array = [b] + c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)] + + result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0] + cublas_grouped_gemm(a_array, b_array, c_array, out_dtype) + + torch.testing.assert_close(result_torch, expected) + torch.testing.assert_close(c_array[0], expected) if __name__ == "__main__": - if torch.cuda.is_available(): - cuda_version = tuple(map(int, torch.version.cuda.split("."))) - if cuda_version >= (12, 5): - unittest.main() - else: - print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.") + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_blockwise_gemm.py b/sgl-kernel/tests/test_fp8_blockwise_gemm.py index 4ae7ae035..c9ca01350 100644 --- a/sgl-kernel/tests/test_fp8_blockwise_gemm.py +++ b/sgl-kernel/tests/test_fp8_blockwise_gemm.py @@ -1,12 +1,13 @@ -import unittest +import os +import random from typing import Optional, Type +import pytest import torch from sgl_kernel import fp8_blockwise_scaled_mm def cdiv(a: int, b: int) -> int: - """Ceiling division.""" return -(a // -b) @@ -23,7 +24,6 @@ def baseline_scaled_mm( out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # We treat N-dimensional group scaling as extended numpy-style broadcasting # in numpy simply stretches dimensions with an extent of 1 to match the # the target shape by repeating the data along that dimension (broadcasting) @@ -51,62 +51,44 @@ def baseline_scaled_mm( scale_a = group_broadcast(scale_a, a.shape) scale_b = group_broadcast(scale_b, b.shape) - output = torch.mm( (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) ).to(out_dtype) - if bias is not None: output = output + bias - return output -class TestFp8Gemm(unittest.TestCase): - def _test_accuracy_once(self, M, N, K, out_dtype, device): - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min +def _test_accuracy_once(M, N, K, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() + scale_a_group_shape = (1, 128) + scale_b_group_shape = (128, 128) + scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) + scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) + scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001 + scale_a = scale_a.t().contiguous().t() + scale_b = scale_b.t().contiguous().t() + o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) + o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") - a_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - b_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() - - scale_a_group_shape = (1, 128) - scale_b_group_shape = (128, 128) - scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) - scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) - - scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001 - scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001 - scale_a = scale_a.t().contiguous().t() - scale_b = scale_b.t().contiguous().t() - - o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) - o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) - o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) - - rtol = 0.02 - atol = 1 - torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) - print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") - - def test_accuracy(self): - Ms = [1, 128, 512, 1024, 4096] - Ns = [128, 512, 1024, 4096] - Ks = [512, 1024, 4096, 8192, 16384] - out_dtypes = [torch.bfloat16, torch.float16] - for M in Ms: - for N in Ns: - for K in Ks: - for out_dtype in out_dtypes: - self._test_accuracy_once(M, N, K, out_dtype, "cuda") +@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("N", [128, 512, 1024, 4096]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_accuracy(M, N, K, out_dtype): + _test_accuracy_once(M, N, K, out_dtype, "cuda") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py index 1a7318659..e70e62af2 100644 --- a/sgl-kernel/tests/test_fp8_gemm.py +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -1,67 +1,49 @@ -import unittest - +import pytest import torch from sgl_kernel import fp8_scaled_mm def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) - o = o.to(torch.float32) temp1 = o * scale_a.view(-1, 1) temp2 = temp1 * scale_b.view(1, -1) final = temp2.to(out_dtype) if bias is not None: final = final + bias.view(1, -1) - return final -class TestFp8Gemm(unittest.TestCase): - def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min +def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 + if with_bias: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + b_fp8 = b_fp8.t() + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") - a_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - b_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - ) - b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 - scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 - if with_bias: - bias = torch.randn((N,), device=device, dtype=out_dtype) - else: - bias = None - o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) - b_fp8 = b_fp8.t() - o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) - o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) - rtol = 0.02 - atol = 1 - torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) - print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") - - def test_accuracy(self): - Ms = [1, 128, 512, 1024, 4096] - Ns = [16, 128, 512, 1024, 4096] - Ks = [512, 1024, 4096, 8192, 16384] - bias_opts = [True, False] - out_dtypes = [torch.bfloat16, torch.float16] - for M in Ms: - for N in Ns: - for K in Ks: - for with_bias in bias_opts: - for out_dtype in out_dtypes: - self._test_accuracy_once( - M, N, K, with_bias, out_dtype, "cuda" - ) +@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("with_bias", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_accuracy(M, N, K, with_bias, out_dtype): + _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 951de314e..d87a9a5aa 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -1,5 +1,4 @@ -import unittest - +import pytest import torch from sgl_kernel import int8_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm @@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): return o.to(out_dtype) -class TestInt8Gemm(unittest.TestCase): - def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): - a = to_int8(torch.randn((M, K), device=device) * 5) - b = to_int8(torch.randn((N, K), device=device).t() * 5) - scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) - scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) - if with_bias: - bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 - else: - bias = None +def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): + a = to_int8(torch.randn((M, K), device=device) * 5) + b = to_int8(torch.randn((N, K), device=device).t() * 5) + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + if with_bias: + bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 + else: + bias = None + o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + torch.testing.assert_close(o, o1) + torch.testing.assert_close(o, o2) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") - o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(o, o1) - torch.testing.assert_close(o, o2) - print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") - def test_accuracy(self): - Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192] - Ns = [16, 128, 512, 1024, 4096, 8192, 16384] - Ks = [512, 1024, 4096, 8192, 16384] - bias_opts = [True, False] - out_dtypes = [torch.float16, torch.bfloat16] - for M in Ms: - for N in Ns: - for K in Ks: - for with_bias in bias_opts: - for out_dtype in out_dtypes: - self._test_accuracy_once( - M, N, K, with_bias, out_dtype, "cuda" - ) +@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) +@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) +@pytest.mark.parametrize("with_bias", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +def test_accuracy(M, N, K, with_bias, out_dtype): + _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda") if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py index 20b2722fc..fe1e0afe3 100644 --- a/sgl-kernel/tests/test_per_token_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations( if __name__ == "__main__": - # Run the specific test function directly pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 910bcb253..242f226be 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -13,155 +13,186 @@ from torch.distributed import ProcessGroup from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + ranks = list(range(world_size)) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + buffer_ptrs = None + tmp_result_buffer_ptrs = None + barrier_in_ptrs = None + barrier_out_ptrs = None + custom_ptr = None + + try: + buffer_ptrs = TestCustomAllReduce.create_shared_buffer( + buffer_max_size, group=group + ) + tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer( + buffer_max_size, group=group + ) + barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer( + barrier_max_size, group=group + ) + barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer( + barrier_max_size, group=group + ) + + rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device) + + custom_ptr = custom_ops.init_custom_reduce( + rank, + world_size, + rank_data, + buffer_ptrs, + tmp_result_buffer_ptrs, + barrier_in_ptrs, + barrier_out_ptrs, + ) + + test_loop = 10 + for sz in test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(test_loop): + inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device) + inp1_ref = inp1.clone() + out1 = torch.empty_like(inp1) + + custom_ops.custom_reduce(custom_ptr, inp1, out1) + + dist.all_reduce(inp1_ref, group=group) + + torch.testing.assert_close(out1, inp1_ref) + + finally: + dist.barrier(group=group) + if custom_ptr is not None: + custom_ops.custom_dispose(custom_ptr) + if buffer_ptrs: + TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group) + if tmp_result_buffer_ptrs: + TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group) + if barrier_in_ptrs: + TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group) + if barrier_out_ptrs: + TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group) + + dist.destroy_process_group(group=group) + + def get_open_port() -> int: - # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] except OSError: - # try ipv6 with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) + s.bind(("::1", 0)) return s.getsockname()[1] def multi_process_parallel( - world_size: int, - test_target: Any, + world_size: int, test_target: Any, target_args: tuple = () ) -> None: + mp.set_start_method("spawn", force=True) + procs = [] distributed_init_port = get_open_port() for i in range(world_size): - proc = mp.Process( - target=test_target, - args=(world_size, i, distributed_init_port), - ) + proc_args = (world_size, i, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") proc.start() procs.append(proc) for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0 + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" class TestCustomAllReduce(unittest.TestCase): - @classmethod - def setUpClass(cls): - random.seed(42) - cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] - cls.world_sizes = [2, 4, 8] + test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] + world_sizes = [2, 4, 8] @staticmethod def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None ) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ lib = CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) + if group is None: + group = dist.group.WORLD world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) + + handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle)) + input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}") + gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, input_tensor, group=group) + + handles = [] + handle_type = type(handle) + for tensor in gathered_tensors: + bytes_list = tensor.cpu().tolist() + bytes_data = bytes(bytes_list) + handle_obj = handle_type() + ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data)) + handles.append(handle_obj) pointers: List[int] = [] for i, h in enumerate(handles): if i == rank: - pointers.append(pointer.value) # type: ignore + pointers.append(pointer.value) else: - pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + try: + opened_ptr = lib.cudaIpcOpenMemHandle(h) + pointers.append(opened_ptr.value) + except Exception as e: + print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}") + raise + dist.barrier(group=group) return pointers @staticmethod def free_shared_buffer( pointers: List[int], group: Optional[ProcessGroup] = None ) -> None: + if group is None: + group = dist.group.WORLD rank = dist.get_rank(group=group) lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) + if pointers and len(pointers) > rank and pointers[rank] is not None: + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + dist.barrier(group=group) def test_correctness(self): for world_size in self.world_sizes: - if world_size > torch.cuda.device_count(): + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + print( + f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}" + ) continue - multi_process_parallel(world_size, self.correctness) + + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, _run_correctness_worker, target_args=(self.test_sizes,) + ) print(f"custom allreduce tp = {world_size}: OK") - def init_custom_allreduce(self, rank, world_size, group): - buffer_max_size = 8 * 1024 * 1024 - barrier_max_size = 8 * (24 + 2) * 8 - - self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer( - buffer_max_size, group=group - ) - self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) - self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") - ) - - self.custom_ptr = custom_ops.init_custom_reduce( - rank, - world_size, - self.rank_data, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) - - def custom_allreduce(self, inp, out): - custom_ops.custom_reduce(self.custom_ptr, inp, out) - - def free_custom_allreduce(self, group): - self.free_shared_buffer(self.buffer_ptrs, group) - self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) - self.free_shared_buffer(self.barrier_in_ptrs, group) - self.free_shared_buffer(self.barrier_out_ptrs, group) - custom_ops.custom_dispose(self.custom_ptr) - - @staticmethod - def init_distributed_env(world_size, rank, distributed_init_port): - device = torch.device("cuda:0") - torch.cuda.set_device(device) - ranks = [i for i in range(world_size)] - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = torch.distributed.new_group(ranks, backend="gloo") - return group - - # compare result with torch.distributed - def correctness(self, world_size, rank, distributed_init_port): - group = self.init_distributed_env(world_size, rank, distributed_init_port) - - self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) - - test_loop = 10 - for sz in self.test_sizes: - for dtype in [torch.float32, torch.float16, torch.bfloat16]: - for _ in range(test_loop): - inp1 = torch.randint( - 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() - ) - out1 = torch.empty_like(inp1) - self.custom_allreduce(inp1, out1) - - dist.all_reduce(inp1, group=group) - torch.testing.assert_close(out1, inp1) - - self.free_custom_allreduce(group) - if __name__ == "__main__": unittest.main()