diff --git a/sgl-kernel/benchmark/bench_int8_gemm.py b/sgl-kernel/benchmark/bench_int8_gemm.py index 2657c616c..c5a709393 100644 --- a/sgl-kernel/benchmark/bench_int8_gemm.py +++ b/sgl-kernel/benchmark/bench_int8_gemm.py @@ -1,3 +1,7 @@ +import argparse +import copy +import itertools + import torch import triton from sgl_kernel import int8_scaled_mm @@ -8,6 +12,56 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], @@ -22,8 +76,8 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: args={}, ) ) -def benchmark(batch_size, provider): - M, N, K = batch_size, 4096, 8192 +def benchmark(batch_size, provider, N, K): + M = batch_size a = to_int8(torch.randn((M, K), device="cuda") * 5) b = to_int8(torch.randn((N, K), device="cuda").t() * 5) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) @@ -52,4 +106,41 @@ def benchmark(batch_size, provider): return gbps(ms), gbps(max_ms), gbps(min_ms) -benchmark.run(print_data=True, show_plots=True, save_path="bench_int8_res") +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K + ) + + print("Benchmark finished!")