From 14269198e3b59cf662da209cafdd7790c9dd634b Mon Sep 17 00:00:00 2001 From: Chunan Zeng Date: Mon, 24 Mar 2025 20:56:31 -0700 Subject: [PATCH] [Benchmark] tilelang vs deepgemm vs w8a8_block_fp8_matmul (#4735) --- .../deepseek/benchmark_deepgemm_fp8_gemm.py | 128 +++++++++++++++--- 1 file changed, 106 insertions(+), 22 deletions(-) diff --git a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py index d80142483..7151f861f 100644 --- a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py +++ b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py @@ -1,11 +1,10 @@ -import itertools from typing import Tuple import deep_gemm -import numpy as np +import tilelang +import tilelang.language as T import torch import triton -import triton.language as tl from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul, @@ -14,6 +13,84 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul +# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1 +def tl_gemm( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "e4m3_float8", + ], "Currently only e4m3_float8 is supported" + assert out_dtype in [ + "bfloat16", + "float16", + ], "Currently only bfloat16 and float16 are supported" + + TILE_SIZE = (128, 128, 128) + block_M = TILE_SIZE[0] + block_N = TILE_SIZE[1] + block_K = TILE_SIZE[2] + + A_shape = (M, K) + Scales_A_shape = (M, T.ceildiv(K, block_K)) + B_shape = (N, K) + Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K)) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = (block_M, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + scales_a: T.Buffer(Scales_A_shape, "float32"), + B: T.Buffer(B_shape, in_dtype), + scales_b: T.Buffer(Scales_B_shape, "float32"), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + Scale_C_shared = T.alloc_shared((block_M), "float32") + C_local = T.alloc_fragment(C_shared_shape, accum_dtype) + C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape @@ -114,35 +191,42 @@ def calculate_diff(m: int, n: int, k: int): out_sglang = fp8_gemm_sglang( x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k ) - out_vllm = fp8_gemm_vllm( - x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k + + tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32") + tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1]) + out_tilelang = tilelang_kernel( + x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone() ) diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item() - diff_vllm_deepgemm = torch.abs(out_deepgemm - out_vllm).mean().item() - diff_vllm_sglang = torch.abs(out_vllm - out_sglang).mean().item() + diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item() + diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item() print(f"Shape m={m}, n={n}, k={k}:") print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}") print(f"SGLang output: {out_sglang[0, 0:5]}") - print(f"vLLM output: {out_vllm[0, 0:5]}") + print(f"TileLang output: {out_tilelang[0, 0:5]}") print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}") - print(f"Mean absolute difference (vLLM-DeepGEMM): {diff_vllm_deepgemm}") - print(f"Mean absolute difference (vLLM-SGLang): {diff_vllm_sglang}") + print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}") + print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}") sglang_deepgemm_match = torch.allclose( out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2 ) - vllm_deepgemm_match = torch.allclose(out_deepgemm, out_vllm, atol=1e-2, rtol=1e-2) - vllm_sglang_match = torch.allclose(out_vllm, out_sglang, atol=1e-2, rtol=1e-2) + tilelang_deepgemm_match = torch.allclose( + out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2 + ) + tilelang_sglang_match = torch.allclose( + out_tilelang, out_sglang, atol=1e-2, rtol=1e-2 + ) - if sglang_deepgemm_match and vllm_deepgemm_match and vllm_sglang_match: + if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match: print("✅ All implementations match\n") else: print("❌ Some implementations differ:") print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}") - print(f" - vLLM vs DeepGEMM: {'✅' if vllm_deepgemm_match else '❌'}") - print(f" - vLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}\n") + print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}") + print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n") def get_weight_shapes(tp_size): @@ -198,8 +282,8 @@ def get_benchmark(tp_size): x_names=["m", "n", "k", "tp_size"], x_vals=[list(config) for config in all_configs], line_arg="provider", - line_vals=["deepgemm", "sglang", "vllm"], - line_names=["DeepGEMM", "SGLang", "vLLM"], + line_vals=["deepgemm", "sglang", "tilelang"], + line_names=["DeepGEMM", "SGLang", "TileLang"], styles=[("blue", "-"), ("red", "-"), ("green", "-")], ylabel="ms", plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}", @@ -244,16 +328,15 @@ def get_benchmark(tp_size): ), quantiles=quantiles, ) - else: # vllm + else: # tilelang + tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32") + tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1]) ms, min_ms, max_ms = triton.testing.do_bench( - lambda: fp8_gemm_vllm( + lambda: tilelang_kernel( x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), - m, - n, - k, ), quantiles=quantiles, ) @@ -282,6 +365,7 @@ if __name__ == "__main__": parser.add_argument( "--run_correctness", action="store_true", + default=True, help="Whether to run correctness test", ) parser.add_argument(