adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
19
benchmark/kernels/deepseek/README.md
Normal file
19
benchmark/kernels/deepseek/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## DeepSeek kernels benchmark
|
||||
|
||||
|
||||
### Prerequisites
|
||||
- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.
|
||||
|
||||
### Benchmark
|
||||
- `benchmark_deepgemm_fp8_gemm.py`
|
||||
```bash
|
||||
python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1
|
||||
```
|
||||
|
||||
- `benchmark_deepgemm_fp8_group_gemm.py`
|
||||
```bash
|
||||
python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1
|
||||
```
|
||||
|
||||
- You can use the `--run_correctness` parameter to verify all kernels results's correctness.
|
||||
- You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.
|
||||
400
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
Normal file
400
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
|
||||
def tl_gemm(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
in_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
):
|
||||
assert in_dtype in [
|
||||
"e4m3_float8",
|
||||
], "Currently only e4m3_float8 is supported"
|
||||
assert out_dtype in [
|
||||
"bfloat16",
|
||||
"float16",
|
||||
], "Currently only bfloat16 and float16 are supported"
|
||||
|
||||
TILE_SIZE = (128, 128, 128)
|
||||
block_M = TILE_SIZE[0]
|
||||
block_N = TILE_SIZE[1]
|
||||
block_K = TILE_SIZE[2]
|
||||
|
||||
A_shape = (M, K)
|
||||
Scales_A_shape = (M, T.ceildiv(K, block_K))
|
||||
B_shape = (N, K)
|
||||
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
|
||||
A_shared_shape = (block_M, block_K)
|
||||
B_shared_shape = (block_N, block_K)
|
||||
C_shared_shape = (block_M, block_N)
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
A: T.Buffer(A_shape, in_dtype),
|
||||
scales_a: T.Buffer(Scales_A_shape, "float32"),
|
||||
B: T.Buffer(B_shape, in_dtype),
|
||||
scales_b: T.Buffer(Scales_B_shape, "float32"),
|
||||
C: T.Buffer((M, N), out_dtype),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||
bx,
|
||||
by,
|
||||
):
|
||||
|
||||
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
|
||||
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
|
||||
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
|
||||
Scale_C_shared = T.alloc_shared((block_M), "float32")
|
||||
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
|
||||
# Improve L2 Cache
|
||||
T.use_swizzle(panel_size=10)
|
||||
|
||||
T.clear(C_local)
|
||||
T.clear(C_local_accum)
|
||||
K_iters = T.ceildiv(K, block_K)
|
||||
for k in T.Pipelined(K_iters, num_stages=4):
|
||||
# Load A into shared memory
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
# Load B into shared memory
|
||||
T.copy(B[bx * block_N, k * block_K], B_shared)
|
||||
# Load scale into shared memory
|
||||
Scale_B = scales_b[bx, k]
|
||||
for i in T.Parallel(block_M):
|
||||
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
||||
|
||||
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||
# Promote to enable 2xAcc
|
||||
for i, j in T.Parallel(block_M, block_N):
|
||||
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
||||
T.clear(C_local)
|
||||
# TMA store
|
||||
T.copy(C_local_accum, C_shared)
|
||||
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n
|
||||
), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def fp8_gemm_deepgemm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""DeepGEMM implementation of FP8 GEMM"""
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
def fp8_gemm_sglang(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""SGLang implementation of FP8 GEMM"""
|
||||
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||
|
||||
# Run SGLang kernel
|
||||
out = w8a8_block_fp8_matmul(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def fp8_gemm_vllm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""vLLM implementation of FP8 GEMM"""
|
||||
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||
|
||||
# Run vLLM kernel
|
||||
out = vllm_w8a8_block_fp8_matmul(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int):
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
out_deepgemm = fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
x_scale_col_major.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
out_sglang = fp8_gemm_sglang(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
|
||||
)
|
||||
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
out_tilelang = tilelang_kernel(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
|
||||
)
|
||||
|
||||
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
|
||||
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
|
||||
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"SGLang output: {out_sglang[0, 0:5]}")
|
||||
print(f"TileLang output: {out_tilelang[0, 0:5]}")
|
||||
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
|
||||
|
||||
sglang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_sglang_match = torch.allclose(
|
||||
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
|
||||
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
configs.append((m, n, k, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "tp_size"],
|
||||
x_vals=[list(config) for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "sglang", "tilelang"],
|
||||
line_names=["DeepGEMM", "SGLang", "TileLang"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, tp_size, provider):
|
||||
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Preprocess data before benchmarking
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
x_scale_col_major.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sglang":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_sglang(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # tilelang
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: tilelang_kernel(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
# Print shape-specific results with TFLOPS
|
||||
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_gemm/",
|
||||
help="Path to save fp8 gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(64, 512, 7168) # Small test
|
||||
calculate_diff(64, 7168, 16384) # Medium test
|
||||
calculate_diff(64, 18432, 7168) # Large test
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = get_benchmark(args.tp_size)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
486
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py
Normal file
486
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py
Normal file
@@ -0,0 +1,486 @@
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import shared functionality from the regular GEMM benchmark
|
||||
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||
per_block_cast_to_fp8,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def construct_grouped_and_flat_fp8(
|
||||
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
|
||||
) -> Tuple[
|
||||
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
|
||||
torch.Tensor, # output
|
||||
torch.Tensor, # reference output
|
||||
]:
|
||||
# Verify input shapes
|
||||
m, k = x.shape
|
||||
n, k_y = y.shape
|
||||
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
|
||||
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
|
||||
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||
|
||||
# Reshape inputs for grouped processing
|
||||
m_per_group = m // num_groups
|
||||
x_grouped = x.view(num_groups, m_per_group, k)
|
||||
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
|
||||
|
||||
# Initialize output tensors
|
||||
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
|
||||
|
||||
# Quantize grouped tensors
|
||||
x_fp8_grouped = (
|
||||
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
y_fp8_grouped = (
|
||||
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
|
||||
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
|
||||
|
||||
# Quantize flat tensors
|
||||
x_fp8_flat = per_token_cast_to_fp8(x)
|
||||
y_fp8_flat = per_block_cast_to_fp8(y)
|
||||
|
||||
# For non-masked input, merge the group and M dims in output
|
||||
if not is_masked:
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0].view(-1, k),
|
||||
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
|
||||
)
|
||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||
|
||||
# Transpose earlier for testing
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0],
|
||||
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
)
|
||||
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
|
||||
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||
|
||||
|
||||
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
|
||||
# custom kernel based on the Triton tutorial.
|
||||
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
@triton.jit
|
||||
def fp8_gemm_group_triton_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
# Pointers to scaling factors
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension.
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# Strides for scaling factors
|
||||
stride_a_scale_m,
|
||||
stride_a_scale_k,
|
||||
stride_b_scale_n,
|
||||
stride_b_scale_k,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
|
||||
Note: Block sizes must be multiples of 32 for optimal TMA performance.
|
||||
"""
|
||||
# Map program ids to the block of C it should compute
|
||||
pid_group = tl.program_id(axis=0) # Group ID
|
||||
pid_n = tl.program_id(axis=1) # N dimension ID
|
||||
|
||||
# Compute the M block ID within this group
|
||||
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
|
||||
pid_m_within_group = tl.program_id(axis=2) % group_size_m
|
||||
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
|
||||
|
||||
# Create pointers for the first blocks of A and B
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# Initialize accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Main loop
|
||||
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
k_offset = k_block * BLOCK_SIZE_K
|
||||
|
||||
# Load the next block of A and B, with masks
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
|
||||
|
||||
# Calculate indices for scaling factors for this K block
|
||||
a_scale_ptrs = a_scale_ptr + (
|
||||
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
|
||||
)
|
||||
b_scale_ptrs = b_scale_ptr + (
|
||||
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
||||
)
|
||||
|
||||
# Perform matrix multiplication in FP8
|
||||
res = tl.dot(a, b)
|
||||
|
||||
# Load scaling factors for the current block
|
||||
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
# Apply scaling factors to the accumulated result
|
||||
accumulator += res * a_scale * b_scale
|
||||
|
||||
# Advance pointers
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Convert to bfloat16 for output
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
|
||||
# Write back the result
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||
"""
|
||||
Perform matrix multiplication with FP8 inputs and proper scaling.
|
||||
|
||||
Args:
|
||||
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
||||
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
||||
c: Output tensor in BF16 format
|
||||
num_groups: Number of groups for grouped GEMM
|
||||
|
||||
Returns:
|
||||
Result tensor in BF16 format
|
||||
"""
|
||||
# Unpack the tuples
|
||||
a, a_scale = a_tuple
|
||||
b, b_scale = b_tuple
|
||||
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
|
||||
# Configure block sizes - must be multiples of 32 for TMA alignment
|
||||
BLOCK_SIZE_M = 128
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_K = 128
|
||||
|
||||
# Calculate grid dimensions
|
||||
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
|
||||
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
|
||||
|
||||
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
|
||||
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
|
||||
|
||||
fp8_gemm_group_triton_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a_scale,
|
||||
b_scale,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
a_scale.stride(0),
|
||||
1, # Stride in the K dimension may be 1
|
||||
b_scale.stride(0),
|
||||
1 if b_scale.dim() > 1 else 0,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
GROUP_SIZE_M=num_groups,
|
||||
)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
m_indices,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
||||
print(f"Shape (m={m}, n={n}, k={k}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||
)
|
||||
m_per_group = m // num_groups
|
||||
out_deepgemm = out.clone()
|
||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||
m_indices = (
|
||||
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
||||
)
|
||||
|
||||
fp8_gemm_group_deepgemm(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out_deepgemm,
|
||||
m_indices,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Prepare inputs for Triton
|
||||
a, a_scale = x_fp8_flat
|
||||
b, b_scale = y_fp8_flat
|
||||
b = b.T.contiguous()
|
||||
# Ensure scales are in the right format and contiguous
|
||||
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||
M, _ = a.shape
|
||||
_, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
||||
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
||||
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"Torch output: {out_torch[0, 0:5]}")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"Triton output: {out_triton[0, 0:5]}")
|
||||
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
||||
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
||||
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
||||
|
||||
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
||||
triton_torch_diff = calc_diff(out_triton, out_torch)
|
||||
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
||||
|
||||
DIFF_THRESHOLD = 0.001
|
||||
all_match = (
|
||||
deepgemm_torch_diff < DIFF_THRESHOLD
|
||||
and triton_torch_diff < DIFF_THRESHOLD
|
||||
and deepgemm_triton_diff < DIFF_THRESHOLD
|
||||
)
|
||||
if all_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(
|
||||
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [2048, 4096]
|
||||
group_sizes = [4, 8]
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
for num_groups in group_sizes:
|
||||
configs.append((m, n, k, num_groups, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "num_groups", "tp_size"],
|
||||
x_vals=[config for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "triton"],
|
||||
line_names=["DeepGEMM", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, num_groups, tp_size, provider):
|
||||
print(
|
||||
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
|
||||
)
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||
)
|
||||
m_per_group = m // num_groups
|
||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||
m_indices = (
|
||||
m_indices.unsqueeze(-1)
|
||||
.expand(num_groups, m_per_group)
|
||||
.contiguous()
|
||||
.view(-1)
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_group_deepgemm(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
m_indices,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
# Prepare inputs for Triton
|
||||
# We did it outside of the lambda function to make it fair comparison like deepgemm
|
||||
a, a_scale = x_fp8_flat
|
||||
b, b_scale = y_fp8_flat
|
||||
b = b.T.contiguous()
|
||||
# Ensure scales are in the right format and contiguous
|
||||
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||
M, _ = a.shape
|
||||
_, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_group_triton(
|
||||
(a, a_scale),
|
||||
(b, b_scale),
|
||||
c,
|
||||
num_groups,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_group_gemm/",
|
||||
help="Path to save deepgemm fp8 group gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(8192, 7168, 4096, 4)
|
||||
calculate_diff(8192, 2048, 7168, 4)
|
||||
calculate_diff(4096, 7168, 4096, 8)
|
||||
calculate_diff(4096, 2048, 7168, 8)
|
||||
calculate_diff(4096, 576, 7168, 8)
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = get_benchmark(args.tp_size)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
Reference in New Issue
Block a user