[benchmark] fbgemm benchmark support bandwidth report and support fbgemm_cutlass_gmm (#7422)
This commit is contained in:
29
benchmark/fbgemm/README.md
Normal file
29
benchmark/fbgemm/README.md
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
## Benchmark FBGEMM Grouped GEMM
|
||||||
|
|
||||||
|
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install fbgemm-gpu-genai
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
grouped-gemm-performance:
|
||||||
|
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
|
||||||
|
0 256.0 3704.841339 3042.626402 2254.725030
|
||||||
|
1 512.0 3691.426346 3029.065684 2269.504543
|
||||||
|
2 1024.0 3653.938629 2258.471467 2358.319020
|
||||||
|
3 2048.0 3596.644313 2271.611904 2476.895397
|
||||||
|
4 4096.0 3468.496435 2231.283986 2179.473910
|
||||||
|
```
|
||||||
|
|
||||||
|
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
|
||||||
@@ -1,10 +1,16 @@
|
|||||||
# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
|
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
|
||||||
from fbgemm_grouped_gemm import (
|
quantize_fp8_row,
|
||||||
|
triton_quantize_fp8_row,
|
||||||
|
)
|
||||||
|
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
||||||
|
grouped_gemm as fbgemm_grouped_gemm,
|
||||||
|
)
|
||||||
|
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
||||||
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
||||||
)
|
)
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
@@ -29,12 +35,11 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||||
num_groups = config.num_experts
|
num_groups = config.num_experts
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
elif config.architectures[0] in [
|
||||||
num_groups = (
|
"DeepseekV2ForCausalLM",
|
||||||
config.n_routed_experts + 1
|
"DeepseekV3ForCausalLM",
|
||||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
]:
|
||||||
else config.n_routed_experts
|
num_groups = config.n_routed_experts
|
||||||
)
|
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||||
num_groups = config.text_config.num_local_experts
|
num_groups = config.text_config.num_local_experts
|
||||||
@@ -65,7 +70,7 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
|||||||
|
|
||||||
tokens_per_group = batch_size // num_groups
|
tokens_per_group = batch_size // num_groups
|
||||||
m_sizes = torch.full(
|
m_sizes = torch.full(
|
||||||
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
|
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
|
||||||
@@ -84,11 +89,11 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
|||||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
|
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
|
||||||
for i in range(1, num_groups + 1):
|
for i in range(1, num_groups + 1):
|
||||||
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
|
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
|
||||||
|
|
||||||
weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda")
|
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
return (
|
return (
|
||||||
x,
|
x,
|
||||||
@@ -102,39 +107,144 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
def create_fp8_test_data(
|
||||||
|
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create test data for FP8 grouped GEMM operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Total batch size
|
||||||
|
num_groups: Number of groups
|
||||||
|
hidden_size: Hidden dimension size
|
||||||
|
intermediate_size: Intermediate dimension size
|
||||||
|
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
|
||||||
|
For cutlass: (x, wq, w_scale, m_sizes)
|
||||||
|
"""
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
tokens_per_group = batch_size // num_groups
|
tokens_per_group = batch_size // num_groups
|
||||||
m_sizes = torch.full(
|
|
||||||
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda")
|
# Create weight matrices for each group
|
||||||
w_fp16 = torch.randn(
|
w_list = []
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
|
for _ in range(num_groups):
|
||||||
)
|
w = torch.randn(
|
||||||
|
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
|
||||||
|
)
|
||||||
|
w_list.append(w)
|
||||||
|
|
||||||
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
# Quantize weights using quantize_fp8_row for each group
|
||||||
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
|
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
|
||||||
|
|
||||||
x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4
|
if backend == "triton":
|
||||||
w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4
|
# Triton format: concatenated weights
|
||||||
|
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
|
||||||
|
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
|
||||||
|
|
||||||
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
|
# Create m_sizes as int32 for triton
|
||||||
|
m_sizes = torch.full(
|
||||||
|
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and quantize input
|
||||||
|
x_fp16 = torch.randn(
|
||||||
|
batch_size, hidden_size, dtype=torch.float16, device="cuda"
|
||||||
|
)
|
||||||
|
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
|
||||||
|
x_scale = x_scale.view(batch_size, -1)
|
||||||
|
|
||||||
|
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
|
||||||
|
|
||||||
|
elif backend == "cutlass":
|
||||||
|
# CUTLASS format: stacked weights
|
||||||
|
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||||
|
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||||
|
|
||||||
|
# Create m_sizes as int64 for cutlass
|
||||||
|
m_values = [tokens_per_group] * num_groups
|
||||||
|
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
|
||||||
|
|
||||||
|
# Create input data - separate for each group then concat
|
||||||
|
x_list = []
|
||||||
|
for _ in range(num_groups):
|
||||||
|
x = torch.randn(
|
||||||
|
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
|
||||||
|
)
|
||||||
|
x_list.append(x)
|
||||||
|
|
||||||
|
# Concatenate inputs into single tensor
|
||||||
|
x = torch.concat(x_list, dim=0).contiguous()
|
||||||
|
|
||||||
|
return x, wq, w_scale, m_sizes
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
|
||||||
|
"""
|
||||||
|
Calculate memory bandwidth based on accessed expert weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m_sizes: Tensor containing batch sizes for each group
|
||||||
|
hidden_size: Hidden dimension size
|
||||||
|
intermediate_size: Intermediate dimension size
|
||||||
|
dtype: Data type of weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory size in bytes for accessed expert weights
|
||||||
|
"""
|
||||||
|
# Count non-zero groups (active experts)
|
||||||
|
if hasattr(m_sizes, "cpu"):
|
||||||
|
active_experts = torch.count_nonzero(m_sizes).item()
|
||||||
|
else:
|
||||||
|
active_experts = sum(1 for m in m_sizes if m > 0)
|
||||||
|
|
||||||
|
# Calculate bytes per element based on dtype
|
||||||
|
if dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
bytes_per_element = 2
|
||||||
|
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
bytes_per_element = 1
|
||||||
|
elif dtype == torch.float32:
|
||||||
|
bytes_per_element = 4
|
||||||
|
else:
|
||||||
|
# Default to 2 bytes for unknown dtypes
|
||||||
|
bytes_per_element = 2
|
||||||
|
|
||||||
|
# Memory per expert weight matrix
|
||||||
|
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
|
||||||
|
|
||||||
|
# Total memory for active experts
|
||||||
|
total_memory_bytes = active_experts * memory_per_expert
|
||||||
|
|
||||||
|
return total_memory_bytes
|
||||||
|
|
||||||
|
|
||||||
def get_benchmark_config(use_fp8_w8a8=False):
|
def get_benchmark_config(use_fp8_w8a8=False):
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
return {
|
return {
|
||||||
"line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"],
|
"line_vals": [
|
||||||
"line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"],
|
"fbgemm_triton_grouped_gemm_fp8",
|
||||||
"styles": [("blue", "-"), ("red", "-")],
|
"fbgemm_cutlass_f8f8bf16_rowwise",
|
||||||
|
"sglang_grouped_gemm",
|
||||||
|
],
|
||||||
|
"line_names": [
|
||||||
|
"FBGEMM Triton Grouped GEMM FP8",
|
||||||
|
"FBGEMM CUTLASS F8F8BF16 Rowwise",
|
||||||
|
"SGLang Grouped GEMM FP8",
|
||||||
|
],
|
||||||
|
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"],
|
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
|
||||||
"line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"],
|
"line_names": [
|
||||||
|
"FBGEMM Triton Grouped GEMM BF16",
|
||||||
|
"SGLang Grouped GEMM BF16",
|
||||||
|
],
|
||||||
"styles": [("blue", "-"), ("green", "-")],
|
"styles": [("blue", "-"), ("green", "-")],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,12 +256,12 @@ def run_benchmark(
|
|||||||
|
|
||||||
benchmark_config = triton.testing.Benchmark(
|
benchmark_config = triton.testing.Benchmark(
|
||||||
x_names=["batch_size"],
|
x_names=["batch_size"],
|
||||||
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
x_vals=[256, 512, 1024, 2048, 4096],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=config["line_vals"],
|
line_vals=config["line_vals"],
|
||||||
line_names=config["line_names"],
|
line_names=config["line_names"],
|
||||||
styles=config["styles"],
|
styles=config["styles"],
|
||||||
ylabel="Time (ms)",
|
ylabel="Bandwidth (GB/s)",
|
||||||
plot_name="grouped-gemm-performance",
|
plot_name="grouped-gemm-performance",
|
||||||
args={},
|
args={},
|
||||||
)
|
)
|
||||||
@@ -165,13 +275,22 @@ def run_benchmark(
|
|||||||
hidden_size = model_config["hidden_size"]
|
hidden_size = model_config["hidden_size"]
|
||||||
intermediate_size = model_config["intermediate_size"]
|
intermediate_size = model_config["intermediate_size"]
|
||||||
|
|
||||||
if provider == "fbgemm_grouped_gemm_fp8":
|
if provider == "fbgemm_triton_grouped_gemm_fp8":
|
||||||
try:
|
try:
|
||||||
test_data = create_fp8_test_data(
|
test_data = create_fp8_test_data(
|
||||||
batch_size, num_groups, hidden_size, intermediate_size
|
batch_size,
|
||||||
|
num_groups,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
backend="triton",
|
||||||
)
|
)
|
||||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
|
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
|
||||||
|
|
||||||
|
# Calculate memory bandwidth
|
||||||
|
memory_bytes = calculate_memory_bandwidth(
|
||||||
|
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
def run_func():
|
def run_func():
|
||||||
return fbgemm_grouped_gemm_fp8_rowwise(
|
return fbgemm_grouped_gemm_fp8_rowwise(
|
||||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
||||||
@@ -180,6 +299,38 @@ def run_benchmark(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"FP8 not supported, skipping: {e}")
|
print(f"FP8 not supported, skipping: {e}")
|
||||||
return float("inf"), float("inf"), float("inf")
|
return float("inf"), float("inf"), float("inf")
|
||||||
|
|
||||||
|
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
|
||||||
|
try:
|
||||||
|
test_data = create_fp8_test_data(
|
||||||
|
batch_size,
|
||||||
|
num_groups,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
backend="cutlass",
|
||||||
|
)
|
||||||
|
x, wq, w_scale, m_sizes = test_data
|
||||||
|
|
||||||
|
# Calculate memory bandwidth
|
||||||
|
memory_bytes = calculate_memory_bandwidth(
|
||||||
|
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quantize input using triton_quantize_fp8_row
|
||||||
|
xq, x_scale = triton_quantize_fp8_row(x)
|
||||||
|
x_scale = x_scale.view(batch_size, -1)
|
||||||
|
|
||||||
|
def run_func():
|
||||||
|
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
|
||||||
|
xq, wq, x_scale, w_scale, m_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
|
||||||
|
f"skipping: {e}"
|
||||||
|
)
|
||||||
|
return float("inf"), float("inf"), float("inf")
|
||||||
else:
|
else:
|
||||||
test_data = create_test_data(
|
test_data = create_test_data(
|
||||||
batch_size, num_groups, hidden_size, intermediate_size
|
batch_size, num_groups, hidden_size, intermediate_size
|
||||||
@@ -195,7 +346,12 @@ def run_benchmark(
|
|||||||
weight_indices,
|
weight_indices,
|
||||||
) = test_data
|
) = test_data
|
||||||
|
|
||||||
if provider == "fbgemm_grouped_gemm":
|
# Calculate memory bandwidth for BF16 operations
|
||||||
|
memory_bytes = calculate_memory_bandwidth(
|
||||||
|
m_sizes, hidden_size, intermediate_size, torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "fbgemm_triton_grouped_gemm":
|
||||||
|
|
||||||
def run_func():
|
def run_func():
|
||||||
return fbgemm_grouped_gemm(
|
return fbgemm_grouped_gemm(
|
||||||
@@ -228,10 +384,19 @@ def run_benchmark(
|
|||||||
try:
|
try:
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
|
||||||
return ms, min_ms, max_ms
|
|
||||||
|
# Convert time (ms) to bandwidth (GB/s)
|
||||||
|
# Bandwidth = Memory (bytes) / Time (seconds)
|
||||||
|
# Convert ms to seconds and bytes to GB (1e9)
|
||||||
|
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
|
||||||
|
# min bandwidth = max time, max bandwidth = min time
|
||||||
|
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
|
||||||
|
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
|
||||||
|
|
||||||
|
return gb_per_s, min_gb_per_s, max_gb_per_s
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during benchmarking for {provider}: {e}")
|
print(f"Error during benchmarking for {provider}: {e}")
|
||||||
return float("inf"), float("inf"), float("inf")
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
dynamic_benchmark.run(
|
dynamic_benchmark.run(
|
||||||
show_plots=True,
|
show_plots=True,
|
||||||
@@ -242,7 +407,7 @@ def run_benchmark(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_correctness(model_config, use_fp8_w8a8):
|
def verify_correctness(model_config):
|
||||||
print("Verifying correctness...")
|
print("Verifying correctness...")
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
num_groups = model_config["num_groups"]
|
num_groups = model_config["num_groups"]
|
||||||
@@ -250,54 +415,39 @@ def verify_correctness(model_config, use_fp8_w8a8):
|
|||||||
intermediate_size = model_config["intermediate_size"]
|
intermediate_size = model_config["intermediate_size"]
|
||||||
|
|
||||||
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
|
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
|
||||||
(x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = (
|
(
|
||||||
test_data
|
x,
|
||||||
|
w_fbgemm,
|
||||||
|
w_sglang,
|
||||||
|
c_fbgemm,
|
||||||
|
c_sglang,
|
||||||
|
m_sizes,
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
) = test_data
|
||||||
|
|
||||||
|
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
|
||||||
|
|
||||||
|
result_sglang = sglang_grouped_gemm(
|
||||||
|
x,
|
||||||
|
w_sglang,
|
||||||
|
c_sglang,
|
||||||
|
num_groups,
|
||||||
|
weight_column_major=True,
|
||||||
|
seg_indptr=seg_indptr,
|
||||||
|
weight_indices=weight_indices,
|
||||||
|
c_dtype=c_sglang.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
|
||||||
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
|
print("✓ BF16 Correctness verification passed!")
|
||||||
|
else:
|
||||||
result_sglang = sglang_grouped_gemm(
|
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
|
||||||
x,
|
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
|
|
||||||
print("✓ BF16 Correctness verification passed!")
|
|
||||||
else:
|
|
||||||
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
|
|
||||||
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
|
||||||
try:
|
|
||||||
fp8_data = create_fp8_test_data(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size
|
|
||||||
)
|
|
||||||
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data
|
|
||||||
|
|
||||||
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
|
||||||
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result_fp8.shape == (batch_size, intermediate_size)
|
|
||||||
print("✓ FP8 functionality test passed!")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"FP8 test failed (possibly unsupported): {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error during correctness verification: {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -348,7 +498,7 @@ def main():
|
|||||||
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
|
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
|
||||||
|
|
||||||
if args.verify_correctness:
|
if args.verify_correctness:
|
||||||
if not verify_correctness(model_config, args.use_fp8_w8a8):
|
if not verify_correctness(model_config):
|
||||||
print("Correctness verification failed. Exiting...")
|
print("Correctness verification failed. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,323 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
|
|
||||||
from fbgemm_grouped_gemm import (
|
|
||||||
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
|
||||||
)
|
|
||||||
|
|
||||||
FBGEMM_AVAILABLE = True
|
|
||||||
print("✓ Successfully imported FBGEMM grouped GEMM")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
|
|
||||||
FBGEMM_AVAILABLE = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
||||||
grouped_gemm_triton as sglang_grouped_gemm,
|
|
||||||
)
|
|
||||||
|
|
||||||
SGLANG_AVAILABLE = True
|
|
||||||
print("✓ Successfully imported SGLang grouped GEMM")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"✗ Failed to import SGLang grouped GEMM: {e}")
|
|
||||||
SGLANG_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
def create_uniform_groups(batch_size, num_groups, device):
|
|
||||||
tokens_per_group = batch_size // num_groups
|
|
||||||
return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def create_non_uniform_groups(batch_size, num_groups, device):
|
|
||||||
remaining = batch_size
|
|
||||||
m_sizes = []
|
|
||||||
|
|
||||||
for i in range(num_groups - 1):
|
|
||||||
if remaining <= 1:
|
|
||||||
size = 1
|
|
||||||
else:
|
|
||||||
max_size = remaining - (num_groups - i - 1) + 1
|
|
||||||
size = torch.randint(1, max_size, (1,)).item()
|
|
||||||
m_sizes.append(size)
|
|
||||||
remaining -= size
|
|
||||||
|
|
||||||
m_sizes.append(remaining)
|
|
||||||
return torch.tensor(m_sizes, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
|
|
||||||
c_sglang = torch.empty(
|
|
||||||
batch_size, intermediate_size, dtype=torch.bfloat16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
|
|
||||||
current_pos = 0
|
|
||||||
for i, size in enumerate(m_sizes):
|
|
||||||
current_pos += size
|
|
||||||
seg_indptr[i + 1] = current_pos
|
|
||||||
|
|
||||||
weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
|
|
||||||
w_sglang = w.view(num_groups, intermediate_size, -1)
|
|
||||||
|
|
||||||
return c_sglang, seg_indptr, weight_indices, w_sglang
|
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
|
|
||||||
w_fp16 = torch.randn(
|
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
|
||||||
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
|
|
||||||
w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4
|
|
||||||
|
|
||||||
return x_fp8, w_fp8, x_scale, w_scale
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def device():
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
pytest.skip("CUDA not available")
|
|
||||||
return torch.device("cuda")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
||||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
||||||
@pytest.mark.parametrize("batch_size", [32])
|
|
||||||
@pytest.mark.parametrize("num_groups", [2, 4, 8])
|
|
||||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
||||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
||||||
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
|
|
||||||
if batch_size % num_groups != 0:
|
|
||||||
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
||||||
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
||||||
w = torch.randn(
|
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
||||||
|
|
||||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
||||||
x, w, m_sizes, num_groups, intermediate_size, device
|
|
||||||
)
|
|
||||||
|
|
||||||
result_sglang = sglang_grouped_gemm(
|
|
||||||
x,
|
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
||||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
||||||
@pytest.mark.parametrize("batch_size", [63, 100, 127])
|
|
||||||
@pytest.mark.parametrize("num_groups", [3, 5, 7])
|
|
||||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
||||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
||||||
def test_non_uniform_groups(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
||||||
):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
|
|
||||||
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
||||||
w = torch.randn(
|
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
||||||
|
|
||||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
||||||
x, w, m_sizes, num_groups, intermediate_size, device
|
|
||||||
)
|
|
||||||
|
|
||||||
result_sglang = sglang_grouped_gemm(
|
|
||||||
x,
|
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
||||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
||||||
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
|
|
||||||
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
|
|
||||||
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
|
|
||||||
def test_large_dimensions(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
||||||
):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
||||||
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
||||||
w = torch.randn(
|
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
||||||
|
|
||||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
||||||
x, w, m_sizes, num_groups, intermediate_size, device
|
|
||||||
)
|
|
||||||
|
|
||||||
result_sglang = sglang_grouped_gemm(
|
|
||||||
x,
|
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
||||||
@pytest.mark.parametrize("batch_size", [32, 64])
|
|
||||||
@pytest.mark.parametrize("num_groups", [2, 4])
|
|
||||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
||||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
||||||
def test_fp8_uniform_groups(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
||||||
):
|
|
||||||
if batch_size % num_groups != 0:
|
|
||||||
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
||||||
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
|
||||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
|
||||||
)
|
|
||||||
assert result_fp8.shape == (batch_size, intermediate_size)
|
|
||||||
assert result_fp8.dtype == torch.bfloat16
|
|
||||||
except Exception as e:
|
|
||||||
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
||||||
@pytest.mark.parametrize("batch_size", [63, 100])
|
|
||||||
@pytest.mark.parametrize("num_groups", [3, 5])
|
|
||||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
|
||||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
||||||
def test_fp8_non_uniform_groups(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
||||||
):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
|
|
||||||
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, device
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
|
|
||||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
|
||||||
)
|
|
||||||
assert result_fp8.shape == (batch_size, intermediate_size)
|
|
||||||
assert result_fp8.dtype == torch.bfloat16
|
|
||||||
except Exception as e:
|
|
||||||
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
|
|
||||||
def test_fbgemm_only_uniform(device):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
batch_size, num_groups = 64, 4
|
|
||||||
hidden_size, intermediate_size = 512, 1024
|
|
||||||
|
|
||||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
||||||
w = torch.randn(
|
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
|
|
||||||
|
|
||||||
assert result.shape == (batch_size, intermediate_size)
|
|
||||||
assert result.dtype == torch.bfloat16
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
|
|
||||||
def test_sglang_only_uniform(device):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
batch_size, num_groups = 64, 4
|
|
||||||
hidden_size, intermediate_size = 512, 1024
|
|
||||||
|
|
||||||
m_sizes = create_uniform_groups(batch_size, num_groups, device)
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
|
||||||
w = torch.randn(
|
|
||||||
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
|
|
||||||
x, w, m_sizes, num_groups, intermediate_size, device
|
|
||||||
)
|
|
||||||
|
|
||||||
result = sglang_grouped_gemm(
|
|
||||||
x,
|
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.shape == (batch_size, intermediate_size)
|
|
||||||
assert result.dtype == torch.bfloat16
|
|
||||||
|
|
||||||
|
|
||||||
def test_imports():
|
|
||||||
assert (
|
|
||||||
FBGEMM_AVAILABLE or SGLANG_AVAILABLE
|
|
||||||
), "Neither FBGEMM nor SGLang is available"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
Reference in New Issue
Block a user