324 lines
11 KiB
Python
324 lines
11 KiB
Python
import os
|
|
import sys
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
try:
|
|
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
|
|
from fbgemm_grouped_gemm import (
|
|
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
|
)
|
|
|
|
FBGEMM_AVAILABLE = True
|
|
print("✓ Successfully imported FBGEMM grouped GEMM")
|
|
except ImportError as e:
|
|
print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
|
|
FBGEMM_AVAILABLE = False
|
|
|
|
try:
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
grouped_gemm_triton as sglang_grouped_gemm,
|
|
)
|
|
|
|
SGLANG_AVAILABLE = True
|
|
print("✓ Successfully imported SGLang grouped GEMM")
|
|
except ImportError as e:
|
|
print(f"✗ Failed to import SGLang grouped GEMM: {e}")
|
|
SGLANG_AVAILABLE = False
|
|
|
|
|
|
def create_uniform_groups(batch_size, num_groups, device):
|
|
tokens_per_group = batch_size // num_groups
|
|
return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)
|
|
|
|
|
|
def create_non_uniform_groups(batch_size, num_groups, device):
|
|
remaining = batch_size
|
|
m_sizes = []
|
|
|
|
for i in range(num_groups - 1):
|
|
if remaining <= 1:
|
|
size = 1
|
|
else:
|
|
max_size = remaining - (num_groups - i - 1) + 1
|
|
size = torch.randint(1, max_size, (1,)).item()
|
|
m_sizes.append(size)
|
|
remaining -= size
|
|
|
|
m_sizes.append(remaining)
|
|
return torch.tensor(m_sizes, dtype=torch.int64, device=device)
|
|
|
|
|
|
def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
|
|
batch_size = x.shape[0]
|
|
|
|
c_sglang = torch.empty(
|
|
batch_size, intermediate_size, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
|
|
current_pos = 0
|
|
for i, size in enumerate(m_sizes):
|
|
current_pos += size
|
|
seg_indptr[i + 1] = current_pos
|
|
|
|
weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
|
|
w_sglang = w.view(num_groups, intermediate_size, -1)
|
|
|
|
return c_sglang, seg_indptr, weight_indices, w_sglang
|
|
|
|
|
|
def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
|
|
torch.manual_seed(42)
|
|
|
|
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
|
|
w_fp16 = torch.randn(
|
|
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
|
|
)
|
|
|
|
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
|
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
|
|
|
|
x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
|
|
w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4
|
|
|
|
return x_fp8, w_fp8, x_scale, w_scale
|
|
|
|
|
|
@pytest.fixture
|
|
def device():
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("CUDA not available")
|
|
return torch.device("cuda")
|
|
|
|
|
|
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
@pytest.mark.parametrize("batch_size", [32])
|
|
@pytest.mark.parametrize("num_groups", [2, 4, 8])
|
|
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
|
|
if batch_size % num_groups != 0:
|
|
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
|
|
|
|
torch.manual_seed(42)
|
|
|
|
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
|
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
w = torch.randn(
|
|
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
|
|
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
x, w, m_sizes, num_groups, intermediate_size, device
|
|
)
|
|
|
|
result_sglang = sglang_grouped_gemm(
|
|
x,
|
|
w_sglang,
|
|
c_sglang,
|
|
num_groups,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
c_dtype=c_sglang.dtype,
|
|
)
|
|
|
|
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
@pytest.mark.parametrize("batch_size", [63, 100, 127])
|
|
@pytest.mark.parametrize("num_groups", [3, 5, 7])
|
|
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
def test_non_uniform_groups(
|
|
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
):
|
|
torch.manual_seed(42)
|
|
|
|
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
|
|
|
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
w = torch.randn(
|
|
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
|
|
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
x, w, m_sizes, num_groups, intermediate_size, device
|
|
)
|
|
|
|
result_sglang = sglang_grouped_gemm(
|
|
x,
|
|
w_sglang,
|
|
c_sglang,
|
|
num_groups,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
c_dtype=c_sglang.dtype,
|
|
)
|
|
|
|
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
|
|
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
|
|
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
|
|
def test_large_dimensions(
|
|
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
):
|
|
torch.manual_seed(42)
|
|
|
|
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
|
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
w = torch.randn(
|
|
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
|
|
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
x, w, m_sizes, num_groups, intermediate_size, device
|
|
)
|
|
|
|
result_sglang = sglang_grouped_gemm(
|
|
x,
|
|
w_sglang,
|
|
c_sglang,
|
|
num_groups,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
c_dtype=c_sglang.dtype,
|
|
)
|
|
|
|
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
@pytest.mark.parametrize("batch_size", [32, 64])
|
|
@pytest.mark.parametrize("num_groups", [2, 4])
|
|
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
def test_fp8_uniform_groups(
|
|
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
):
|
|
if batch_size % num_groups != 0:
|
|
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
|
|
|
|
torch.manual_seed(42)
|
|
|
|
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
|
|
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
)
|
|
|
|
try:
|
|
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
|
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
|
)
|
|
assert result_fp8.shape == (batch_size, intermediate_size)
|
|
assert result_fp8.dtype == torch.bfloat16
|
|
except Exception as e:
|
|
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
|
|
|
|
|
|
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
@pytest.mark.parametrize("batch_size", [63, 100])
|
|
@pytest.mark.parametrize("num_groups", [3, 5])
|
|
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
def test_fp8_non_uniform_groups(
|
|
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
):
|
|
torch.manual_seed(42)
|
|
|
|
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
|
|
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
|
|
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
)
|
|
|
|
try:
|
|
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
|
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
|
)
|
|
assert result_fp8.shape == (batch_size, intermediate_size)
|
|
assert result_fp8.dtype == torch.bfloat16
|
|
except Exception as e:
|
|
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
|
|
|
|
|
|
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
def test_fbgemm_only_uniform(device):
|
|
torch.manual_seed(42)
|
|
|
|
batch_size, num_groups = 64, 4
|
|
hidden_size, intermediate_size = 512, 1024
|
|
|
|
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
w = torch.randn(
|
|
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
|
|
assert result.shape == (batch_size, intermediate_size)
|
|
assert result.dtype == torch.bfloat16
|
|
|
|
|
|
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
def test_sglang_only_uniform(device):
|
|
torch.manual_seed(42)
|
|
|
|
batch_size, num_groups = 64, 4
|
|
hidden_size, intermediate_size = 512, 1024
|
|
|
|
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
w = torch.randn(
|
|
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
x, w, m_sizes, num_groups, intermediate_size, device
|
|
)
|
|
|
|
result = sglang_grouped_gemm(
|
|
x,
|
|
w_sglang,
|
|
c_sglang,
|
|
num_groups,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
c_dtype=c_sglang.dtype,
|
|
)
|
|
|
|
assert result.shape == (batch_size, intermediate_size)
|
|
assert result.dtype == torch.bfloat16
|
|
|
|
|
|
def test_imports():
|
|
assert (
|
|
FBGEMM_AVAILABLE or SGLANG_AVAILABLE
|
|
), "Neither FBGEMM nor SGLang is available"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|