2025-03-02 15:01:58 +08:00
from typing import Tuple
import deep_gemm
2025-03-24 20:56:31 -07:00
import tilelang
import tilelang . language as T
2025-03-02 15:01:58 +08:00
import torch
import triton
from deep_gemm import ceil_div , get_col_major_tma_aligned_tensor
from vllm . model_executor . layers . quantization . utils . fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul ,
)
2025-05-29 00:15:11 -07:00
from sglang . srt . layers . quantization . fp8_kernel import (
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul ,
)
2025-03-02 15:01:58 +08:00
2025-03-24 20:56:31 -07:00
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
def tl_gemm (
M ,
N ,
K ,
in_dtype ,
out_dtype ,
accum_dtype ,
) :
assert in_dtype in [
" e4m3_float8 " ,
] , " Currently only e4m3_float8 is supported "
assert out_dtype in [
" bfloat16 " ,
" float16 " ,
] , " Currently only bfloat16 and float16 are supported "
TILE_SIZE = ( 128 , 128 , 128 )
block_M = TILE_SIZE [ 0 ]
block_N = TILE_SIZE [ 1 ]
block_K = TILE_SIZE [ 2 ]
A_shape = ( M , K )
Scales_A_shape = ( M , T . ceildiv ( K , block_K ) )
B_shape = ( N , K )
Scales_B_shape = ( T . ceildiv ( N , block_N ) , T . ceildiv ( K , block_K ) )
A_shared_shape = ( block_M , block_K )
B_shared_shape = ( block_N , block_K )
C_shared_shape = ( block_M , block_N )
@T.prim_func
def main (
A : T . Buffer ( A_shape , in_dtype ) ,
scales_a : T . Buffer ( Scales_A_shape , " float32 " ) ,
B : T . Buffer ( B_shape , in_dtype ) ,
scales_b : T . Buffer ( Scales_B_shape , " float32 " ) ,
C : T . Buffer ( ( M , N ) , out_dtype ) ,
) :
with T . Kernel ( T . ceildiv ( N , block_N ) , T . ceildiv ( M , block_M ) , threads = 128 ) as (
bx ,
by ,
) :
A_shared = T . alloc_shared ( A_shared_shape , in_dtype )
B_shared = T . alloc_shared ( B_shared_shape , in_dtype )
C_shared = T . alloc_shared ( C_shared_shape , out_dtype )
Scale_C_shared = T . alloc_shared ( ( block_M ) , " float32 " )
C_local = T . alloc_fragment ( C_shared_shape , accum_dtype )
C_local_accum = T . alloc_fragment ( C_shared_shape , accum_dtype )
# Improve L2 Cache
T . use_swizzle ( panel_size = 10 )
T . clear ( C_local )
T . clear ( C_local_accum )
K_iters = T . ceildiv ( K , block_K )
for k in T . Pipelined ( K_iters , num_stages = 4 ) :
# Load A into shared memory
T . copy ( A [ by * block_M , k * block_K ] , A_shared )
# Load B into shared memory
T . copy ( B [ bx * block_N , k * block_K ] , B_shared )
# Load scale into shared memory
Scale_B = scales_b [ bx , k ]
for i in T . Parallel ( block_M ) :
Scale_C_shared [ i ] = scales_a [ by * block_M + i , k ] * Scale_B
T . gemm ( A_shared , B_shared , C_local , transpose_B = True )
# Promote to enable 2xAcc
for i , j in T . Parallel ( block_M , block_N ) :
C_local_accum [ i , j ] + = C_local [ i , j ] * Scale_C_shared [ i ]
T . clear ( C_local )
# TMA store
T . copy ( C_local_accum , C_shared )
T . copy ( C_shared , C [ by * block_M , bx * block_N ] )
return main
2025-03-02 15:01:58 +08:00
def per_token_cast_to_fp8 ( x : torch . Tensor ) - > Tuple [ torch . Tensor , torch . Tensor ] :
assert x . dim ( ) == 2 and x . size ( 1 ) % 128 == 0
m , n = x . shape
x_view = x . view ( m , - 1 , 128 )
x_amax = x_view . abs ( ) . float ( ) . amax ( dim = 2 ) . view ( m , - 1 ) . clamp ( 1e-4 )
return ( x_view * ( 448.0 / x_amax . unsqueeze ( 2 ) ) ) . to ( torch . float8_e4m3fn ) . view (
m , n
) , ( x_amax / 448.0 ) . view ( m , - 1 )
def per_block_cast_to_fp8 ( x : torch . Tensor ) - > Tuple [ torch . Tensor , torch . Tensor ] :
assert x . dim ( ) == 2
m , n = x . shape
x_padded = torch . zeros (
( ceil_div ( m , 128 ) * 128 , ceil_div ( n , 128 ) * 128 ) , dtype = x . dtype , device = x . device
)
x_padded [ : m , : n ] = x
x_view = x_padded . view ( - 1 , 128 , x_padded . size ( 1 ) / / 128 , 128 )
x_amax = x_view . abs ( ) . float ( ) . amax ( dim = ( 1 , 3 ) , keepdim = True ) . clamp ( 1e-4 )
x_scaled = ( x_view * ( 448.0 / x_amax ) ) . to ( torch . float8_e4m3fn )
return x_scaled . view_as ( x_padded ) [ : m , : n ] . contiguous ( ) , ( x_amax / 448.0 ) . view (
x_view . size ( 0 ) , x_view . size ( 2 )
)
def fp8_gemm_deepgemm (
x_fp8 : torch . Tensor ,
x_scale : torch . Tensor ,
y_fp8 : torch . Tensor ,
y_scale : torch . Tensor ,
m : int ,
n : int ,
k : int ,
) :
""" DeepGEMM implementation of FP8 GEMM """
out = torch . empty ( ( m , n ) , device = " cuda " , dtype = torch . bfloat16 )
# Run DeepGEMM kernel
deep_gemm . gemm_fp8_fp8_bf16_nt ( ( x_fp8 , x_scale ) , ( y_fp8 , y_scale ) , out )
return out
def fp8_gemm_sglang (
x_fp8 : torch . Tensor ,
x_scale : torch . Tensor ,
y_fp8 : torch . Tensor ,
y_scale : torch . Tensor ,
m : int ,
n : int ,
k : int ,
) :
""" SGLang implementation of FP8 GEMM """
block_size = [ 128 , 128 ] # Matches the block size in per_block_cast_to_fp8
# Run SGLang kernel
out = w8a8_block_fp8_matmul (
x_fp8 , y_fp8 , x_scale , y_scale , block_size , torch . bfloat16
)
return out
def fp8_gemm_vllm (
x_fp8 : torch . Tensor ,
x_scale : torch . Tensor ,
y_fp8 : torch . Tensor ,
y_scale : torch . Tensor ,
m : int ,
n : int ,
k : int ,
) :
""" vLLM implementation of FP8 GEMM """
block_size = [ 128 , 128 ] # Matches the block size in per_block_cast_to_fp8
# Run vLLM kernel
out = vllm_w8a8_block_fp8_matmul (
x_fp8 , y_fp8 , x_scale , y_scale , block_size , torch . bfloat16
)
return out
def calculate_diff ( m : int , n : int , k : int ) :
x = torch . randn ( ( m , k ) , device = " cuda " , dtype = torch . bfloat16 )
y = torch . randn ( ( n , k ) , device = " cuda " , dtype = torch . bfloat16 )
x_fp8 , x_scale = per_token_cast_to_fp8 ( x . clone ( ) )
y_fp8 , y_scale = per_block_cast_to_fp8 ( y . clone ( ) )
x_scale_col_major = get_col_major_tma_aligned_tensor ( x_scale . clone ( ) )
out_deepgemm = fp8_gemm_deepgemm (
x_fp8 . clone ( ) ,
x_scale_col_major . clone ( ) ,
y_fp8 . clone ( ) ,
y_scale . clone ( ) ,
m ,
n ,
k ,
)
out_sglang = fp8_gemm_sglang (
x_fp8 . clone ( ) , x_scale . clone ( ) , y_fp8 . clone ( ) , y_scale . clone ( ) , m , n , k
)
2025-03-24 20:56:31 -07:00
tilelang_func = tl_gemm ( m , n , k , " e4m3_float8 " , " bfloat16 " , " float32 " )
tilelang_kernel = tilelang . compile ( tilelang_func , out_idx = [ - 1 ] )
out_tilelang = tilelang_kernel (
x_fp8 . clone ( ) , x_scale . clone ( ) , y_fp8 . clone ( ) , y_scale . clone ( )
2025-03-02 15:01:58 +08:00
)
diff_sglang_deepgemm = torch . abs ( out_deepgemm - out_sglang ) . mean ( ) . item ( )
2025-03-24 20:56:31 -07:00
diff_tilelang_deepgemm = torch . abs ( out_deepgemm - out_tilelang ) . mean ( ) . item ( )
diff_tilelang_sglang = torch . abs ( out_tilelang - out_sglang ) . mean ( ) . item ( )
2025-03-02 15:01:58 +08:00
print ( f " Shape m= { m } , n= { n } , k= { k } : " )
print ( f " DeepGEMM output: { out_deepgemm [ 0 , 0 : 5 ] } " )
print ( f " SGLang output: { out_sglang [ 0 , 0 : 5 ] } " )
2025-03-24 20:56:31 -07:00
print ( f " TileLang output: { out_tilelang [ 0 , 0 : 5 ] } " )
2025-03-02 15:01:58 +08:00
print ( f " Mean absolute difference (SGLang-DeepGEMM): { diff_sglang_deepgemm } " )
2025-03-24 20:56:31 -07:00
print ( f " Mean absolute difference (TileLang-DeepGEMM): { diff_tilelang_deepgemm } " )
print ( f " Mean absolute difference (TileLang-SGLang): { diff_tilelang_sglang } " )
2025-03-02 15:01:58 +08:00
sglang_deepgemm_match = torch . allclose (
out_deepgemm , out_sglang , atol = 1e-2 , rtol = 1e-2
)
2025-03-24 20:56:31 -07:00
tilelang_deepgemm_match = torch . allclose (
out_deepgemm , out_tilelang , atol = 1e-2 , rtol = 1e-2
)
tilelang_sglang_match = torch . allclose (
out_tilelang , out_sglang , atol = 1e-2 , rtol = 1e-2
)
2025-03-02 15:01:58 +08:00
2025-03-24 20:56:31 -07:00
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match :
2025-03-02 15:01:58 +08:00
print ( " ✅ All implementations match \n " )
else :
print ( " ❌ Some implementations differ: " )
print ( f " - SGLang vs DeepGEMM: { ' ✅ ' if sglang_deepgemm_match else ' ❌ ' } " )
2025-03-24 20:56:31 -07:00
print ( f " - TileLang vs DeepGEMM: { ' ✅ ' if tilelang_deepgemm_match else ' ❌ ' } " )
print ( f " - TileLang vs SGLang: { ' ✅ ' if tilelang_sglang_match else ' ❌ ' } \n " )
2025-03-02 15:01:58 +08:00
def get_weight_shapes ( tp_size ) :
# cannot TP
total = [
( 512 + 64 , 7168 ) ,
( ( 128 + 64 ) * 128 , 7168 ) ,
( 128 * ( 128 + 128 ) , 512 ) ,
( 7168 , 16384 ) ,
( 7168 , 18432 ) ,
]
# N can TP
n_tp = [
( 18432 * 2 , 7168 ) ,
( ( 128 + 64 ) * 128 , 7168 ) ,
( 128 * ( 128 + 128 ) , 512 ) ,
( 24576 , 1536 ) ,
( 4096 , 7168 ) ,
]
# K can TP
k_tp = [ ( 7168 , 18432 ) , ( 7168 , 16384 ) , ( 7168 , 2048 ) ]
weight_shapes = [ ]
for t in total :
weight_shapes . append ( t )
for n_t in n_tp :
new_t = ( n_t [ 0 ] / / tp_size , n_t [ 1 ] )
weight_shapes . append ( new_t )
for k_t in k_tp :
new_t = ( k_t [ 0 ] , k_t [ 1 ] / / tp_size )
weight_shapes . append ( new_t )
return weight_shapes
def create_benchmark_configs ( tp_size ) :
configs = [ ]
weight_shapes = get_weight_shapes ( tp_size )
batch_sizes = [ 8 , 16 , 32 , 64 , 128 , 256 , 1024 , 2048 , 4096 ]
for n , k in weight_shapes :
for m in batch_sizes :
configs . append ( ( m , n , k , tp_size ) )
return configs
def get_benchmark ( tp_size ) :
all_configs = create_benchmark_configs ( tp_size )
@triton.testing.perf_report (
triton . testing . Benchmark (
x_names = [ " m " , " n " , " k " , " tp_size " ] ,
x_vals = [ list ( config ) for config in all_configs ] ,
line_arg = " provider " ,
2025-03-24 20:56:31 -07:00
line_vals = [ " deepgemm " , " sglang " , " tilelang " ] ,
line_names = [ " DeepGEMM " , " SGLang " , " TileLang " ] ,
2025-03-02 15:01:58 +08:00
styles = [ ( " blue " , " - " ) , ( " red " , " - " ) , ( " green " , " - " ) ] ,
ylabel = " ms " ,
plot_name = f " fp8-gemm-performance-comparison-tp { tp_size } " ,
args = { } ,
)
)
def benchmark ( m , n , k , tp_size , provider ) :
print ( f " Shape (m= { m } , n= { n } , k= { k } , tp= { tp_size } ), Provider: { provider } " )
x = torch . randn ( ( m , k ) , device = " cuda " , dtype = torch . bfloat16 )
y = torch . randn ( ( n , k ) , device = " cuda " , dtype = torch . bfloat16 )
2025-03-02 17:47:21 -08:00
# Preprocess data before benchmarking
2025-03-02 15:01:58 +08:00
x_fp8 , x_scale = per_token_cast_to_fp8 ( x )
y_fp8 , y_scale = per_block_cast_to_fp8 ( y )
x_scale_col_major = get_col_major_tma_aligned_tensor ( x_scale . clone ( ) )
quantiles = [ 0.5 , 0.2 , 0.8 ]
if provider == " deepgemm " :
ms , min_ms , max_ms = triton . testing . do_bench (
lambda : fp8_gemm_deepgemm (
x_fp8 . clone ( ) ,
x_scale_col_major . clone ( ) ,
y_fp8 . clone ( ) ,
y_scale . clone ( ) ,
m ,
n ,
k ,
) ,
quantiles = quantiles ,
)
elif provider == " sglang " :
ms , min_ms , max_ms = triton . testing . do_bench (
lambda : fp8_gemm_sglang (
x_fp8 . clone ( ) ,
x_scale . clone ( ) ,
y_fp8 . clone ( ) ,
y_scale . clone ( ) ,
m ,
n ,
k ,
) ,
quantiles = quantiles ,
)
2025-03-24 20:56:31 -07:00
else : # tilelang
tilelang_func = tl_gemm ( m , n , k , " e4m3_float8 " , " bfloat16 " , " float32 " )
tilelang_kernel = tilelang . compile ( tilelang_func , out_idx = [ - 1 ] )
2025-03-02 15:01:58 +08:00
ms , min_ms , max_ms = triton . testing . do_bench (
2025-03-24 20:56:31 -07:00
lambda : tilelang_kernel (
2025-03-02 15:01:58 +08:00
x_fp8 . clone ( ) ,
x_scale . clone ( ) ,
y_fp8 . clone ( ) ,
y_scale . clone ( ) ,
) ,
quantiles = quantiles ,
)
# Calculate TFLOPS
flops = 2 * m * n * k # multiply-adds
tflops = flops / ( ms * 1e-3 ) / 1e12
# Print shape-specific results with TFLOPS
print ( f " Time: { ms * 1000 : .2f } ms, TFLOPS: { tflops : .2f } " )
return ms * 1000 , max_ms * 1000 , min_ms * 1000 # convert to ms
return benchmark
if __name__ == " __main__ " :
import argparse
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --save_path " ,
type = str ,
default = " ./configs/benchmark_ops/fp8_gemm/ " ,
help = " Path to save fp8 gemm benchmark results " ,
)
parser . add_argument (
" --run_correctness " ,
action = " store_true " ,
2025-03-24 20:56:31 -07:00
default = True ,
2025-03-02 15:01:58 +08:00
help = " Whether to run correctness test " ,
)
parser . add_argument (
" --tp_size " ,
type = int ,
default = 1 ,
help = " Tensor parallelism size to benchmark (default: 1) " ,
)
args = parser . parse_args ( )
# Set random seed for reproducibility
torch . manual_seed ( 0 )
torch . cuda . manual_seed ( 0 )
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
torch . backends . cuda . matmul . allow_tf32 = True
torch . backends . cudnn . allow_tf32 = True
# Run correctness tests on a few examples
if args . run_correctness :
print ( " Running correctness tests... " )
calculate_diff ( 64 , 512 , 7168 ) # Small test
calculate_diff ( 64 , 7168 , 16384 ) # Medium test
calculate_diff ( 64 , 18432 , 7168 ) # Large test
# Get the benchmark function with the specified tp_size
benchmark = get_benchmark ( args . tp_size )
print ( f " Running performance benchmark for TP size = { args . tp_size } ... " )
benchmark . run ( print_data = True , save_path = args . save_path )