[Feature] use pytest for sgl-kernel (#4896)
This commit is contained in:
committed by
GitHub
parent
4ede6770cd
commit
9fccda3111
5
.github/workflows/pr-test-sgl-kernel.yml
vendored
5
.github/workflows/pr-test-sgl-kernel.yml
vendored
@@ -80,7 +80,8 @@ jobs:
|
||||
|
||||
- name: Install
|
||||
run: |
|
||||
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1
|
||||
bash scripts/ci_install_dependency.sh
|
||||
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2
|
||||
pip3 uninstall sgl-kernel -y || true
|
||||
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
|
||||
pip3 list | grep sgl-kernel
|
||||
@@ -89,7 +90,7 @@ jobs:
|
||||
timeout-minutes: 30
|
||||
run: |
|
||||
cd sgl-kernel
|
||||
find tests -name "test_*.py" | xargs -n 1 python3
|
||||
pytest tests/
|
||||
|
||||
- name: Uninstall dependencies
|
||||
run: |
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
@@ -85,14 +86,14 @@ def test_verify_tree_greedy():
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
return predicts, accept_index, accept_token_num
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
|
||||
# Check the expected output.
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import tree_speculative_sampling_target_only
|
||||
@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
|
||||
print(f"{accept_index=}")
|
||||
print(f"{accept_token_num=}")
|
||||
|
||||
return predicts, accept_index, accept_token_num
|
||||
if threshold_single == 1 and threshold_acc == 1:
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
elif threshold_single == 0 and threshold_acc == 0:
|
||||
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 1, 2, -1],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [2, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predicts, accept_index, accept_token_num = (
|
||||
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
|
||||
)
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
predicts, accept_index, accept_token_num = (
|
||||
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
|
||||
)
|
||||
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 1, 2, -1],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [2, 2]
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,49 +1,40 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import cublas_grouped_gemm
|
||||
|
||||
|
||||
def torch_grouped_gemm(a_array, b_array, out_dtype):
|
||||
c_array = []
|
||||
for a, b in zip(a_array, b_array):
|
||||
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
|
||||
return c_array
|
||||
return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]
|
||||
|
||||
|
||||
class TestGroupedGemm(unittest.TestCase):
|
||||
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
|
||||
group_count = len(Ms)
|
||||
a_array = []
|
||||
b_array = []
|
||||
c_array_cublas = []
|
||||
for i in range(group_count):
|
||||
M, N, K = Ms[i], Ns[i], Ks[i]
|
||||
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
|
||||
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
|
||||
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))
|
||||
skip_condition = not torch.cuda.is_available() or (
|
||||
torch.version.cuda is None
|
||||
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
|
||||
)
|
||||
|
||||
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
|
||||
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)
|
||||
|
||||
for i in range(group_count):
|
||||
M, N, K = Ms[i], Ns[i], Ks[i]
|
||||
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024])
|
||||
@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096])
|
||||
@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192])
|
||||
def test_grouped_gemm_accuracy(out_dtype, M, N, K):
|
||||
a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
|
||||
b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
|
||||
expected = torch.matmul(a, b.t()).to(out_dtype)
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 16, 32, 256, 1024]
|
||||
Ns = [2, 16, 128, 256, 4096]
|
||||
Ks = [3, 16, 32, 512, 8192]
|
||||
out_dtypes = [torch.float16, torch.bfloat16]
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy(Ms, Ns, Ks, out_dtype)
|
||||
a_array = [a]
|
||||
b_array = [b]
|
||||
c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]
|
||||
|
||||
result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
|
||||
cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)
|
||||
|
||||
torch.testing.assert_close(result_torch, expected)
|
||||
torch.testing.assert_close(c_array[0], expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.cuda.is_available():
|
||||
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
|
||||
if cuda_version >= (12, 5):
|
||||
unittest.main()
|
||||
else:
|
||||
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import unittest
|
||||
import os
|
||||
import random
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
@@ -23,7 +24,6 @@ def baseline_scaled_mm(
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
# in numpy simply stretches dimensions with an extent of 1 to match the
|
||||
# the target shape by repeating the data along that dimension (broadcasting)
|
||||
@@ -51,62 +51,44 @@ def baseline_scaled_mm(
|
||||
|
||||
scale_a = group_broadcast(scale_a, a.shape)
|
||||
scale_b = group_broadcast(scale_b, b.shape)
|
||||
|
||||
output = torch.mm(
|
||||
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
|
||||
).to(out_dtype)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TestFp8Gemm(unittest.TestCase):
|
||||
def _test_accuracy_once(self, M, N, K, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
def _test_accuracy_once(M, N, K, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
|
||||
a_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
b_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
|
||||
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
|
||||
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
|
||||
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 128, 512, 1024, 4096]
|
||||
Ns = [128, 512, 1024, 4096]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
out_dtypes = [torch.bfloat16, torch.float16]
|
||||
for M in Ms:
|
||||
for N in Ns:
|
||||
for K in Ks:
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_accuracy(M, N, K, out_dtype):
|
||||
_test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,67 +1,49 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_scaled_mm
|
||||
|
||||
|
||||
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
|
||||
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
|
||||
|
||||
o = o.to(torch.float32)
|
||||
temp1 = o * scale_a.view(-1, 1)
|
||||
temp2 = temp1 * scale_b.view(1, -1)
|
||||
final = temp2.to(out_dtype)
|
||||
if bias is not None:
|
||||
final = final + bias.view(1, -1)
|
||||
|
||||
return final
|
||||
|
||||
|
||||
class TestFp8Gemm(unittest.TestCase):
|
||||
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
|
||||
if with_bias:
|
||||
bias = torch.randn((N,), device=device, dtype=out_dtype)
|
||||
else:
|
||||
bias = None
|
||||
b_fp8 = b_fp8.t()
|
||||
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
a_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
b_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
|
||||
if with_bias:
|
||||
bias = torch.randn((N,), device=device, dtype=out_dtype)
|
||||
else:
|
||||
bias = None
|
||||
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
|
||||
b_fp8 = b_fp8.t()
|
||||
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 128, 512, 1024, 4096]
|
||||
Ns = [16, 128, 512, 1024, 4096]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
bias_opts = [True, False]
|
||||
out_dtypes = [torch.bfloat16, torch.float16]
|
||||
for M in Ms:
|
||||
for N in Ns:
|
||||
for K in Ks:
|
||||
for with_bias in bias_opts:
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy_once(
|
||||
M, N, K, with_bias, out_dtype, "cuda"
|
||||
)
|
||||
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("with_bias", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_accuracy(M, N, K, with_bias, out_dtype):
|
||||
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
@@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
class TestInt8Gemm(unittest.TestCase):
|
||||
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
|
||||
a = to_int8(torch.randn((M, K), device=device) * 5)
|
||||
b = to_int8(torch.randn((N, K), device=device).t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
if with_bias:
|
||||
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
|
||||
a = to_int8(torch.randn((M, K), device=device) * 5)
|
||||
b = to_int8(torch.randn((N, K), device=device).t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
if with_bias:
|
||||
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
torch.testing.assert_close(o, o1)
|
||||
torch.testing.assert_close(o, o2)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
torch.testing.assert_close(o, o1)
|
||||
torch.testing.assert_close(o, o2)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
|
||||
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
bias_opts = [True, False]
|
||||
out_dtypes = [torch.float16, torch.bfloat16]
|
||||
for M in Ms:
|
||||
for N in Ns:
|
||||
for K in Ks:
|
||||
for with_bias in bias_opts:
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy_once(
|
||||
M, N, K, with_bias, out_dtype, "cuda"
|
||||
)
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
|
||||
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("with_bias", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
def test_accuracy(M, N, K, with_bias, out_dtype):
|
||||
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -13,155 +13,186 @@ from torch.distributed import ProcessGroup
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = list(range(world_size))
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
buffer_ptrs = None
|
||||
tmp_result_buffer_ptrs = None
|
||||
barrier_in_ptrs = None
|
||||
barrier_out_ptrs = None
|
||||
custom_ptr = None
|
||||
|
||||
try:
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
)
|
||||
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
)
|
||||
|
||||
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
|
||||
custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
rank_data,
|
||||
buffer_ptrs,
|
||||
tmp_result_buffer_ptrs,
|
||||
barrier_in_ptrs,
|
||||
barrier_out_ptrs,
|
||||
)
|
||||
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
|
||||
custom_ops.custom_reduce(custom_ptr, inp1, out1)
|
||||
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
|
||||
torch.testing.assert_close(out1, inp1_ref)
|
||||
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
if custom_ptr is not None:
|
||||
custom_ops.custom_dispose(custom_ptr)
|
||||
if buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
|
||||
if tmp_result_buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
|
||||
if barrier_in_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
|
||||
if barrier_out_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
|
||||
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.bind(("::1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int,
|
||||
test_target: Any,
|
||||
world_size: int, test_target: Any, target_args: tuple = ()
|
||||
) -> None:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
distributed_init_port = get_open_port()
|
||||
for i in range(world_size):
|
||||
proc = mp.Process(
|
||||
target=test_target,
|
||||
args=(world_size, i, distributed_init_port),
|
||||
)
|
||||
proc_args = (world_size, i, distributed_init_port) + target_args
|
||||
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for i in range(world_size):
|
||||
procs[i].join()
|
||||
assert procs[i].exitcode == 0
|
||||
assert (
|
||||
procs[i].exitcode == 0
|
||||
), f"Process {i} failed with exit code {procs[i].exitcode}"
|
||||
|
||||
|
||||
class TestCustomAllReduce(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
random.seed(42)
|
||||
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
cls.world_sizes = [2, 4, 8]
|
||||
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
world_sizes = [2, 4, 8]
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
|
||||
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
|
||||
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_tensors, input_tensor, group=group)
|
||||
|
||||
handles = []
|
||||
handle_type = type(handle)
|
||||
for tensor in gathered_tensors:
|
||||
bytes_list = tensor.cpu().tolist()
|
||||
bytes_data = bytes(bytes_list)
|
||||
handle_obj = handle_type()
|
||||
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
|
||||
handles.append(handle_obj)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
pointers.append(pointer.value)
|
||||
else:
|
||||
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
try:
|
||||
opened_ptr = lib.cudaIpcOpenMemHandle(h)
|
||||
pointers.append(opened_ptr.value)
|
||||
except Exception as e:
|
||||
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
|
||||
raise
|
||||
|
||||
dist.barrier(group=group)
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
if pointers and len(pointers) > rank and pointers[rank] is not None:
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
dist.barrier(group=group)
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
if world_size > torch.cuda.device_count():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if world_size > available_gpus:
|
||||
print(
|
||||
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
|
||||
)
|
||||
continue
|
||||
multi_process_parallel(world_size, self.correctness)
|
||||
|
||||
print(f"Running test for world_size={world_size}")
|
||||
multi_process_parallel(
|
||||
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
|
||||
)
|
||||
print(f"custom allreduce tp = {world_size}: OK")
|
||||
|
||||
def init_custom_allreduce(self, rank, world_size, group):
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
|
||||
)
|
||||
|
||||
self.custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
|
||||
def custom_allreduce(self, inp, out):
|
||||
custom_ops.custom_reduce(self.custom_ptr, inp, out)
|
||||
|
||||
def free_custom_allreduce(self, group):
|
||||
self.free_shared_buffer(self.buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_out_ptrs, group)
|
||||
custom_ops.custom_dispose(self.custom_ptr)
|
||||
|
||||
@staticmethod
|
||||
def init_distributed_env(world_size, rank, distributed_init_port):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = [i for i in range(world_size)]
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
return group
|
||||
|
||||
# compare result with torch.distributed
|
||||
def correctness(self, world_size, rank, distributed_init_port):
|
||||
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
||||
|
||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
||||
|
||||
test_loop = 10
|
||||
for sz in self.test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
out1 = torch.empty_like(inp1)
|
||||
self.custom_allreduce(inp1, out1)
|
||||
|
||||
dist.all_reduce(inp1, group=group)
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
|
||||
self.free_custom_allreduce(group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user