Add shapes for int8 gemm benchmark (#3093)
This commit is contained in:
@@ -1,3 +1,7 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from sgl_kernel import int8_scaled_mm
|
from sgl_kernel import int8_scaled_mm
|
||||||
@@ -8,6 +12,56 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-Large-Instruct-2407": [
|
||||||
|
([12288, 14336], 1),
|
||||||
|
([12288, 12288], 0),
|
||||||
|
([12288, 57344], 1),
|
||||||
|
([28672, 12288], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct": [
|
||||||
|
([3584, 4608], 1),
|
||||||
|
([3584, 3584], 0),
|
||||||
|
([3584, 37888], 1),
|
||||||
|
([18944, 3584], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-32B-Instruct": [
|
||||||
|
([5120, 7168], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 55296], 1),
|
||||||
|
([27648, 5120], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-72B-Instruct": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 59136], 1),
|
||||||
|
([29568, 8192], 0),
|
||||||
|
],
|
||||||
|
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||||
|
([2048, 3072], 1),
|
||||||
|
([2048, 4096], 1),
|
||||||
|
([2048, 2048], 0),
|
||||||
|
([2048, 576], 0),
|
||||||
|
([2048, 21888], 1),
|
||||||
|
([10944, 2048], 0),
|
||||||
|
([2048, 2816], 1),
|
||||||
|
([1408, 2048], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["batch_size"],
|
x_names=["batch_size"],
|
||||||
@@ -22,8 +76,8 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
args={},
|
args={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def benchmark(batch_size, provider):
|
def benchmark(batch_size, provider, N, K):
|
||||||
M, N, K = batch_size, 4096, 8192
|
M = batch_size
|
||||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||||
@@ -52,4 +106,41 @@ def benchmark(batch_size, provider):
|
|||||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||||
|
|
||||||
|
|
||||||
benchmark.run(print_data=True, show_plots=True, save_path="bench_int8_res")
|
def prepare_shapes(args):
|
||||||
|
KN_model_names = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
assert model in WEIGHT_SHAPES
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KN.append(model)
|
||||||
|
KN_model_names.append(KN)
|
||||||
|
return KN_model_names
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||||
|
help="List of models to benchmark",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=[1],
|
||||||
|
help="List of tensor parallel sizes",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
KN_model_names = prepare_shapes(args)
|
||||||
|
for K, N, model_name in KN_model_names:
|
||||||
|
print(f"{model_name} N={N} K={K}: ")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
|
|||||||
Reference in New Issue
Block a user