add fbgemm moe grouped gemm kernel benchmark (#6924)
This commit is contained in:
366
benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
Normal file
366
benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
|
||||
from fbgemm_grouped_gemm import (
|
||||
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
||||
)
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
grouped_gemm_triton as sglang_grouped_gemm,
|
||||
)
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
num_groups = config.ffn_config.moe_num_experts
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
num_groups = config.num_experts
|
||||
intermediate_size = config.intermediate_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
num_groups = config.num_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
num_groups = config.num_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
num_groups = (
|
||||
config.n_routed_experts + 1
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
num_groups = config.text_config.num_local_experts
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
num_groups = config.num_local_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
else:
|
||||
num_groups = config.num_local_experts
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
shape_configs = {
|
||||
"num_groups": num_groups,
|
||||
"hidden_size": config.hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
||||
torch.manual_seed(42)
|
||||
|
||||
tokens_per_group = batch_size // num_groups
|
||||
m_sizes = torch.full(
|
||||
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
|
||||
)
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
base_weights = torch.randn(
|
||||
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
|
||||
w_sglang = base_weights
|
||||
|
||||
c_fbgemm = torch.empty(
|
||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
c_sglang = torch.empty(
|
||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
|
||||
for i in range(1, num_groups + 1):
|
||||
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
|
||||
|
||||
weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda")
|
||||
|
||||
return (
|
||||
x,
|
||||
w_fbgemm,
|
||||
w_sglang,
|
||||
c_fbgemm,
|
||||
c_sglang,
|
||||
m_sizes,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
)
|
||||
|
||||
|
||||
def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
||||
torch.manual_seed(42)
|
||||
|
||||
tokens_per_group = batch_size // num_groups
|
||||
m_sizes = torch.full(
|
||||
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
|
||||
)
|
||||
|
||||
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda")
|
||||
w_fp16 = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
|
||||
)
|
||||
|
||||
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
||||
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4
|
||||
w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4
|
||||
|
||||
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
|
||||
|
||||
|
||||
def get_benchmark_config(use_fp8_w8a8=False):
|
||||
if use_fp8_w8a8:
|
||||
return {
|
||||
"line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"],
|
||||
"line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"],
|
||||
"styles": [("blue", "-"), ("red", "-")],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"],
|
||||
"line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"],
|
||||
"styles": [("blue", "-"), ("green", "-")],
|
||||
}
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
|
||||
):
|
||||
config = get_benchmark_config(use_fp8_w8a8)
|
||||
|
||||
benchmark_config = triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||
line_arg="provider",
|
||||
line_vals=config["line_vals"],
|
||||
line_names=config["line_names"],
|
||||
styles=config["styles"],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="grouped-gemm-performance",
|
||||
args={},
|
||||
)
|
||||
|
||||
@triton.testing.perf_report(benchmark_config)
|
||||
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"Benchmarking {provider} with batch_size={batch_size}")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_groups = model_config["num_groups"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
intermediate_size = model_config["intermediate_size"]
|
||||
|
||||
if provider == "fbgemm_grouped_gemm_fp8":
|
||||
try:
|
||||
test_data = create_fp8_test_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size
|
||||
)
|
||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
|
||||
|
||||
def run_func():
|
||||
return fbgemm_grouped_gemm_fp8_rowwise(
|
||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"FP8 not supported, skipping: {e}")
|
||||
return float("inf"), float("inf"), float("inf")
|
||||
else:
|
||||
test_data = create_test_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size
|
||||
)
|
||||
(
|
||||
x,
|
||||
w_fbgemm,
|
||||
w_sglang,
|
||||
c_fbgemm,
|
||||
c_sglang,
|
||||
m_sizes,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
) = test_data
|
||||
|
||||
if provider == "fbgemm_grouped_gemm":
|
||||
|
||||
def run_func():
|
||||
return fbgemm_grouped_gemm(
|
||||
x, w_fbgemm, m_sizes, use_fast_accum=True
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def run_func():
|
||||
return sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
try:
|
||||
run_func()
|
||||
except Exception as e:
|
||||
print(f"Error during warmup for {provider}: {e}")
|
||||
return float("inf"), float("inf"), float("inf")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
try:
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
except Exception as e:
|
||||
print(f"Error during benchmarking for {provider}: {e}")
|
||||
return float("inf"), float("inf"), float("inf")
|
||||
|
||||
dynamic_benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
)
|
||||
|
||||
|
||||
def verify_correctness(model_config, use_fp8_w8a8):
|
||||
print("Verifying correctness...")
|
||||
batch_size = 128
|
||||
num_groups = model_config["num_groups"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
intermediate_size = model_config["intermediate_size"]
|
||||
|
||||
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
|
||||
(x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = (
|
||||
test_data
|
||||
)
|
||||
|
||||
try:
|
||||
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
|
||||
|
||||
result_sglang = sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
|
||||
print("✓ BF16 Correctness verification passed!")
|
||||
else:
|
||||
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
|
||||
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
|
||||
return False
|
||||
|
||||
if use_fp8_w8a8:
|
||||
try:
|
||||
fp8_data = create_fp8_test_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size
|
||||
)
|
||||
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data
|
||||
|
||||
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
||||
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True
|
||||
)
|
||||
|
||||
assert result_fp8.shape == (batch_size, intermediate_size)
|
||||
print("✓ FP8 functionality test passed!")
|
||||
except Exception as e:
|
||||
print(f"FP8 test failed (possibly unsupported): {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during correctness verification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
help="Model name to get configuration from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size", type=int, default=1, help="Tensor parallelism size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./benchmark_grouped_gemm/",
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify-correctness",
|
||||
action="store_true",
|
||||
help="Verify correctness before benchmarking",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
except Exception as e:
|
||||
print(f"Failed to get model config: {e}")
|
||||
print("Using default configuration...")
|
||||
model_config = {
|
||||
"num_groups": 8,
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
print("Running benchmark with:")
|
||||
print(f" num_groups: {model_config['num_groups']}")
|
||||
print(f" hidden_size: {model_config['hidden_size']}")
|
||||
print(f" intermediate_size: {model_config['intermediate_size']}")
|
||||
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
|
||||
|
||||
if args.verify_correctness:
|
||||
if not verify_correctness(model_config, args.use_fp8_w8a8):
|
||||
print("Correctness verification failed. Exiting...")
|
||||
return
|
||||
|
||||
try:
|
||||
run_benchmark(
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
save_path=args.save_path,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1294
benchmark/fbgemm/fbgemm_grouped_gemm.py
Normal file
1294
benchmark/fbgemm/fbgemm_grouped_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
323
benchmark/fbgemm/test_grouped_gemm.py
Normal file
323
benchmark/fbgemm/test_grouped_gemm.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
try:
|
||||
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
|
||||
from fbgemm_grouped_gemm import (
|
||||
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
||||
)
|
||||
|
||||
FBGEMM_AVAILABLE = True
|
||||
print("✓ Successfully imported FBGEMM grouped GEMM")
|
||||
except ImportError as e:
|
||||
print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
|
||||
FBGEMM_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
grouped_gemm_triton as sglang_grouped_gemm,
|
||||
)
|
||||
|
||||
SGLANG_AVAILABLE = True
|
||||
print("✓ Successfully imported SGLang grouped GEMM")
|
||||
except ImportError as e:
|
||||
print(f"✗ Failed to import SGLang grouped GEMM: {e}")
|
||||
SGLANG_AVAILABLE = False
|
||||
|
||||
|
||||
def create_uniform_groups(batch_size, num_groups, device):
|
||||
tokens_per_group = batch_size // num_groups
|
||||
return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
def create_non_uniform_groups(batch_size, num_groups, device):
|
||||
remaining = batch_size
|
||||
m_sizes = []
|
||||
|
||||
for i in range(num_groups - 1):
|
||||
if remaining <= 1:
|
||||
size = 1
|
||||
else:
|
||||
max_size = remaining - (num_groups - i - 1) + 1
|
||||
size = torch.randint(1, max_size, (1,)).item()
|
||||
m_sizes.append(size)
|
||||
remaining -= size
|
||||
|
||||
m_sizes.append(remaining)
|
||||
return torch.tensor(m_sizes, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
|
||||
batch_size = x.shape[0]
|
||||
|
||||
c_sglang = torch.empty(
|
||||
batch_size, intermediate_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
|
||||
current_pos = 0
|
||||
for i, size in enumerate(m_sizes):
|
||||
current_pos += size
|
||||
seg_indptr[i + 1] = current_pos
|
||||
|
||||
weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
|
||||
w_sglang = w.view(num_groups, intermediate_size, -1)
|
||||
|
||||
return c_sglang, seg_indptr, weight_indices, w_sglang
|
||||
|
||||
|
||||
def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
|
||||
torch.manual_seed(42)
|
||||
|
||||
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
|
||||
w_fp16 = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
|
||||
)
|
||||
|
||||
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
||||
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
|
||||
w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4
|
||||
|
||||
return x_fp8, w_fp8, x_scale, w_scale
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
return torch.device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("num_groups", [2, 4, 8])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
||||
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
|
||||
if batch_size % num_groups != 0:
|
||||
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
w = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
||||
|
||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
||||
x, w, m_sizes, num_groups, intermediate_size, device
|
||||
)
|
||||
|
||||
result_sglang = sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
||||
@pytest.mark.parametrize("batch_size", [63, 100, 127])
|
||||
@pytest.mark.parametrize("num_groups", [3, 5, 7])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
||||
def test_non_uniform_groups(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
|
||||
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
w = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
||||
|
||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
||||
x, w, m_sizes, num_groups, intermediate_size, device
|
||||
)
|
||||
|
||||
result_sglang = sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
||||
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
|
||||
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
|
||||
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
|
||||
def test_large_dimensions(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
|
||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
w = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
||||
|
||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
||||
x, w, m_sizes, num_groups, intermediate_size, device
|
||||
)
|
||||
|
||||
result_sglang = sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
||||
@pytest.mark.parametrize("batch_size", [32, 64])
|
||||
@pytest.mark.parametrize("num_groups", [2, 4])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
||||
def test_fp8_uniform_groups(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
||||
):
|
||||
if batch_size % num_groups != 0:
|
||||
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
||||
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
||||
)
|
||||
|
||||
try:
|
||||
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
||||
)
|
||||
assert result_fp8.shape == (batch_size, intermediate_size)
|
||||
assert result_fp8.dtype == torch.bfloat16
|
||||
except Exception as e:
|
||||
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
||||
@pytest.mark.parametrize("batch_size", [63, 100])
|
||||
@pytest.mark.parametrize("num_groups", [3, 5])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
||||
def test_fp8_non_uniform_groups(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
|
||||
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
|
||||
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
||||
)
|
||||
|
||||
try:
|
||||
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
||||
)
|
||||
assert result_fp8.shape == (batch_size, intermediate_size)
|
||||
assert result_fp8.dtype == torch.bfloat16
|
||||
except Exception as e:
|
||||
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
||||
def test_fbgemm_only_uniform(device):
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch_size, num_groups = 64, 4
|
||||
hidden_size, intermediate_size = 512, 1024
|
||||
|
||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
w = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
||||
|
||||
assert result.shape == (batch_size, intermediate_size)
|
||||
assert result.dtype == torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
||||
def test_sglang_only_uniform(device):
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch_size, num_groups = 64, 4
|
||||
hidden_size, intermediate_size = 512, 1024
|
||||
|
||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
w = torch.randn(
|
||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
||||
x, w, m_sizes, num_groups, intermediate_size, device
|
||||
)
|
||||
|
||||
result = sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, intermediate_size)
|
||||
assert result.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_imports():
|
||||
assert (
|
||||
FBGEMM_AVAILABLE or SGLANG_AVAILABLE
|
||||
), "Neither FBGEMM nor SGLang is available"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user