import torch import torch_mlu import torch_mlu_ops as tmo from common import benchmark_forward, save_to_csv import argparse from tabulate import tabulate import os import random e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},] def main(): if 'MLU3' in torch.mlu.get_device_name(): exit() parser = argparse.ArgumentParser() parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') parser.add_argument('--csv', action='store_true', help='write the report data to csv') parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') args = parser.parse_args() device = 'mlu' titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \ "block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] contents = [] for params_dict in e2e_time_param_dict_list: bs = params_dict["batch"] seq_len = params_dict["seq_len"] q_heads = params_dict["head_num_q"] kv_heads = params_dict["head_num_k"] head_size = params_dict["head_size"] rope_dim = params_dict["rotary_dim"] quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False max_decode_len = 0 num_blocks = 0 block_size = 0 if paged_cache: num_blocks = params_dict["num_blocks"] block_size = params_dict["block_size"] else: max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32 input_dtype_list = params_dict["input_dtype"] for dtype in input_dtype_list: discrete_batch = True max_bs = bs + 1 if discrete_batch else bs input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size) input = torch.randn(size=input_shape, dtype=dtype).mlu() input_ref = input.clone() cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu() beta = torch.randn(size=(head_size, ), dtype=dtype).mlu() cache_dtype = dtype if quant_kv: k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu() v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu() cache_dtype = torch.int8 k_scale_ops = 1 / k_scale v_scale_ops = 1 / v_scale else: k_scale = None v_scale = None k_scale_ops = None v_scale_ops = None if paged_cache: cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu') else: cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu') if quant_kv: cache = (cache - 0.5) * 256 cache = cache.to(cache_dtype) k_cache = cache[0] v_cache = cache[1] cache_bs_id = None cache_seq_offsets = None slot_mapping = None if not paged_cache: if discrete_batch: cache_bs_id = random.sample([*range(0, max_bs)], bs) cache_bs_id = torch.IntTensor(cache_bs_id).mlu() cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2, dtype=torch.int32, device='mlu') else: slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs) slot_mapping = torch.IntTensor(slot_mapping).mlu() position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu') k_cache_lp = None v_cache_lp = None k_scale_lp = None v_scale_lp = None cache_bs_id_lp = None cache_seq_offsets_lp = None if mixed_cache: max_decode_len_lp = 1024 k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu') v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu') max_value = torch.amax(torch.abs(k_cache_raw)) k_cache_raw = k_cache_raw * (7 / max_value) max_value = torch.amax(torch.abs(v_cache_raw)) v_cache_raw = v_cache_raw * (7 / max_value) k_cache_lp = k_cache_raw.to(torch.int8) v_cache_lp = v_cache_raw.to(torch.int8) k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu() v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu() cache_bs_id_lp = random.sample([*range(0, max_bs)], bs) cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu() cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2, dtype=torch.int32, device='mlu') hardware_time, e2e_time = benchmark_forward(tmo.fused_rope, input, k_cache, v_cache, sin_table, cos_table, position_id, gamma, beta, k_cache_lp, v_cache_lp, cache_bs_id, cache_seq_offsets, cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops, v_scale_ops, k_scale_lp, v_scale_lp, slot_mapping, None, 1e-5, repeats=args.repeat_times) content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \ f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] contents.append(content) table = [titles] + contents print(tabulate(table, headers="firstrow", tablefmt="grid")) if args.csv: current_file_path = __file__ _, file_name = os.path.split(current_file_path) save_to_csv(table, args.o, file_name) if __name__=="__main__": main()