This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,51 @@
## benchmark测试脚本使用方式
Torch-MLU-Ops benchmark测试脚本为用户提供了进行算子性能测试的便捷入口。
用户可通过以下命令获取各个参数的含义。
```bash
# 测试命令帮助
python3 benchmark_xxx.py --help
```
各个参数含义如下:
`options`:
- -h, --help show this help message and exit
- --repeat_times REPEAT_TIMES repeat times for testing
- --csv write the report data to csv
- -o O specify the output folder name under --csv mode
```bash
# 测试命令示例如下
python3 benchmark_active.py --repeat_times 10 --csv -o './active/'
```
支持如下算子:
| op_name |
| ---------------------------------|
| active |
| apply_rotary |
| attention_project |
| ffn |
| flash_attn |
| fused_layer_norm |
| fused_moe |
| fused_norm_attention_project |
| fused_norm_residual_ffn |
| fused_rms_norm |
| group_gemm |
| matmul |
| offline_quant_to_linear_cache |
| per_token_smooth_quantize |
| preload |
| quantize |
| reshape_linear_cache |
| quant_to_linear_cache |
| reshape_paged_cache |
| single_query_cached_kv_attn |
| smooth_quant_matmul |
| weight_only_quant_matmul |
| moe_gen_idx |
| moe_expand_input |
| moe_softmax_topk |
| moe_combine_result |

View File

@@ -0,0 +1,64 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 1024, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 4096, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 8192, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 32768, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["input_shape", "act_mode", "is_gated", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
act_mode = params_dict["act_mode"]
is_gated = params_dict["is_gated"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.active,
input,
act_mode,
is_gated,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated)
io_eff = io_bytes / hardware_time / bd
content = [f"{batch,seq_len,hidden_size}", f"{act_mode}", f"{is_gated}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,84 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "rotary_dim": 64,
"interleaved": True, "discrete": False, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "rotary_dim": 64,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "rotary_dim": 96,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 4, "seq_len": 1, "head_num": 80, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "head_num", "head_size", "rotary_dim", "interleaved", "discrete", "dynamic_ntk", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
head_num = params_dict["head_num"]
head_size = params_dict["head_size"]
# full/partial
rotary_dim = params_dict["rotary_dim"]
# cross/fold
interleaved = params_dict["interleaved"]
# discrete
discrete = params_dict["discrete"]
# dynamic_ntk
dynamic_ntk = params_dict["dynamic_ntk"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, head_num, head_size).to(device).to(dtype) # [batch, seqlen, head_num, head_size]
if dynamic_ntk:
sin_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
cos_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
else:
sin_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
cos_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
if discrete:
pos_ids = torch.randint(0, seq_len, (batch * seq_len,)).to(device).to(torch.int32)
else:
pos_ids = None
hardware_time, e2e_time = benchmark_forward(tmo.apply_rotary,
input,
sin_cache,
cos_cache,
pos_ids,
None,
interleaved,
discrete,
dynamic_ntk,
seq_len,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{head_num}", f"{head_size}", f"{rotary_dim}", f"{interleaved}", f"{discrete}", f"{dynamic_ntk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,74 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "hidden_size": 1600,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 2048, "hidden_size": 2048,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 4096, "hidden_size": 4096,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6144, "hidden_size": 6144,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6656, "hidden_size": 6656,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 8192, "hidden_size": 8192,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 12288, "hidden_size": 12288,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 14336, "hidden_size": 14336,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "input_size", "hidden_size", "has_residual", "has_bias", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
input_size = params_dict["input_size"]
hidden_size = params_dict["hidden_size"]
has_residual = params_dict["has_residual"]
has_bias = params_dict["has_bias"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
x = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
weight = torch.randn(hidden_size, input_size).to(dtype).to(device)
residual, bias = None, None
if has_residual:
residual = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
if has_bias:
bias = torch.randn(hidden_size).to(dtype).to(device)
hardware_time, e2e_time = benchmark_forward(tmo.attention_project,
x,
weight,
bias,
residual,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{has_residual}", f"{has_bias}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,60 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 2, "m": 1024, "k": 1600, "n": 6400, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 2048, "n": 8192, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 4096, "n": 11008, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 5120, "n": 16384, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 6144, "n": 24576, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "m", "k", "n", "has_c", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
has_c = params_dict["has_c"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(batch, m, k).to(device).to(dtype)
b = torch.randn(batch, n, k).to(device).to(dtype)
c = None
if has_c:
c = torch.randn(batch, m, n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.batch_matmul,
a,
b,
c,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{batch}", f"{m}", f"{k}", f"{n}", f"{has_c}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,120 @@
import argparse
import random
import os
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
from itertools import product
from tabulate import tabulate
e2e_time_param_dict_list = [
{"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
"head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "head_size": 128,
"input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [4, 8],
"use_offset": True},
]
def main():
parser = argparse.ArgumentParser(description="Benchmark for dequant from linear cache.")
parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["max_batch_size", "batch_size", "max_context_len", "head_num_q", "head_num_kv",
"cache_mem_len", "head_size", "input_dytpe", "quant_mode", "quant_bit",
"use_offset", "hardware_time(us)", "e2e_latency(us)"]
contents = []
mlu_name = torch.mlu.get_device_name()
for params_dict in e2e_time_param_dict_list:
max_batch_size = params_dict["max_batch_size"]
batch_size_list = params_dict["batch_size"]
max_context_len_list = params_dict["max_context_len"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
cache_mem_len = params_dict["cache_mem_len"]
input_dtype_list = params_dict["input_dtype"]
quant_mode_list = params_dict["quant_mode"]
quant_bit_list = params_dict["quant_bit"]
use_offset = params_dict["use_offset"]
for batch_size, max_context_len, quant_mode, quant_bit, dtype in list(product( \
batch_size_list, max_context_len_list, quant_mode_list, quant_bit_list, \
input_dtype_list)):
torch.manual_seed(2766)
torch.mlu.manual_seed(2766)
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \
* head_size >= 2**31 - 1):
print("large tensor is not support on {}, skip".format(mlu_name))
continue
total_heads = head_num_q + head_num_kv * 2
assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
"equal to cache_mem_len."
max_seq_offset = cache_mem_len - max_context_len
# Generates key and cache from context
context_lens = torch.randint(size=[batch_size], low=max_context_len,
high=max_context_len + 1,
dtype=torch.int32, device="mlu")
if use_offset:
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
dtype=torch.int32, device="mlu")
else:
context_paddings = torch.zeros_like(context_lens)
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
total_seqlen = cu_context_lens[-1]
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
context_seq_offset[1:] = cu_context_lens[:-1]
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
key = context[..., head_num_q:head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
# Generates key_cache and value_cache
cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu()
cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size],
dtype=torch.int32, device="mlu")
if quant_bit == 4:
key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
head_num_kv, cache_mem_len, head_size // 2), device="mlu")
value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
head_num_kv, cache_mem_len // 2, head_size), device="mlu")
key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8)
else:
cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size),
low=-128, high=127, dtype=torch.int32, device="mlu")
cache = cache.to(torch.int8)
key_cache, value_cache = cache[[0, 1]]
# Generates key_cache_scale and value_cache_scale
if quant_mode == 0: # quant_mode == 0 is per channel
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
else: # quant_mode != 1 (== 1 for extend) is per head
cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len),
dtype=torch.float, device="mlu")
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_linear_cache,
key, value, key_cache, value_cache,
key_cache_scale, value_cache_scale,
context_lens, max_context_len,
context_seq_offset if use_offset else None,
cache_bs_id, cache_seq_offset, quant_mode,
quant_bit, repeats=args.repeat_times)
content = [f"{max_batch_size}", f"{batch_size}", f"{max_context_len}", f"{head_num_q}",
f"{head_num_kv}", f"{cache_mem_len}", f"{head_size}", f"{dtype}", f"{quant_mode}",
f"{quant_bit}", f"{quant_mode}", f"{use_offset}", f"{hardware_time}",
f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,110 @@
import argparse
import math
import random
import os
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
from itertools import product
from tabulate import tabulate
e2e_time_param_dict_list = [
{"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
"head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "block_size": 16, "head_size": 128,
"input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [8],
"use_offset": True},
]
def main():
parser = argparse.ArgumentParser(description="Benchmark for dequant from paged cache.")
parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
if "MLU3" in torch.mlu.get_device_name():
print("Op dequant_from_paged_cache does not support MLU300 devices.")
return
titles = ["batch_size", "max_context_len", "head_num_q", "head_num_kv",
"cache_mem_len", "block_size", "head_size", "input_dytpe", "quant_mode", "quant_bit",
"use_offset", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch_size_list = params_dict["batch_size"]
max_context_len_list = params_dict["max_context_len"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
block_size = params_dict["block_size"]
head_size = params_dict["head_size"]
cache_mem_len = params_dict["cache_mem_len"]
input_dtype_list = params_dict["input_dtype"]
quant_mode_list = params_dict["quant_mode"]
quant_bit_list = params_dict["quant_bit"]
use_offset = params_dict["use_offset"]
for quant_mode, batch_size, max_context_len, quant_bit, dtype in list(product( \
quant_mode_list, batch_size_list, max_context_len_list, quant_bit_list, \
input_dtype_list)):
torch.manual_seed(2766)
torch.mlu.manual_seed(2766)
total_heads = head_num_q + head_num_kv * 2
assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
"equal to cache_mem_len."
max_seq_offset = cache_mem_len - max_context_len
max_block_num = int(math.ceil(max_context_len / block_size))
total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size
block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size,
max_block_num)
# Generates key and cache from context
context_lens = torch.randint(size=[batch_size], low=max_context_len, high=max_context_len + 1,
dtype=torch.int32, device="mlu")
if use_offset:
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
dtype=torch.int32, device="mlu")
else:
context_paddings = torch.zeros_like(context_lens)
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
total_seqlen = cu_context_lens[-1]
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
context_seq_offset[1:] = cu_context_lens[:-1]
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
key = context[..., head_num_q:head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
# Generates key_cache and value_cache
cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size),
low=-128, high=127, dtype=torch.int32, device="mlu")
cache = cache.to(torch.int8)
key_cache, value_cache = cache[[0, 1]]
# Generates key_cache_scale and value_cache_scale
if quant_mode == 0: # quant_mode == 0 is per channel
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
else: # quant_mode != 1 (== 1 for extend) is per head
cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size),
dtype=torch.float, device="mlu")
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_paged_cache,
key, value, key_cache, value_cache,
key_cache_scale, value_cache_scale,
context_lens, max_context_len,
context_seq_offset if use_offset else None,
block_tables, quant_mode,
quant_bit, repeats=args.repeat_times)
content = [f"{batch_size}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}",
f"{cache_mem_len}", f"{block_size}", f"{head_size}", f"{dtype}", f"{quant_mode}",
f"{quant_bit}", f"{use_offset}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,89 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
"gated_ffn": False, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
inner_size = params_dict["inner_size"]
gated_ffn = params_dict["gated_ffn"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
gate_up_proj_weight, gate_up_proj_bias = None, None
if gated_ffn:
gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.ffn,
input,
up_proj_weight,
up_proj_bias,
down_proj_weight,
down_proj_bias,
gate_up_proj_weight,
gate_up_proj_bias,
act_mode,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,92 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
# for e2e time test
e2e_time_param_dict_list = [{"batch": 1, "seq_q": 32768, "seq_kv": 32768, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 16384, "seq_kv": 16384, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 8192, "seq_kv": 24576, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 4096, "seq_kv": 28672, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 4096, "seq_kv": 32768, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_q", "seq_kv", "head_num", "head_num_kv", "head_size", "use_causal", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_q = params_dict["seq_q"]
seq_kv = params_dict["seq_kv"]
head_num = params_dict["head_num"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
use_causal = params_dict["use_causal"]
softmax_scale = params_dict["softmax_scale"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
if seq_q == seq_kv:
qkv = torch.randn(batch, seq_q, head_num + 2 * head_num_kv, head_size).to(dtype).to(device)
q = qkv[:, :, : head_num, :]
k = qkv[:, :, head_num : head_num + head_num_kv, :]
v = qkv[:, :, head_num + head_num_kv : head_num + head_num * 2, :]
elif seq_q < seq_kv:
q = torch.randn(batch, seq_q, head_num, head_size).to(device).to(dtype)
kv = torch.randn(batch, seq_kv, head_num_kv * 2, head_size).to(device).to(dtype)
k = kv[:, :, : head_num_kv, :]
v = kv[:, :, head_num_kv :, :]
hardware_time, e2e_time = benchmark_forward(tmo.flash_attention,
q = q,
k = k,
v = v,
out = None,
cu_seq_lens_q = None,
cu_seq_lens_kv = None,
alibi_slope = None,
attn_bias = None,
max_seq_len_q = seq_q,
max_seq_len_kv = seq_kv,
softmax_scale = softmax_scale,
is_causal = use_causal,
window_size_left = -1,
window_size_right = -1,
compute_dtype = dtype,
return_lse = False,
block_tables = None,
k_cache_quant_scale = None,
v_cache_quant_scale = None,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_q}", f"{seq_kv}", f"{head_num}", f"{head_num_kv}", f"{head_size}", f"{use_causal}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,103 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [
{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 8192, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 490, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 525, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "has_residual", "has_bias", "has_quant",
"dynamic_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
has_residual = params_dict["has_residual"]
has_bias = params_dict["has_bias"]
has_quant = params_dict["has_quant"]
dynamic_quant = params_dict["dynamic_quant"]
dtype = params_dict["input_dtype"]
eps = 1e-6
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
x = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
residual, bias, quant_scale = None, None, None
if has_residual:
residual = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
if has_bias:
bias = torch.randn(hidden_size, dtype=dtype, device=device)
if has_quant or dynamic_quant:
quant_scale = torch.randn(hidden_size, dtype=torch.float, device=device)
store_output_before_norm = has_residual
hardware_time, e2e_time = benchmark_forward(tmo.fused_layer_norm,
x,
residual,
gamma,
beta,
bias,
eps,
store_output_before_norm,
quant_scale,
None,
dynamic_quant,
repeats=args.repeat_times)
n = x.nelement()
sizeoft = x.element_size()
io_bytes = (sizeoft + 1) * n + \
(1 + store_output_before_norm) * (sizeoft * n if has_residual else 0) + \
sizeoft * hidden_size * 2 + \
(hidden_size * 4 if has_quant else 0) + \
(batch * seq_len * 4 if dynamic_quant else 0)
io_eff = io_bytes / hardware_time / bd
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{has_residual}", f"{has_bias}",
f"{has_quant}", f"{dynamic_quant}", f"{dtype}",
f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,143 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [
{"batch": 1, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 490, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 525, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "num_expert", "topk", "act_mode", "quant_weight", "dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
inner_size = params_dict["inner_size"]
gated_ffn = params_dict["gated_ffn"]
act_mode = params_dict["act_mode"]
num_expert = params_dict["num_expert"]
start_expert_id = params_dict["start_expert_id"]
expert_size = params_dict["expert_size"]
topk = params_dict["topk"]
has_residual = params_dict["has_residual"]
smooth_quant = params_dict["smooth_quant"]
renormalize = params_dict["renormalize"]
input_dtype_list = params_dict["dtype"]
# print(f"batch:{batch}, seq_len:{seq_len}, hidden_size:{hidden_size}, inner_size:{inner_size}, "
# f"gated_ffn:{gated_ffn}, act_mode:{act_mode}, num_expert:{num_expert}, topk:{topk}")
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
hidden_states = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
router_logit = torch.randn(batch, seq_len, num_expert, device=device, dtype=torch.float32)
if False: # print token_count
softmax = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
topk_logit, expert_id = torch.topk(softmax, k=topk, dim=1)
if renormalize:
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
sorted_expert_id, indices = expert_id.int().flatten().sort()
token_cout = torch.bincount(sorted_expert_id, minlength=num_expert).int()
print(token_cout)
residual = None
if has_residual:
residual = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
weight1 = torch.randn(num_expert, inner_size*(1+gated_ffn), hidden_size, device=device, dtype=dtype)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device=device, dtype=data_type)
weight2 = torch.randn(num_expert, hidden_size, inner_size, device=device, dtype=dtype)
bias2 = None # torch.randn(expert_num, hidden_size, device=device, dtype=data_type)
input_smooth, act_smooth, w1_scale, w2_scale = None, None, None, None
if smooth_quant:
input_smooth = torch.randn(expert_size, hidden_size, device=device, dtype=torch.float32).abs() + 0.1
act_smooth = torch.randn(expert_size, inner_size, device=device, dtype=torch.float32).abs() + 0.1
weight1 = torch.randint(-128, 127, (num_expert, inner_size*(1+gated_ffn), hidden_size)).to(torch.int8).mlu()
weight2 = torch.randint(-128, 127, (num_expert, hidden_size, inner_size)).to(torch.int8).mlu()
w1_scale = torch.randn(expert_size, (1+gated_ffn)*inner_size).to(device).to(torch.float32)
w2_scale = torch.randn(expert_size, hidden_size).to(device).to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.fused_moe,
hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth,
act_smooth,
w1_scale,
w2_scale,
topk,
renormalize,
gated_ffn,
act_mode,
start_expert_id,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{num_expert}", f"{topk}", f"{act_mode}", f"{smooth_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,71 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "head_size": 80, "hidden_size": 1600, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 2048, "head_size": 128, "hidden_size": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 4096, "head_size": 128, "hidden_size": 4096, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6144, "head_size": 128, "hidden_size": 6144, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6656, "head_size": 128, "hidden_size": 6656, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 8192, "head_size": 128, "hidden_size": 8192, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 12288, "head_size": 128, "hidden_size": 12288, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 14336, "head_size": 128, "hidden_size": 14336, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "input_size", "hidden_size", "head_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
input_size = params_dict["input_size"]
hidden_size = params_dict["hidden_size"]
head_size = params_dict["head_size"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, input_size).to(device).to(dtype)
weight = torch.randn(hidden_size * 3, input_size).to(device).to(dtype)
bias = torch.randn(hidden_size * 3).to(device).to(dtype)
weights = torch.chunk(weight, 3)
biases = torch.chunk(bias, 3)
norm_weight = torch.randn(input_size).to(device).to(dtype)
norm_bias = torch.randn(input_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_attention_project,
input,
weights[0],
biases[0],
weights[1],
biases[1],
weights[2],
biases[2],
norm_weight,
norm_bias,
1e-6,
'nthc',
head_size,
False,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{head_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,97 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
"gated_ffn": False, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "residual_is", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
inner_size = params_dict["inner_size"]
gated_ffn = params_dict["gated_ffn"]
residual_is = params_dict["residual_is"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
layernorm_weight = torch.randn(hidden_size).to(device).to(dtype)
layernorm_bias = torch.randn(hidden_size).to(device).to(dtype)
gate_up_proj_weight, gate_up_proj_bias = None, None
if gated_ffn:
gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_residual_ffn,
input,
up_proj_weight,
up_proj_bias,
down_proj_weight,
down_proj_bias,
gate_up_proj_weight,
gate_up_proj_bias,
layernorm_weight,
layernorm_bias,
1e-6,
act_mode,
residual_is,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{residual_is}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,90 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
from itertools import product
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 16, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 96, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 112, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["input_shape", "has_residual", "has_bias", "has_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
head_num = params_dict["head_num"]
head_size = params_dict["head_size"]
has_residual = params_dict["has_residual"]
has_quant = params_dict["has_quant"]
has_bias = params_dict["has_bias"]
eps = params_dict["eps"]
dynamic_quant_list = params_dict["dynamic_quant"]
input_dtype_list = params_dict["input_dtype"]
dynamic_quant_list = params_dict["dynamic_quant"]
input_dtype_list = params_dict["input_dtype"]
iters = product(dynamic_quant_list, input_dtype_list)
for dynamic_quant, dtype in iters:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
x = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
beta = torch.randn(head_size).to(dtype).to(device)
gamma = torch.randn(head_size).to(dtype).to(device)
residual, bias, quant_scale = None, None, None
if has_residual:
residual = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
if has_bias:
bias = torch.randn(head_size).to(dtype).to(device)
if has_quant or dynamic_quant:
quant_scale = torch.randn(head_size).to(device)
hardware_time, e2e_time = benchmark_forward(tmo.fused_rms_norm,
x,
residual,
gamma,
beta,
bias,
eps,
False,
quant_scale,
None,
dynamic_quant,
repeats=args.repeat_times)
content = [f"{batch, seq_len, head_num, head_size}", f"{has_residual}", f"{has_bias}", f"{has_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,207 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
"block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
bs = params_dict["batch"]
seq_len = params_dict["seq_len"]
q_heads = params_dict["head_num_q"]
kv_heads = params_dict["head_num_k"]
head_size = params_dict["head_size"]
rope_dim = params_dict["rotary_dim"]
quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
max_decode_len = 0
num_blocks = 0
block_size = 0
if paged_cache:
num_blocks = params_dict["num_blocks"]
block_size = params_dict["block_size"]
else:
max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
discrete_batch = True
max_bs = bs + 1 if discrete_batch else bs
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
input = torch.randn(size=input_shape, dtype=dtype).mlu()
input_ref = input.clone()
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
cache_dtype = dtype
if quant_kv:
k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
cache_dtype = torch.int8
k_scale_ops = 1 / k_scale
v_scale_ops = 1 / v_scale
else:
k_scale = None
v_scale = None
k_scale_ops = None
v_scale_ops = None
if paged_cache:
cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
else:
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
if quant_kv:
cache = (cache - 0.5) * 256
cache = cache.to(cache_dtype)
k_cache = cache[0]
v_cache = cache[1]
cache_bs_id = None
cache_seq_offsets = None
slot_mapping = None
if not paged_cache:
if discrete_batch:
cache_bs_id = random.sample([*range(0, max_bs)], bs)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
dtype=torch.int32, device='mlu')
else:
slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
slot_mapping = torch.IntTensor(slot_mapping).mlu()
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
k_cache_lp = None
v_cache_lp = None
k_scale_lp = None
v_scale_lp = None
cache_bs_id_lp = None
cache_seq_offsets_lp = None
if mixed_cache:
max_decode_len_lp = 1024
k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(k_cache_raw))
k_cache_raw = k_cache_raw * (7 / max_value)
max_value = torch.amax(torch.abs(v_cache_raw))
v_cache_raw = v_cache_raw * (7 / max_value)
k_cache_lp = k_cache_raw.to(torch.int8)
v_cache_lp = v_cache_raw.to(torch.int8)
k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
dtype=torch.int32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
input,
k_cache,
v_cache,
sin_table,
cos_table,
position_id,
gamma,
beta,
k_cache_lp,
v_cache_lp,
cache_bs_id,
cache_seq_offsets,
cache_bs_id_lp,
cache_seq_offsets_lp,
k_scale_ops,
v_scale_ops,
k_scale_lp,
v_scale_lp,
slot_mapping,
None,
1e-5,
repeats=args.repeat_times)
content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,117 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [
{"batch": 1, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 490, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 490, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 525, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 525, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 2048, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 2048, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "seq_len", "k", "n", "expert_num", "topk", "smooth_quant", "dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
k = params_dict["k"]
n = params_dict["n"]
expert_num = params_dict["expert_num"]
topk = params_dict["topk"]
is_quant = params_dict["is_quant"]
input_dtype_list = params_dict["dtype"]
# print(f"batch:{batch}, seq_len:{seq_len}, k:{k}, n:{n}, expert_num:{expert_num}, topk:{topk}, is_quant:{is_quant}")
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
max_m = batch * seq_len
m = batch * seq_len * topk
avg, rem = m // expert_num, m % expert_num
m_list = [avg + (i < rem) for i in range(expert_num)]
token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
if not is_quant:
a = torch.randn(m, k, dtype=dtype, device='mlu')
b = torch.randn(expert_num, n, k, dtype=dtype, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.group_gemm,
a, b, token_count,
None, None, None, None,
max_m,
repeats=args.repeat_times)
else:
a = torch.randint(-128, 127, (m, k)).to(torch.int8).mlu()
b = torch.randint(-128, 127, (expert_num, n, k)).to(torch.int8).mlu()
a_scale = torch.randn(a.size(0), dtype=torch.float32, device='mlu')
b_scale = torch.randn(expert_num, n, dtype=torch.float32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_group_gemm,
a, b, token_count,
None, None, None, None,
a_scale, b_scale, dtype, max_m,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{k}", f"{n}", f"{expert_num}", f"{topk}",
f"{is_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,75 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"m": 1024, "k": 1600, "n": 6400, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 2048, "n": 8192, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 4096, "n": 11008, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 4096, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 27392, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 6144, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 6656, "n": 17920, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 22016, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 28672, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 49152, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 12288, "n": 32768, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 14336, "n": 57344, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
has_c = params_dict["has_c"]
has_bias = params_dict["has_bias"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(m, k).to(device).to(dtype)
b = torch.randn(n, k).to(device).to(dtype)
c = None
if has_c:
c = torch.randn(m, n).to(device).to(dtype)
bias = None
if has_bias:
bias = torch.randn(n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.matmul,
a,
b,
bias,
c,
act_mode,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,117 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "inner_size": 1024,
"act_mode": "gelu", "is_gated": True, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": False, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 32, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 64, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 256, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 512, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},]
def gen_data(num_expert,
total_tokens,
inner_size,
output_stride,
dtype,
is_gated,
has_bias,
is_ep):
ci = inner_size * (1 + is_gated)
input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu')
cusum_token_count, token_count = generate_token_count(num_expert, total_tokens)
output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu')
output.as_strided(output.size(), (output_stride, 1))
start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0
expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert
bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None
return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["input_shape", "act_mode", "is_gated", "has_bias", "expert_num", "start_expert_id",
"expert_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
inner_size = params_dict["inner_size"]
act_mode = params_dict["act_mode"]
is_gated = params_dict["is_gated"]
input_dtype_list = params_dict["input_dtype"]
has_bias = params_dict["has_bias"]
is_ep = params_dict["is_ep"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
expert_num = expert_num = random.randint(1, 256)
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(expert_num, batch * seq_len, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
real_bias = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
hardware_time, e2e_time = benchmark_forward(tmo.moe_active,
input,
act_mode,
is_gated,
output,
real_bias,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated) + \
real_bias.element_size() * real_bias.nelement() + \
(cusum_token_count.element_size() * cusum_token_count.nelement()) if has_bias or is_ep else 0
io_eff = io_bytes / hardware_time / bd
content = [f"{batch,seq_len,inner_size}", f"{act_mode}", f"{is_gated}", f"{has_bias}", f"{expert_num}",
f"{start_expert_id}", f"{expert_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,59 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 16, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 128, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 512, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "seq_len", "hidden_size", "expert_num", "input_dtype", "hardware_time(us)",
"e2e_latency(us)", "IO efficiency"]
contents = []
bandwidth = get_band_width()
for param_dict in e2e_time_param_dict_list:
batch = param_dict["batch"]
seq_len = param_dict["seq_len"]
hidden_size = param_dict["hidden_size"]
expert_num = param_dict["expert_num"]
input_dtype = param_dict["input_dtype"]
if input_dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
input_dtype = torch.half
input = torch.randn(batch, seq_len, hidden_size, dtype=input_dtype, device="mlu")
weight = torch.randn(expert_num, hidden_size, dtype=torch.float32, device="mlu")
hardware_time, e2e_time = benchmark_forward(tmo.moe_cast_gating,
input,
weight)
io_bytes = batch * seq_len * hidden_size * input.element_size() + \
expert_num * hidden_size * weight.element_size() + batch * seq_len * expert_num * weight.element_size()
io_coeff = io_bytes / hardware_time / bandwidth
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{expert_num}", f"{input_dtype}",
f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,166 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [
{"num_tokens": 16, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 128, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 490, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 525, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 2048, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 4096, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 8192, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 32768, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
]
def gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
has_bias,
has_residual,
dtype,
device):
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device)
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device)
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device)
bias = None
residual = None
cusum_token_count = None
if has_bias:
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device)
if has_residual:
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
if has_bias or expert_size < num_expert:
cusum_token_count, _ = generate_token_count(num_expert, num_tokens * topk)
cusum_token_count = cusum_token_count.to(device=device)
return input, reduce_weight, gather_ids, residual, bias, cusum_token_count
def get_io_bytes(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
start_expert_id,
has_bias,
has_residual,
dtype,
cusum_token_count,
gather_ids):
io_bytes = 0
dtype_size = 4 if dtype is torch.float32 else 2
if cusum_token_count is not None:
filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \
(gather_ids < cusum_token_count[start_expert_id + expert_size])
filtered_ids = filtered_ids.to(dtype=torch.float32)
io_bytes += torch.sum(filtered_ids).item() * hidden_size * dtype_size
else:
io_bytes += num_tokens * topk * hidden_size * dtype_size
if has_bias:
io_bytes += expert_size * hidden_size * dtype_size
if has_residual:
io_bytes += num_tokens * hidden_size * dtype_size
io_bytes += num_tokens * topk * 4
io_bytes += num_tokens * hidden_size * dtype_size
return io_bytes
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["num_tokens", "num_expert", "topk", "start_expert_id", "expert_size", \
"hidden_size", "has_residual", "dtype", "hardware_time(us)", "e2e_latency(us)", "io_coeff"]
contents = []
bandwidth = get_band_width()
for params_dict in e2e_time_param_dict_list:
num_tokens = params_dict["num_tokens"]
num_expert = params_dict["num_expert"]
topk = params_dict["topk"]
start_expert_id = params_dict["start_expert_id"]
expert_size = params_dict["expert_size"]
has_residual = params_dict["has_residual"]
hidden_size = params_dict["hidden_size"]
dtype_list = params_dict["dtype"]
for dtype in dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
inputs = gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
False,
has_residual,
dtype,
device)
input = inputs[0]
reduce_weight = inputs[1]
gather_ids = inputs[2]
residual = inputs[3]
bias = inputs[4]
cusum_token_count = inputs[5]
io_bytes = get_io_bytes(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
start_expert_id,
False,
has_residual,
dtype,
cusum_token_count,
gather_ids)
hardware_time, e2e_time = benchmark_forward(tmo.moe_combine_result, input, reduce_weight,
gather_ids,residual, cusum_token_count,
start_expert_id, expert_size,
repeats=args.repeat_times)
io_coeff = io_bytes / hardware_time / bandwidth
content = [f"{num_tokens}", f"{num_expert}", f"{topk}", f"{start_expert_id}", \
f"{expert_size}", f"{hidden_size}", f"{has_residual}", f"{dtype}", \
f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,93 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import numpy as np
e2e_time_param_dict_list = [{"token_num": 1, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 16, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 32, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 64, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 128, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 512, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 1024, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 4096, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 8192, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 32768, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}]
def gen_tensor(token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype):
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,)).to(torch.int32).to('mlu')
cusum_token_count, _ = generate_token_count(expert_num, token_num * topk)
cusum_token_count = cusum_token_count.to('mlu')
use_all_experts = expert_num == expert_size
if use_all_experts:
cusum_token_count = None
real_token_count = token_num * topk
else:
real_token_count = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
return input, gather_idx, cusum_token_count, real_token_count
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["token_num", "hidden_size", "expert_num", "topk", "start_expert_id", "expert_size", "input_dtype",
"hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
token_num = params_dict["token_num"]
hidden_size = params_dict["hidden_size"]
expert_num = params_dict["expert_num"]
topk = params_dict["topk"]
start_expert_id = params_dict["start_expert_id"]
expert_size = params_dict["expert_size"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input, gather_idx, cusum_token_count, real_token_count = \
gen_tensor(token_num, hidden_size, expert_num,topk, start_expert_id, expert_size, dtype)
hardware_time, e2e_time = benchmark_forward(tmo.moe_expand_input,
input,
gather_idx,
cusum_token_count,
start_expert_id,
expert_size,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() + \
gather_idx.element_size() * gather_idx.nelement() + \
(cusum_token_count.element_size() * cusum_token_count.nelement() if cusum_token_count is not None else 0) + \
real_token_count * input.element_size()
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{hidden_size}", f"{expert_num}", f"{topk}", f"{start_expert_id}", f"{expert_size}",
f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,69 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"token_num": 1, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 16, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 32, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 64, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 512, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 1024, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 4096, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 8192, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 32767, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 1, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 16, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 32, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 64, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 512, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 1024, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 4096, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 8192, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 32767, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["token_num", "expert_num", "topk", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
token_num = params_dict["token_num"]
expert_num = params_dict["expert_num"]
topk = params_dict["topk"]
dtype = params_dict["input_dtype"]
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
gather_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
combine_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
token_count = torch.empty((expert_num), dtype=dtype, device='mlu')
cusum_token_count = torch.empty((expert_num + 1), dtype=dtype, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.moe_gen_idx,
expert_id,
expert_num,
repeats=args.repeat_times)
io_bytes = expert_id.element_size() * expert_id.nelement() + \
gather_idx.element_size() * gather_idx.nelement() + \
combine_idx.element_size() * combine_idx.nelement() + \
token_count.element_size() * token_count.nelement() + \
cusum_token_count.element_size() * cusum_token_count.nelement()
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{expert_num}", f"{topk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,114 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
params_dict = [
{"token_num": 1, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 16, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 128, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 490, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 512, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 525, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 2048, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 4096, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 8192, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 32768, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 1, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 16, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 128, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 490, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 512, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 525, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 2048, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 4096, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 8192, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 32768, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["token_num", "hidden_size", "expert_num", "topk", "has_gather_idx", "dtype",
"hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for param in params_dict:
token_num, hidden_size, expert_num, topk, has_gather_idx, dtype = param.values()
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
if "MLU3" in torch.mlu.get_device_name():
has_gather_idx = False
expand_token_num = token_num * topk
input_shape = (token_num if has_gather_idx else expand_token_num, hidden_size)
input = torch.randn(input_shape).to(device).to(dtype)
scale = torch.randn(expert_num, hidden_size).to(device).to(torch.float32)
avg, rem = expand_token_num // expert_num, expand_token_num % expert_num
m_list = [avg + (i < rem) for i in range(expert_num)]
token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
if has_gather_idx:
gather_idx = torch.arange(0, token_num).repeat([topk])
gather_idx = gather_idx[torch.randperm(gather_idx.size(0))].to(torch.int32).mlu()
else:
gather_idx = None
hardware_time, e2e_time = benchmark_forward(tmo.moe_quantize,
input,
scale,
None,
token_count,
gather_idx,
None,
None,
None,
True,
repeats=args.repeat_times)
expand_num = topk if has_gather_idx else 1
io_bytes = (input.element_size() + 1) * input.nelement() * expand_num
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{hidden_size}", f"{expert_num}",
f"{topk}", f"{has_gather_idx}", f"{dtype}",
f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,87 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 32, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 72, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 8192, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 32768, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 2, "seq_len": 16, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 2, "seq_len": 36, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 8, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 16, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 4, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 2, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 16, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 16, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 64, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1024, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 2048, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 8192, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 32768, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["num_batch", "seq_len", "num_expert", "topk", "num_expert_group", "topk_group", "normalize", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
num_batch = params_dict["num_batch"]
seq_len = params_dict["seq_len"]
num_expert = params_dict["num_expert"]
topk = params_dict["topk"]
num_expert_group = params_dict["num_expert_group"]
topk_group = params_dict["topk_group"]
normalize = params_dict["normalize"]
dtype = params_dict["input_dtype"]
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
input = torch.randn(num_batch, seq_len, num_expert, dtype=dtype, device='mlu')
mask = torch.randint(0, 2, (1, seq_len, num_expert), dtype = dtype, device='mlu')
if num_expert_group > 1:
mask = None
normed_by = "softmax_logit"
reduce_weight = torch.empty(num_batch, topk, dtype=torch.float, device='mlu')
expert_id = torch.empty(num_batch, topk, dtype=torch.int32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.moe_softmax_topk,
input,
topk,
normalize,
num_expert_group,
topk_group,
mask,
normed_by,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() + \
reduce_weight.element_size() * reduce_weight.nelement() + \
expert_id.element_size() * expert_id.nelement()
io_eff = io_bytes / hardware_time / bd
content = [f"{num_batch}", f"{seq_len}", f"{num_expert}", f"{topk}", f"{num_expert_group}", f"{topk_group}", f"{normalize}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,145 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"}
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "quantize_mode", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
max_batch = params_dict["max_batch"]
batch = params_dict["batch"]
cache_mem_len = params_dict["cache_mem_len"]
max_context_len = params_dict["max_context_len"]
max_seq_offset = params_dict["max_seq_offset"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
packed = params_dict["packed"]
quantize_mode = params_dict["quantize_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
context_lens = torch.randint(size=(batch, ), low=max_context_len,
high=max_context_len+1,
dtype=torch.int32, device='mlu')
# max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
high=(cache_mem_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
total_heads = head_num_q + 2 * head_num_kv
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
else:
context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
context = context.to(dtype)
cache = cache.to(dtype)
key = context[..., head_num_q : head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
key_cache = cache[0]
value_cache = cache[1]
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
key_cache = (key_cache - 0.5) * 256
value_cache = (value_cache - 0.5) * 256
key_cache = key_cache.to(torch.int8)
value_cache = value_cache.to(torch.int8)
if packed > 0:
if quantize_mode == "per_channel":
key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
cu_context_lens, max_context_len, 0,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
elif quantize_mode == "per_head":
key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
cu_context_lens, max_context_len, 1,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
else:
if quantize_mode == "per_channel":
key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
context_lens, max_context_len, 0,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
elif quantize_mode == "per_head":
key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
context_lens, max_context_len, 1,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{quantize_mode}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,73 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from itertools import product
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [
{"token_num": [1024, 2048, 3072, 4096], "head_num_kv": 1, "head_size": 128, "block_size": 16,
"input_dtype": [torch.float16, torch.bfloat16]},
{"token_num": [1024 * 32, 2048 * 32, 3072 * 32, 4096 * 32], "head_num_kv": 1, "head_size": 128, "block_size": 16,
"input_dtype": [torch.float16, torch.bfloat16]},
{"token_num": [1024 * 64, 2048 * 64, 3072 * 64, 4096 * 64], "head_num_kv": 1, "head_size": 128, "block_size": 16,
"input_dtype": [torch.float16, torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["token_num", "head_num_kv", "head_size", "block_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
if "MLU3" in torch.mlu.get_device_name():
print("Op offline_quant_to_paged_cache does not support MLU300 devices.")
return
for params_dict in e2e_time_param_dict_list:
token_num_list = params_dict["token_num"]
# block_num = params_dict["block_num"]
block_size = params_dict["block_size"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
input_dtype_list = params_dict["input_dtype"]
for token_num, dtype in product(token_num_list, input_dtype_list):
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
block_num = (token_num + block_size - 1) // block_size
key = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
value = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
key_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
value_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
num_slots = block_num * block_size
slot_mapping = random.sample(range(num_slots), token_num)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu")
key_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
value_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_paged_cache,
key, value,
key_cache_scale, value_cache_scale,
slot_mapping,
key_cache, value_cache,
repeats=args.repeat_times)
content = [f"{token_num}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,61 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
import csv
params_dict = {
"token_num": [n * 5 for n in [1, 72, 512, 1024, 4096, 32768]],
"hidden_size": [1024, 8192],
"input_dtype": [torch.float16]
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
params_list = product(params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
bd = get_band_width()
for params in params_list:
token_num, hidden_size = params[0], params[1]
input_shape = (token_num, hidden_size)
smooth_shape = (hidden_size)
dtype = params[2]
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(input_shape).to(device).to(dtype)
smooth = torch.randn(smooth_shape).to(device).to(torch.float32)
zero = None
token_count = None
hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
input,
smooth,
zero,
token_count,
repeats=args.repeat_times)
io_bytes = (input.element_size() + 1) * input.nelement() + \
smooth.element_size() * smooth.nelement() + \
token_num * 4
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,46 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"input_shape": [100, 100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
{"input_shape": [100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
{"input_shape": [50, 50, 50], "input_dtype": [torch.float16, torch.bfloat16]},
{"input_shape": [1, 100, 1000], "input_dtype": [torch.float16, torch.bfloat16]}
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["input_shape", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
input_shape = params_dict["input_shape"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(input_shape).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.preload,
input,
input.element_size() * input.numel(),
repeats=args.repeat_times)
content = [f"{input_shape}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,201 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import random
import os
e2e_time_param_dict_list = [
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "max_context_len", "head_num_kv", "head_size", "packed", "input_dtype",
"quant_bit", "group_size", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bandwidth = get_band_width()
for params_dict in e2e_time_param_dict_list:
max_batch = params_dict["max_batch"]
batch = params_dict["batch"]
cache_mem_len = params_dict["cache_mem_len"]
max_context_len = params_dict["max_context_len"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
packed = params_dict["packed"]
input_dtype_list = params_dict["input_dtype"]
quant_bit = params_dict["quant_bit"]
group_size = params_dict["group_size"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
context_lens = torch.tensor([max_context_len] * batch).to(torch.int32).mlu()
context_seq_offsets = torch.zeros(batch, dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch, ),
low=0,
high = 1 if max_context_len > 1 else cache_mem_len,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
key = torch.randn((total_seqlen, head_num_kv, head_size),
dtype=torch.float, device='mlu')
value = torch.randn((total_seqlen, head_num_kv, head_size),
dtype=torch.float, device='mlu')
else:
key = torch.randn((batch, max_context_len, head_num_kv, head_size),
dtype=torch.float, device='mlu')
value = torch.randn((batch, max_context_len, head_num_kv, head_size),
dtype=torch.float, device='mlu')
key = key.to(dtype)
value = value.to(dtype)
if quant_bit == 8 and group_size == head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
if quant_bit == 8 and group_size != head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
if quant_bit == 4 and group_size == head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
if quant_bit == 4 and group_size != head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
if packed > 0:
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_scale,
value_cache_scale,
cu_context_lens, max_context_len,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
quant_bit,
repeats=args.repeat_times)
else:
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_scale,
value_cache_scale,
context_lens, max_context_len,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
quant_bit,
repeats=args.repeat_times)
io_bytes = key.nelement() * (key.element_size() + 1) * 2
io_eff = io_bytes / hardware_time / bandwidth
content = [f"{batch}", f"{max_context_len}", f"{head_num_kv}", f"{head_size}", f"{packed}",
f"{dtype}", f"{quant_bit}", f"{group_size}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,61 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
params_dict = {"dynamic": [True],
"token_num": [1, 72, 490, 512, 525, 1024, 4096, 8192, 32768],
"hidden_size": [8192, 1024],
"input_dtype": [torch.bfloat16]}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["dynamic", "token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
params_list = product(params_dict["dynamic"], params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
bd = get_band_width()
for param in params_list:
dynamic, token_num, hidden_size, dtype = param[0], param[1], param[2], param[3]
input_shape = (token_num, hidden_size)
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
input = torch.randn(input_shape).to(device).to(dtype)
scale = torch.randn(input_shape[-1]).to(device).to(torch.float32)
zero = None
if dynamic:
hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
input,
scale,
zero,
None,
repeats=args.repeat_times)
else:
hardware_time, e2e_time = benchmark_forward(tmo.quantize,
input,
scale,
zero,
repeats=args.repeat_times)
io_bytes = (input.element_size() + 1) * input.nelement() + scale.element_size() * scale.nelement()
io_eff = io_bytes / hardware_time / bd
content = [f"{dynamic}", f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,109 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16]}
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
max_batch = params_dict["max_batch"]
batch = params_dict["batch"]
cache_mem_len = params_dict["cache_mem_len"]
max_context_len = params_dict["max_context_len"]
max_seq_offset = params_dict["max_seq_offset"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
packed = params_dict["packed"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
context_lens = torch.randint(size=(batch, ), low=max_context_len,
high=max_context_len+1,
dtype=torch.int32, device='mlu')
# max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
high=(cache_mem_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
total_heads = head_num_q + 2 * head_num_kv
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
else:
context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
context = context.to(dtype)
cache = cache.to(dtype)
key = context[..., head_num_q : head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
key_cache = cache[0]
value_cache = cache[1]
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
if packed > 0:
hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
key, value,
key_cache, value_cache,
cu_context_lens, max_context_len,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
else:
hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
key, value,
key_cache, value_cache,
context_lens, max_context_len,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,76 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
"head_num_kv": 32, "head_size": 128, "quantize": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
"head_num_kv": 32, "head_size": 128, "quantize": False, "input_dtype": [torch.float16, torch.bfloat16]}
]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["num_tokens", "num_block", "block_size", "head_num_q", "head_num_kv", "head_size", "input_dytpe", "quantize", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
num_tokens = params_dict["num_tokens"]
num_blocks = params_dict["num_block"]
block_size = params_dict["block_size"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
quantize = params_dict["quantize"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
qkv = torch.randn(num_tokens, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
key = qkv[:, head_num_q : head_num_q + head_num_kv, :]
value = qkv[:, head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
key_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
value_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
num_slots = num_blocks * block_size
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1
if not quantize:
hardware_time, e2e_time = benchmark_forward(tmo.reshape_paged_cache,
key, value,
key_cache, value_cache,
slot_mapping,
repeats=args.repeat_times)
else:
k_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
v_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_paged_cache,
key, value,
key_cache, value_cache,
k_cache_quant_scale,
v_cache_quant_scale,
slot_mapping,
repeats=args.repeat_times)
content = [f"{num_tokens}", f"{num_blocks}", f"{block_size}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{dtype}", f"{quantize}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,116 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import math
import random
e2e_time_param_dict_list = [
{"batch": 16, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 128, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 512, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 16, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 128, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "max_seq_len", "head_num_q", "head_num_kv", "head_size", "block_size", "alibi_bias", "kv_cache_dtype", "use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
max_seq_len = params_dict["max_seq_len"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
block_size = params_dict["block_size"]
alibi_bias = params_dict["alibi_bias"]
kv_cache_dtype = params_dict["kv_cache_dtype"]
use_paged_attn = params_dict["use_paged_attn"]
input_dtype_list = params_dict["input_dtype"]
is_pertoken = params_dict["is_pertoken"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
input_q = input_qkv[..., 0 : head_num_q, :]
context_lens = torch.randint(max_seq_len, max_seq_len + 1, (batch, ), dtype=torch.int32).to(device)
max_context_len = int(max(context_lens))
if use_paged_attn:
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
continue
block_size = 16
else:
block_size = max_seq_len + 512
num_blocks = batch * ((max_seq_len + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape = (num_blocks, head_num_kv, block_size, head_size)
scale_shape = (num_blocks, head_num_kv, block_size) if is_pertoken else (head_num_kv, head_size)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
if kv_cache_dtype is not torch.int8:
key_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
value_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
key_cache_scale = None
value_cache_scale = None
else:
key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
alibi_slopes = None
if alibi_bias:
alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
hardware_time, e2e_time = benchmark_forward(tmo.single_query_cached_kv_attn,
input_q,
key_cache,
value_cache,
None,
block_tables,
context_lens,
key_cache_scale,
value_cache_scale,
alibi_slopes,
max_context_len,
-1,
-1,
softmax_scale,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_seq_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{alibi_bias}", f"{kv_cache_dtype}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,131 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import math
import random
e2e_time_param_dict_list = [{"batch": 16, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": True, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 16, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}]
def gen_cache(batch, num_kv_heads, head_size, is_pagedattn, max_context_len, data_type, quant_bit, quant_mode):
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
context_lens = torch.randint(max_context_len, max_context_len + 1, (batch, ), dtype=torch.int32).mlu()
block_size = 16
if is_pagedattn is False:
block_size = max_context_len
num_blocks = (int)(batch * ((max_context_len + block_size - 1)/ block_size))
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if quant_mode == "per_token":
scale_shape = (num_blocks, num_kv_heads, block_size, 1)
else: # per channel
scale_shape = (num_kv_heads, head_size)
if quant_bit == 4:
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2)
cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size)
key_cache = torch.zeros(cache_shape_k_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape_v_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
elif quant_bit == 8:
key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
elif quant_bit == -1:
key_cache = torch.randn(cache_shape, dtype=data_type).mlu()
value_cache = torch.randn(cache_shape, dtype=data_type).mlu()
key_cache_scale = None
value_cache_scale = None
else:
print("!!!!!!!!!!!gen case error, quant_bit must be in {-1, 4, 8}")
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
return key_cache, value_cache, key_cache_scale, value_cache_scale, context_lens, block_tables
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "max_seq_len_lp", "max_seq_len_hp", "head_num_q", "head_num_kv", "head_size", "alibi_bias", "quant_bit_lp", "quant_bit_hp","use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
max_seq_len_lp = params_dict["max_seq_len_lp"]
max_seq_len_hp = params_dict["max_seq_len_hp"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
alibi_bias = params_dict["alibi_bias"]
quant_bit_lp = params_dict["quant_bit_lp"]
quant_bit_hp = params_dict["quant_bit_hp"]
use_paged_attn = params_dict["use_paged_attn"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
input_q = input_qkv[..., 0 : head_num_q, :]
params_lp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_lp, dtype, quant_bit_lp, "per_token")
params_hp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_hp, dtype, quant_bit_hp, "per_channel")
key_cache_lp, value_cache_lp, key_cache_scale_lp, value_cache_scale_lp, context_lens_lp, block_tables_lp = params_lp
key_cache_hp, value_cache_hp, key_cache_scale_hp, value_cache_scale_hp, context_lens_hp, block_tables_hp = params_hp
alibi_slopes = None
if alibi_bias:
alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
hardware_time, e2e_time = benchmark_forward(tmo.single_query_mixed_cached_kv_attn,
input_q,
key_cache_lp, value_cache_lp,
key_cache_hp, value_cache_hp,
None, #output
block_tables_lp, block_tables_hp,
context_lens_lp, context_lens_hp,
key_cache_scale_lp, value_cache_scale_lp,
key_cache_scale_hp, value_cache_scale_hp,
alibi_slopes,
max_seq_len_lp, max_seq_len_hp,
softmax_scale, True,
quant_bit_lp, quant_bit_hp,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_seq_len_lp}", f"{max_seq_len_hp}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{alibi_bias}", f"{quant_bit_lp}", f"{quant_bit_hp}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,69 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
"act_mode": "none", "output_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
"act_mode": "silu", "output_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "output_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
has_c = params_dict["has_c"]
has_bias = params_dict["has_bias"]
act_mode = params_dict["act_mode"]
output_dtype_list = params_dict["output_dtype"]
for dtype in output_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(m, k).to(device).to(torch.int8)
b = torch.randn(n, k).to(device).to(torch.int8)
a_scale = torch.randn(m).to(device)
b_scale = torch.randn(n).to(device)
c = None
if has_c:
c = torch.randn(m, n).to(device).to(dtype)
bias = None
if has_bias:
bias = torch.randn(n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_matmul,
a,
a_scale,
b,
b_scale,
dtype,
bias,
c,
act_mode,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,15 @@
#!/bin/bash
LOG_PATH=${LOG_PATH:-.}
files=($(ls benchmark_*.py))
for file in "${files[@]}"; do
echo "test ${file}..."
op_name=$(basename "$file" .py)
python "$file" > ${LOG_PATH}/${op_name}.log 2>&1
ret_tmp=$?
cat ${LOG_PATH}/${op_name}.log
if [ $ret_tmp != 0 ]; then
echo "${sc} test failed..."
exit $ret_tmp
fi
done

View File

@@ -0,0 +1,99 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 16, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 128, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 512, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 1024, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 1024, "max_seq_len": 1024,
"dtype": [torch.float16], "pack": True},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 2048, "max_seq_len": 2048,
"dtype": [torch.float16], "pack": True},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 8192, "max_seq_len": 8192,
"dtype": [torch.float16], "pack": True},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 32768, "max_seq_len": 32768,
"dtype": [torch.float16], "pack": True},]
def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack):
if not pack:
out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype)
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype)
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
seq_offset = None
cu_seqs = None
block_cu_seqs = None
else:
seq_lens = torch.randint(low=max_seq_len, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32)
block_seq_lens = torch.randint(low=block_seq_len, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32)
block_seq_lens = torch.minimum(seq_lens, block_seq_lens)
seq_offset = torch.zeros_like(seq_lens)
for i in range(batch):
seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32)
seq_offset = seq_offset.mlu()
cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu()
block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu()
total_seqs = torch.sum(seq_lens)
block_total_seqs = torch.sum(block_seq_lens)
out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype)
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype)
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "head_num", "head_size", "block_seq_len", "max_seq_len", "dtype", "pack", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
head_num = params_dict["head_num"]
head_size = params_dict["head_size"]
block_seq_len = params_dict["block_seq_len"]
max_seq_len = params_dict["max_seq_len"]
dtype_list = params_dict["dtype"]
pack = params_dict["pack"]
for dtype in dtype_list:
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
hardware_time, e2e_time = benchmark_forward(tmo.update_out_and_lse,
out,
lse,
block_out,
block_lse,
seq_offset,
cu_seqs,
block_cu_seqs,
repeats=args.repeat_times)
content = [f"{batch}", f"{head_num}", f"{head_size}", f"{block_seq_len}", f"{max_seq_len}", f"{dtype}", f"{pack}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,70 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
"act_mode": "none", "quant_bit": 8, "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
"act_mode": "silu", "quant_bit": 4, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "quant_bit", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
quant_bit = params_dict["quant_bit"]
has_c = params_dict["has_c"]
has_bias = params_dict["has_bias"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(m, k).to(device).to(dtype)
b = torch.randn(n, k if quant_bit == 8 else k//2).to(device).to(torch.int8)
scale = torch.randn(n).to(device)
zero = None
c = None
if has_c:
c = torch.randn(m, n).to(device).to(dtype)
bias = None
if has_bias:
bias = torch.randn(n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.weight_only_quant_matmul,
a,
b,
scale,
zero,
bias,
c,
act_mode,
quant_bit,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{quant_bit}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,65 @@
import torch
import torch_mlu
import time
from pathlib import Path
import csv
import os
import subprocess
from itertools import product
def benchmark_forward(fn, *inputs, repeats=1, **kwinputs):
notify_start = torch.mlu.Event(enable_timing=True)
notify_end = torch.mlu.Event(enable_timing=True)
notify_start.record()
t0 = time.perf_counter()
for _ in range(repeats):
fn(*inputs, **kwinputs)
notify_end.record()
notify_end.synchronize()
total_e2e_time = time.perf_counter() - t0
average_e2e_time = total_e2e_time / repeats * 1e6
total_hardware_time = notify_start.hardware_time(notify_end)
average_hardware_time = total_hardware_time / repeats
return average_hardware_time, average_e2e_time
def save_to_csv(table, file_path, file_name):
file_name_without_ext, _ = os.path.splitext(file_name)
new_file_name = file_name_without_ext + '.csv'
if file_path is None:
file_path = './'
path = Path(file_path)
if path.suffix:
directory = path.parent
filename = path.name
else:
directory = path
filename = new_file_name
if not directory.exists():
directory.mkdir(parents=True, exist_ok=True)
full_path = directory / filename
if not full_path.exists():
full_path.touch()
with open(full_path, mode="w", newline="") as file:
writer = csv.writer(file)
writer.writerows(table)
print(f"output saved at: {full_path}")
def get_band_width(card_id: int = 0):
cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2"
res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert res.returncode == 0, "Failed to get BandWidth."
bd = int(res.stdout.decode().strip())
return bd
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]