Files
sglang/benchmark/fbgemm/test_grouped_gemm.py
2025-06-07 02:57:30 -07:00

324 lines
11 KiB
Python

import os
import sys
import pytest
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
from fbgemm_grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
FBGEMM_AVAILABLE = True
print("✓ Successfully imported FBGEMM grouped GEMM")
except ImportError as e:
print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
FBGEMM_AVAILABLE = False
try:
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
SGLANG_AVAILABLE = True
print("✓ Successfully imported SGLang grouped GEMM")
except ImportError as e:
print(f"✗ Failed to import SGLang grouped GEMM: {e}")
SGLANG_AVAILABLE = False
def create_uniform_groups(batch_size, num_groups, device):
tokens_per_group = batch_size // num_groups
return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)
def create_non_uniform_groups(batch_size, num_groups, device):
remaining = batch_size
m_sizes = []
for i in range(num_groups - 1):
if remaining <= 1:
size = 1
else:
max_size = remaining - (num_groups - i - 1) + 1
size = torch.randint(1, max_size, (1,)).item()
m_sizes.append(size)
remaining -= size
m_sizes.append(remaining)
return torch.tensor(m_sizes, dtype=torch.int64, device=device)
def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
batch_size = x.shape[0]
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device=device
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
current_pos = 0
for i, size in enumerate(m_sizes):
current_pos += size
seg_indptr[i + 1] = current_pos
weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
w_sglang = w.view(num_groups, intermediate_size, -1)
return c_sglang, seg_indptr, weight_indices, w_sglang
def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
torch.manual_seed(42)
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
w_fp16 = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4
return x_fp8, w_fp8, x_scale, w_scale
@pytest.fixture
def device():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
return torch.device("cuda")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("num_groups", [2, 4, 8])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
if batch_size % num_groups != 0:
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [63, 100, 127])
@pytest.mark.parametrize("num_groups", [3, 5, 7])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_non_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
def test_large_dimensions(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [32, 64])
@pytest.mark.parametrize("num_groups", [2, 4])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
if batch_size % num_groups != 0:
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
batch_size, num_groups, hidden_size, intermediate_size, device
)
try:
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
assert result_fp8.dtype == torch.bfloat16
except Exception as e:
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [63, 100])
@pytest.mark.parametrize("num_groups", [3, 5])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_non_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
batch_size, num_groups, hidden_size, intermediate_size, device
)
try:
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
assert result_fp8.dtype == torch.bfloat16
except Exception as e:
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
def test_fbgemm_only_uniform(device):
torch.manual_seed(42)
batch_size, num_groups = 64, 4
hidden_size, intermediate_size = 512, 1024
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
assert result.shape == (batch_size, intermediate_size)
assert result.dtype == torch.bfloat16
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
def test_sglang_only_uniform(device):
torch.manual_seed(42)
batch_size, num_groups = 64, 4
hidden_size, intermediate_size = 512, 1024
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert result.shape == (batch_size, intermediate_size)
assert result.dtype == torch.bfloat16
def test_imports():
assert (
FBGEMM_AVAILABLE or SGLANG_AVAILABLE
), "Neither FBGEMM nor SGLang is available"
if __name__ == "__main__":
pytest.main([__file__, "-v"])