[Benchmark] tilelang vs deepgemm vs w8a8_block_fp8_matmul (#4735)
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
import itertools
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import numpy as np
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||
@@ -14,6 +13,84 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
|
||||
def tl_gemm(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
in_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
):
|
||||
assert in_dtype in [
|
||||
"e4m3_float8",
|
||||
], "Currently only e4m3_float8 is supported"
|
||||
assert out_dtype in [
|
||||
"bfloat16",
|
||||
"float16",
|
||||
], "Currently only bfloat16 and float16 are supported"
|
||||
|
||||
TILE_SIZE = (128, 128, 128)
|
||||
block_M = TILE_SIZE[0]
|
||||
block_N = TILE_SIZE[1]
|
||||
block_K = TILE_SIZE[2]
|
||||
|
||||
A_shape = (M, K)
|
||||
Scales_A_shape = (M, T.ceildiv(K, block_K))
|
||||
B_shape = (N, K)
|
||||
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
|
||||
A_shared_shape = (block_M, block_K)
|
||||
B_shared_shape = (block_N, block_K)
|
||||
C_shared_shape = (block_M, block_N)
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
A: T.Buffer(A_shape, in_dtype),
|
||||
scales_a: T.Buffer(Scales_A_shape, "float32"),
|
||||
B: T.Buffer(B_shape, in_dtype),
|
||||
scales_b: T.Buffer(Scales_B_shape, "float32"),
|
||||
C: T.Buffer((M, N), out_dtype),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||
bx,
|
||||
by,
|
||||
):
|
||||
|
||||
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
|
||||
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
|
||||
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
|
||||
Scale_C_shared = T.alloc_shared((block_M), "float32")
|
||||
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
|
||||
# Improve L2 Cache
|
||||
T.use_swizzle(panel_size=10)
|
||||
|
||||
T.clear(C_local)
|
||||
T.clear(C_local_accum)
|
||||
K_iters = T.ceildiv(K, block_K)
|
||||
for k in T.Pipelined(K_iters, num_stages=4):
|
||||
# Load A into shared memory
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
# Load B into shared memory
|
||||
T.copy(B[bx * block_N, k * block_K], B_shared)
|
||||
# Load scale into shared memory
|
||||
Scale_B = scales_b[bx, k]
|
||||
for i in T.Parallel(block_M):
|
||||
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
||||
|
||||
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||
# Promote to enable 2xAcc
|
||||
for i, j in T.Parallel(block_M, block_N):
|
||||
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
||||
T.clear(C_local)
|
||||
# TMA store
|
||||
T.copy(C_local_accum, C_shared)
|
||||
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
@@ -114,35 +191,42 @@ def calculate_diff(m: int, n: int, k: int):
|
||||
out_sglang = fp8_gemm_sglang(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
|
||||
)
|
||||
out_vllm = fp8_gemm_vllm(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
|
||||
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
out_tilelang = tilelang_kernel(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
|
||||
)
|
||||
|
||||
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
|
||||
diff_vllm_deepgemm = torch.abs(out_deepgemm - out_vllm).mean().item()
|
||||
diff_vllm_sglang = torch.abs(out_vllm - out_sglang).mean().item()
|
||||
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
|
||||
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"SGLang output: {out_sglang[0, 0:5]}")
|
||||
print(f"vLLM output: {out_vllm[0, 0:5]}")
|
||||
print(f"TileLang output: {out_tilelang[0, 0:5]}")
|
||||
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
|
||||
print(f"Mean absolute difference (vLLM-DeepGEMM): {diff_vllm_deepgemm}")
|
||||
print(f"Mean absolute difference (vLLM-SGLang): {diff_vllm_sglang}")
|
||||
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
|
||||
|
||||
sglang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
vllm_deepgemm_match = torch.allclose(out_deepgemm, out_vllm, atol=1e-2, rtol=1e-2)
|
||||
vllm_sglang_match = torch.allclose(out_vllm, out_sglang, atol=1e-2, rtol=1e-2)
|
||||
tilelang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_sglang_match = torch.allclose(
|
||||
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
if sglang_deepgemm_match and vllm_deepgemm_match and vllm_sglang_match:
|
||||
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
|
||||
print(f" - vLLM vs DeepGEMM: {'✅' if vllm_deepgemm_match else '❌'}")
|
||||
print(f" - vLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}\n")
|
||||
print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
@@ -198,8 +282,8 @@ def get_benchmark(tp_size):
|
||||
x_names=["m", "n", "k", "tp_size"],
|
||||
x_vals=[list(config) for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "sglang", "vllm"],
|
||||
line_names=["DeepGEMM", "SGLang", "vLLM"],
|
||||
line_vals=["deepgemm", "sglang", "tilelang"],
|
||||
line_names=["DeepGEMM", "SGLang", "TileLang"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||
@@ -244,16 +328,15 @@ def get_benchmark(tp_size):
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # vllm
|
||||
else: # tilelang
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_vllm(
|
||||
lambda: tilelang_kernel(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
@@ -282,6 +365,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user