208 lines
13 KiB
Python
208 lines
13 KiB
Python
import torch
|
|
import torch_mlu
|
|
import torch_mlu_ops as tmo
|
|
from common import benchmark_forward, save_to_csv
|
|
import argparse
|
|
from tabulate import tabulate
|
|
import os
|
|
import random
|
|
|
|
e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
|
|
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
|
|
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
|
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
|
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
|
|
def main():
|
|
if 'MLU3' in torch.mlu.get_device_name():
|
|
exit()
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
|
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
|
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
|
|
|
args = parser.parse_args()
|
|
device = 'mlu'
|
|
|
|
titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
|
|
"block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
|
contents = []
|
|
for params_dict in e2e_time_param_dict_list:
|
|
bs = params_dict["batch"]
|
|
seq_len = params_dict["seq_len"]
|
|
q_heads = params_dict["head_num_q"]
|
|
kv_heads = params_dict["head_num_k"]
|
|
head_size = params_dict["head_size"]
|
|
rope_dim = params_dict["rotary_dim"]
|
|
quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
|
|
paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
|
|
mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
|
|
max_decode_len = 0
|
|
num_blocks = 0
|
|
block_size = 0
|
|
if paged_cache:
|
|
num_blocks = params_dict["num_blocks"]
|
|
block_size = params_dict["block_size"]
|
|
else:
|
|
max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
|
|
input_dtype_list = params_dict["input_dtype"]
|
|
|
|
for dtype in input_dtype_list:
|
|
discrete_batch = True
|
|
max_bs = bs + 1 if discrete_batch else bs
|
|
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
|
|
input = torch.randn(size=input_shape, dtype=dtype).mlu()
|
|
input_ref = input.clone()
|
|
|
|
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
|
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
|
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
|
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
|
cache_dtype = dtype
|
|
if quant_kv:
|
|
k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
|
|
v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
|
|
cache_dtype = torch.int8
|
|
k_scale_ops = 1 / k_scale
|
|
v_scale_ops = 1 / v_scale
|
|
else:
|
|
k_scale = None
|
|
v_scale = None
|
|
k_scale_ops = None
|
|
v_scale_ops = None
|
|
if paged_cache:
|
|
cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
|
|
else:
|
|
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
|
|
|
|
if quant_kv:
|
|
cache = (cache - 0.5) * 256
|
|
cache = cache.to(cache_dtype)
|
|
k_cache = cache[0]
|
|
v_cache = cache[1]
|
|
|
|
cache_bs_id = None
|
|
cache_seq_offsets = None
|
|
slot_mapping = None
|
|
if not paged_cache:
|
|
if discrete_batch:
|
|
cache_bs_id = random.sample([*range(0, max_bs)], bs)
|
|
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
|
|
|
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
|
|
dtype=torch.int32, device='mlu')
|
|
else:
|
|
slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
|
|
slot_mapping = torch.IntTensor(slot_mapping).mlu()
|
|
|
|
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
|
|
|
|
k_cache_lp = None
|
|
v_cache_lp = None
|
|
k_scale_lp = None
|
|
v_scale_lp = None
|
|
cache_bs_id_lp = None
|
|
cache_seq_offsets_lp = None
|
|
|
|
if mixed_cache:
|
|
max_decode_len_lp = 1024
|
|
k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
|
|
v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
|
|
max_value = torch.amax(torch.abs(k_cache_raw))
|
|
k_cache_raw = k_cache_raw * (7 / max_value)
|
|
max_value = torch.amax(torch.abs(v_cache_raw))
|
|
v_cache_raw = v_cache_raw * (7 / max_value)
|
|
k_cache_lp = k_cache_raw.to(torch.int8)
|
|
v_cache_lp = v_cache_raw.to(torch.int8)
|
|
|
|
k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
|
|
v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
|
|
|
|
cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
|
|
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
|
|
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
|
|
dtype=torch.int32, device='mlu')
|
|
|
|
hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
|
|
input,
|
|
k_cache,
|
|
v_cache,
|
|
sin_table,
|
|
cos_table,
|
|
position_id,
|
|
gamma,
|
|
beta,
|
|
k_cache_lp,
|
|
v_cache_lp,
|
|
cache_bs_id,
|
|
cache_seq_offsets,
|
|
cache_bs_id_lp,
|
|
cache_seq_offsets_lp,
|
|
k_scale_ops,
|
|
v_scale_ops,
|
|
k_scale_lp,
|
|
v_scale_lp,
|
|
slot_mapping,
|
|
None,
|
|
1e-5,
|
|
repeats=args.repeat_times)
|
|
content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
|
|
f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
|
contents.append(content)
|
|
table = [titles] + contents
|
|
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
|
|
|
if args.csv:
|
|
current_file_path = __file__
|
|
_, file_name = os.path.split(current_file_path)
|
|
save_to_csv(table, args.o, file_name)
|
|
|
|
if __name__=="__main__":
|
|
main()
|