forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
207
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py
Normal file
207
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
|
||||
def main():
|
||||
if 'MLU3' in torch.mlu.get_device_name():
|
||||
exit()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
|
||||
"block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
bs = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
q_heads = params_dict["head_num_q"]
|
||||
kv_heads = params_dict["head_num_k"]
|
||||
head_size = params_dict["head_size"]
|
||||
rope_dim = params_dict["rotary_dim"]
|
||||
quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
|
||||
paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
|
||||
mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
|
||||
max_decode_len = 0
|
||||
num_blocks = 0
|
||||
block_size = 0
|
||||
if paged_cache:
|
||||
num_blocks = params_dict["num_blocks"]
|
||||
block_size = params_dict["block_size"]
|
||||
else:
|
||||
max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
|
||||
for dtype in input_dtype_list:
|
||||
discrete_batch = True
|
||||
max_bs = bs + 1 if discrete_batch else bs
|
||||
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
|
||||
input = torch.randn(size=input_shape, dtype=dtype).mlu()
|
||||
input_ref = input.clone()
|
||||
|
||||
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
cache_dtype = dtype
|
||||
if quant_kv:
|
||||
k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
|
||||
v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
|
||||
cache_dtype = torch.int8
|
||||
k_scale_ops = 1 / k_scale
|
||||
v_scale_ops = 1 / v_scale
|
||||
else:
|
||||
k_scale = None
|
||||
v_scale = None
|
||||
k_scale_ops = None
|
||||
v_scale_ops = None
|
||||
if paged_cache:
|
||||
cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
|
||||
else:
|
||||
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
|
||||
|
||||
if quant_kv:
|
||||
cache = (cache - 0.5) * 256
|
||||
cache = cache.to(cache_dtype)
|
||||
k_cache = cache[0]
|
||||
v_cache = cache[1]
|
||||
|
||||
cache_bs_id = None
|
||||
cache_seq_offsets = None
|
||||
slot_mapping = None
|
||||
if not paged_cache:
|
||||
if discrete_batch:
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], bs)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
|
||||
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
else:
|
||||
slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
|
||||
slot_mapping = torch.IntTensor(slot_mapping).mlu()
|
||||
|
||||
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
|
||||
|
||||
k_cache_lp = None
|
||||
v_cache_lp = None
|
||||
k_scale_lp = None
|
||||
v_scale_lp = None
|
||||
cache_bs_id_lp = None
|
||||
cache_seq_offsets_lp = None
|
||||
|
||||
if mixed_cache:
|
||||
max_decode_len_lp = 1024
|
||||
k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
|
||||
v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
|
||||
max_value = torch.amax(torch.abs(k_cache_raw))
|
||||
k_cache_raw = k_cache_raw * (7 / max_value)
|
||||
max_value = torch.amax(torch.abs(v_cache_raw))
|
||||
v_cache_raw = v_cache_raw * (7 / max_value)
|
||||
k_cache_lp = k_cache_raw.to(torch.int8)
|
||||
v_cache_lp = v_cache_raw.to(torch.int8)
|
||||
|
||||
k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
|
||||
v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
|
||||
|
||||
cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
|
||||
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
|
||||
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
|
||||
input,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sin_table,
|
||||
cos_table,
|
||||
position_id,
|
||||
gamma,
|
||||
beta,
|
||||
k_cache_lp,
|
||||
v_cache_lp,
|
||||
cache_bs_id,
|
||||
cache_seq_offsets,
|
||||
cache_bs_id_lp,
|
||||
cache_seq_offsets_lp,
|
||||
k_scale_ops,
|
||||
v_scale_ops,
|
||||
k_scale_lp,
|
||||
v_scale_lp,
|
||||
slot_mapping,
|
||||
None,
|
||||
1e-5,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
|
||||
f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user