import torch import torch_mlu import torch_mlu_ops as tmo from common import benchmark_forward, save_to_csv import argparse from tabulate import tabulate import os import random e2e_time_param_dict_list = [ {"batch": 1, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 16, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 16, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 72, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 72, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 128, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 128, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 490, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 490, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 525, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 525, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 1024, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 1024, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 2048, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 2048, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 4096, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 4096, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 8192, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 8192, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 32768, "k": 8192, "n": 1024, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, {"batch": 1, "seq_len": 32768, "k": 1024, "n": 8192, "expert_num": 32, "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, ] def main(): parser = argparse.ArgumentParser() parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') parser.add_argument('--csv', action='store_true', help='write the report data to csv') parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') args = parser.parse_args() titles = ["batch", "seq_len", "k", "n", "expert_num", "topk", "smooth_quant", "dtype", "hardware_time(us)", "e2e_latency(us)"] contents = [] for params_dict in e2e_time_param_dict_list: batch = params_dict["batch"] seq_len = params_dict["seq_len"] k = params_dict["k"] n = params_dict["n"] expert_num = params_dict["expert_num"] topk = params_dict["topk"] is_quant = params_dict["is_quant"] input_dtype_list = params_dict["dtype"] # print(f"batch:{batch}, seq_len:{seq_len}, k:{k}, n:{n}, expert_num:{expert_num}, topk:{topk}, is_quant:{is_quant}") for dtype in input_dtype_list: if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): dtype = torch.half max_m = batch * seq_len m = batch * seq_len * topk avg, rem = m // expert_num, m % expert_num m_list = [avg + (i < rem) for i in range(expert_num)] token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu') if not is_quant: a = torch.randn(m, k, dtype=dtype, device='mlu') b = torch.randn(expert_num, n, k, dtype=dtype, device='mlu') hardware_time, e2e_time = benchmark_forward(tmo.group_gemm, a, b, token_count, None, None, None, None, max_m, repeats=args.repeat_times) else: a = torch.randint(-128, 127, (m, k)).to(torch.int8).mlu() b = torch.randint(-128, 127, (expert_num, n, k)).to(torch.int8).mlu() a_scale = torch.randn(a.size(0), dtype=torch.float32, device='mlu') b_scale = torch.randn(expert_num, n, dtype=torch.float32, device='mlu') hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_group_gemm, a, b, token_count, None, None, None, None, a_scale, b_scale, dtype, max_m, repeats=args.repeat_times) content = [f"{batch}", f"{seq_len}", f"{k}", f"{n}", f"{expert_num}", f"{topk}", f"{is_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] contents.append(content) table = [titles] + contents print(tabulate(table, headers="firstrow", tablefmt="grid")) if args.csv: current_file_path = __file__ _, file_name = os.path.split(current_file_path) save_to_csv(table, args.o, file_name) if __name__=="__main__": main()