Files
2026-02-04 17:39:32 +08:00

208 lines
13 KiB
Python

import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
"block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
bs = params_dict["batch"]
seq_len = params_dict["seq_len"]
q_heads = params_dict["head_num_q"]
kv_heads = params_dict["head_num_k"]
head_size = params_dict["head_size"]
rope_dim = params_dict["rotary_dim"]
quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
max_decode_len = 0
num_blocks = 0
block_size = 0
if paged_cache:
num_blocks = params_dict["num_blocks"]
block_size = params_dict["block_size"]
else:
max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
discrete_batch = True
max_bs = bs + 1 if discrete_batch else bs
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
input = torch.randn(size=input_shape, dtype=dtype).mlu()
input_ref = input.clone()
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
cache_dtype = dtype
if quant_kv:
k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
cache_dtype = torch.int8
k_scale_ops = 1 / k_scale
v_scale_ops = 1 / v_scale
else:
k_scale = None
v_scale = None
k_scale_ops = None
v_scale_ops = None
if paged_cache:
cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
else:
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
if quant_kv:
cache = (cache - 0.5) * 256
cache = cache.to(cache_dtype)
k_cache = cache[0]
v_cache = cache[1]
cache_bs_id = None
cache_seq_offsets = None
slot_mapping = None
if not paged_cache:
if discrete_batch:
cache_bs_id = random.sample([*range(0, max_bs)], bs)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
dtype=torch.int32, device='mlu')
else:
slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
slot_mapping = torch.IntTensor(slot_mapping).mlu()
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
k_cache_lp = None
v_cache_lp = None
k_scale_lp = None
v_scale_lp = None
cache_bs_id_lp = None
cache_seq_offsets_lp = None
if mixed_cache:
max_decode_len_lp = 1024
k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(k_cache_raw))
k_cache_raw = k_cache_raw * (7 / max_value)
max_value = torch.amax(torch.abs(v_cache_raw))
v_cache_raw = v_cache_raw * (7 / max_value)
k_cache_lp = k_cache_raw.to(torch.int8)
v_cache_lp = v_cache_raw.to(torch.int8)
k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
dtype=torch.int32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
input,
k_cache,
v_cache,
sin_table,
cos_table,
position_id,
gamma,
beta,
k_cache_lp,
v_cache_lp,
cache_bs_id,
cache_seq_offsets,
cache_bs_id_lp,
cache_seq_offsets_lp,
k_scale_ops,
v_scale_ops,
k_scale_lp,
v_scale_lp,
slot_mapping,
None,
1e-5,
repeats=args.repeat_times)
content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()