+
+Torch-MLU-Ops
+===========================
+
+
+
+
+
+## 目录
+
+- [Torch-MLU-Ops](#torch-mlu-ops)
+ - [目录](#目录)
+ - [简介](#简介)
+ - [安装](#安装)
+ - [环境要求](#环境要求)
+ - [通过docker镜像启动](#通过docker镜像启动)
+ - [编译安装](#编译安装)
+ - [测试](#测试)
+ - [目录结构](#目录结构)
+ - [关键特性](#关键特性)
+
+## 简介
+
+Torch-MLU-Ops是寒武纪设计和开发的PyTorch第三方算子库。对于使用PyTorch框架的开发者,通过Torch-MLU-Ops,能够便捷地使用这些自定义算子,进行算子的集成、评测和业务部署。
+
+Torch-MLU-Ops已全量覆盖LLM(Large Language Model)推理场景下的常见算子。作为后端,已支持Cambricon vLLM、Cambricon TGI、Cambricon Stable Diffusion web UI、Cambricon ComfyUI以及Cambricon Diffusers。
+
+
+## 安装
+
+### 环境要求
+
+Torch-MLU-Ops支持的操作系统以及软件依赖如下。
+
+不同平台支持的软件版本如下:
+
+| 操作系统 | 编译器版本 | CMake版本 |
+| ------------------------| ----------------------| -------------------------- |
+| Ubuntu22.04 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 |
+| Debian10.11 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 |
+
+
+编译、运行Torch-MLU-Ops同时还需要配置以下依赖。
+
+| Torch-MLU-Ops | Torch-MLU | CNNL | CNNL_Extra | CNCL | CNToolkit |
+| -----------------| ------------------| ---------| --------------- | --------| -----------|
+| v1.3.2 | v1.24.1 | v1.28.3 | v1.12.3 | v1.24.1 | v3.15.6 |
+| v1.3.1 | v1.24.1 | v1.28.3 | v1.12.2 | v1.24.1 | v3.15.5 |
+| v1.3.0 | v1.24.0 | v1.28.2 | v1.12.1 | v1.24.0 | v3.15.4 |
+| v1.2.3 | v1.24.0 | v1.28.1 | v1.12.1 | v1.24.0 | v3.15.3 |
+| v1.2.2 | v1.23.1 | v1.27.4 | v1.11.3 | v1.23.0 | v3.14.3 |
+| v1.2.1 | v1.23.1 | v1.27.3 | v1.11.3 | v1.22.3 | v3.14.2 |
+| v1.2.0 | v1.23.0 | v1.27.1 | v1.11.1 | v1.22.1 | v3.14.1 |
+| v1.1.4 | v1.22.2 | v1.26.7 | v1.10.4 | v1.21.1 | v3.13.7 |
+| v1.1.3 | v1.22.2 | v1.26.6 | v1.10.3 | v1.21.1 | v3.13.5 |
+
+
+此外运行Torch-MLU-Ops还依赖PyTorch环境。Torch-MLU-Ops已支持的PyTorch版本请参考。
+
+| Torch-MLU-Ops | PyTorch |
+| -----------------| ------------------|
+| v1.3.2 | v2.1, v2.4, v2.5 |
+| v1.3.1 | v2.1, v2.4, v2.5 |
+| v1.3.0 | v2.1, v2.4, v2.5 |
+| v1.2.3 | v2.1, v2.4, v2.5 |
+| v1.2.2 | v2.1, v2.4, v2.5 |
+| v1.2.1 | v2.1, v2.4, v2.5 |
+| v1.2.0 | v2.1, v2.4, v2.5 |
+| v1.1.4 | v2.1, v2.3, v2.4 |
+| v1.1.3 | v2.1, v2.3 |
+
+### 通过docker镜像启动
+
+寒武纪提供的解决方案镜像已提供了Torch-MLU-Ops所需要的依赖,请参考Cambricon PyTorch Container Image使用。您可以从[寒武纪开发者社区](https://account.cambricon.com/interaction/SMmmcGFmw837gghtZ7VR4)获取该镜像。
+
+### 编译安装
+
+Torch-MLU-Ops编译脚本依赖于`NEUWARE_HOME`环境变量,该环境变量指向寒武纪CNToolkit的安装路径,通常为`/usr/local/neuware`。
+
+在运行示例程序前,您需要将`NEUWARE_HOME`中的`lib64`目录添加到`LD_LIBRARY_PATH`环境变量中,例如:
+
+```shell
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NEUWARE_HOME/lib64
+```
+
+如果您使用寒武纪提供的docker镜像,以上环境变量均已经被设置好了。
+
+设置好`NEUWARE_HOME`环境变量后,您可以使用以下命令编译Torch-MLU-Ops。
+
+从远端clone仓库,进入项目的主目录:
+```shell
+cd torch_mlu_ops
+```
+
+源码安装:
+```shell
+pip install -e .
+```
+
+wheel包安装:
+```shell
+python setup.py bdist_wheel
+pip install dist/torch_mlu_ops*.whl
+```
+
+> _注意,如果安装过程中遇到权限问题,请获取相关权限或切换到拥有权限的user执行pip install 这步操作。_
+
+移除Torch-MLU-Ops(Optional):
+```shell
+pip uninstall torch_mlu_ops
+```
+
+### 测试
+
+在测试前,请确保Torch-MLU-Ops已安装,您可以通过 `pip list` 命令查询Torch-MLU-Ops的安装情况。若有如下类似打印,则表明安装成功。
+```shell
+torch-mlu-ops 1.1.0+pt21
+```
+
+Torch-MLU-Ops模块的导入方法请参考。
+```python
+import torch_mlu_ops as tmo
+```
+
+在`./tests/ops_pytest`目录下,提供了`ops`级别的测试示例程序,您可以参考如下命令进行测试。
+```shell
+cd tests/ops_pytest/
+./run_test.sh
+# or 单独测试某个测例,例如
+python test_flash_attention.py
+```
+在`./tests/kernels_pytest`目录下,提供了`kernels`级别的测试示例程序,您可以参考如下命令进行测试。
+```shell
+cd tests/kernels_pytest
+./build.sh # 需要先编译
+./run_test.sh
+```
+## 目录结构
+```
+torch_mlu_ops
+|-- benchmarks ## 性能benchmarks测试代码目录
+|-- csrc ## C以及BangC源码目录
+| |-- common
+| |-- kernels
+| |-- ops
+| `-- torch_api
+|-- docs ## 文档目录
+| |-- release_notes
+| `-- user_guide
+|-- tests ## 测试代码目录,包含了kernel和ops级别的测试
+| |-- kernels_pytest
+| `-- ops_pytest
+|-- tools ## 工具目录,主要包含CI相关代码
+| |-- ci
+`-- torch_mlu_ops ## PyTorch第三方算子接口实现相关代码
+```
+
+## 关键特性
+
+Torch-MLU-Ops提供的自定义算子列表以及接口说明请参考[Torch-MLU-Ops用户手册](./docs/user_guide/)。
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/README.md b/torch_mlu_ops-v1.3.2/benchmarks/README.md
new file mode 100644
index 0000000..4717d61
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/README.md
@@ -0,0 +1,51 @@
+## benchmark测试脚本使用方式
+
+Torch-MLU-Ops benchmark测试脚本为用户提供了进行算子性能测试的便捷入口。
+用户可通过以下命令获取各个参数的含义。
+
+```bash
+# 测试命令帮助
+python3 benchmark_xxx.py --help
+```
+各个参数含义如下:
+
+`options`:
+- -h, --help show this help message and exit
+- --repeat_times REPEAT_TIMES repeat times for testing
+- --csv write the report data to csv
+- -o O specify the output folder name under --csv mode
+
+```bash
+# 测试命令示例如下
+python3 benchmark_active.py --repeat_times 10 --csv -o './active/'
+```
+支持如下算子:
+
+| op_name |
+| ---------------------------------|
+| active |
+| apply_rotary |
+| attention_project |
+| ffn |
+| flash_attn |
+| fused_layer_norm |
+| fused_moe |
+| fused_norm_attention_project |
+| fused_norm_residual_ffn |
+| fused_rms_norm |
+| group_gemm |
+| matmul |
+| offline_quant_to_linear_cache |
+| per_token_smooth_quantize |
+| preload |
+| quantize |
+| reshape_linear_cache |
+| quant_to_linear_cache |
+| reshape_paged_cache |
+| single_query_cached_kv_attn |
+| smooth_quant_matmul |
+| weight_only_quant_matmul |
+| moe_gen_idx |
+| moe_expand_input |
+| moe_softmax_topk |
+| moe_combine_result |
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py
new file mode 100644
index 0000000..e87a7ef
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py
@@ -0,0 +1,64 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 16, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 72, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 1024, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 4096, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 8192, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 32768, "seq_len": 5, "hidden_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["input_shape", "act_mode", "is_gated", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ hidden_size = params_dict["hidden_size"]
+ act_mode = params_dict["act_mode"]
+ is_gated = params_dict["is_gated"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.active,
+ input,
+ act_mode,
+ is_gated,
+ repeats=args.repeat_times)
+ io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated)
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{batch,seq_len,hidden_size}", f"{act_mode}", f"{is_gated}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py
new file mode 100644
index 0000000..c2bec3c
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py
@@ -0,0 +1,84 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "rotary_dim": 128,
+ "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "rotary_dim": 64,
+ "interleaved": True, "discrete": False, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "rotary_dim": 128,
+ "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "rotary_dim": 128,
+ "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "rotary_dim": 64,
+ "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "rotary_dim": 96,
+ "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 4, "seq_len": 1, "head_num": 80, "head_size": 128, "rotary_dim": 128,
+ "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}]
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "seq_len", "head_num", "head_size", "rotary_dim", "interleaved", "discrete", "dynamic_ntk", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ head_num = params_dict["head_num"]
+ head_size = params_dict["head_size"]
+ # full/partial
+ rotary_dim = params_dict["rotary_dim"]
+ # cross/fold
+ interleaved = params_dict["interleaved"]
+ # discrete
+ discrete = params_dict["discrete"]
+ # dynamic_ntk
+ dynamic_ntk = params_dict["dynamic_ntk"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input = torch.randn(batch, seq_len, head_num, head_size).to(device).to(dtype) # [batch, seqlen, head_num, head_size]
+ if dynamic_ntk:
+ sin_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
+ cos_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
+ else:
+ sin_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
+ cos_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
+ if discrete:
+ pos_ids = torch.randint(0, seq_len, (batch * seq_len,)).to(device).to(torch.int32)
+ else:
+ pos_ids = None
+ hardware_time, e2e_time = benchmark_forward(tmo.apply_rotary,
+ input,
+ sin_cache,
+ cos_cache,
+ pos_ids,
+ None,
+ interleaved,
+ discrete,
+ dynamic_ntk,
+ seq_len,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{head_num}", f"{head_size}", f"{rotary_dim}", f"{interleaved}", f"{discrete}", f"{dynamic_ntk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_attention_project.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_attention_project.py
new file mode 100644
index 0000000..0cbc627
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_attention_project.py
@@ -0,0 +1,74 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "hidden_size": 1600,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 2048, "hidden_size": 2048,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 4096, "hidden_size": 4096,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 6144, "hidden_size": 6144,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 6656, "hidden_size": 6656,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 8192, "hidden_size": 8192,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 12288, "hidden_size": 12288,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 14336, "hidden_size": 14336,
+ "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "seq_len", "input_size", "hidden_size", "has_residual", "has_bias", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ input_size = params_dict["input_size"]
+ hidden_size = params_dict["hidden_size"]
+ has_residual = params_dict["has_residual"]
+ has_bias = params_dict["has_bias"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ x = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
+ weight = torch.randn(hidden_size, input_size).to(dtype).to(device)
+ residual, bias = None, None
+ if has_residual:
+ residual = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
+ if has_bias:
+ bias = torch.randn(hidden_size).to(dtype).to(device)
+ hardware_time, e2e_time = benchmark_forward(tmo.attention_project,
+ x,
+ weight,
+ bias,
+ residual,
+ 1.0,
+ 1.0,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{has_residual}", f"{has_bias}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py
new file mode 100644
index 0000000..740c7e0
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py
@@ -0,0 +1,60 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 2, "m": 1024, "k": 1600, "n": 6400, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 2, "m": 1024, "k": 2048, "n": 8192, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 2, "m": 1024, "k": 4096, "n": 11008, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 2, "m": 1024, "k": 5120, "n": 16384, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 2, "m": 1024, "k": 6144, "n": 24576, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "m", "k", "n", "has_c", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ m = params_dict["m"]
+ k = params_dict["k"]
+ n = params_dict["n"]
+ has_c = params_dict["has_c"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ a = torch.randn(batch, m, k).to(device).to(dtype)
+ b = torch.randn(batch, n, k).to(device).to(dtype)
+ c = None
+ if has_c:
+ c = torch.randn(batch, m, n).to(device).to(dtype)
+
+ hardware_time, e2e_time = benchmark_forward(tmo.batch_matmul,
+ a,
+ b,
+ c,
+ 1.0,
+ 1.0,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{m}", f"{k}", f"{n}", f"{has_c}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_linear_cache.py
new file mode 100644
index 0000000..4cb5f7d
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_linear_cache.py
@@ -0,0 +1,120 @@
+import argparse
+import random
+import os
+
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+
+from common import benchmark_forward, save_to_csv
+from itertools import product
+from tabulate import tabulate
+
+e2e_time_param_dict_list = [
+ {"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
+ "head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "head_size": 128,
+ "input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [4, 8],
+ "use_offset": True},
+]
+
+def main():
+ parser = argparse.ArgumentParser(description="Benchmark for dequant from linear cache.")
+ parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["max_batch_size", "batch_size", "max_context_len", "head_num_q", "head_num_kv",
+ "cache_mem_len", "head_size", "input_dytpe", "quant_mode", "quant_bit",
+ "use_offset", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ mlu_name = torch.mlu.get_device_name()
+ for params_dict in e2e_time_param_dict_list:
+ max_batch_size = params_dict["max_batch_size"]
+ batch_size_list = params_dict["batch_size"]
+ max_context_len_list = params_dict["max_context_len"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ cache_mem_len = params_dict["cache_mem_len"]
+ input_dtype_list = params_dict["input_dtype"]
+ quant_mode_list = params_dict["quant_mode"]
+ quant_bit_list = params_dict["quant_bit"]
+ use_offset = params_dict["use_offset"]
+ for batch_size, max_context_len, quant_mode, quant_bit, dtype in list(product( \
+ batch_size_list, max_context_len_list, quant_mode_list, quant_bit_list, \
+ input_dtype_list)):
+ torch.manual_seed(2766)
+ torch.mlu.manual_seed(2766)
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \
+ * head_size >= 2**31 - 1):
+ print("large tensor is not support on {}, skip".format(mlu_name))
+ continue
+ total_heads = head_num_q + head_num_kv * 2
+ assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
+ "equal to cache_mem_len."
+ max_seq_offset = cache_mem_len - max_context_len
+ # Generates key and cache from context
+ context_lens = torch.randint(size=[batch_size], low=max_context_len,
+ high=max_context_len + 1,
+ dtype=torch.int32, device="mlu")
+ if use_offset:
+ context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
+ dtype=torch.int32, device="mlu")
+ else:
+ context_paddings = torch.zeros_like(context_lens)
+ cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
+ total_seqlen = cu_context_lens[-1]
+ context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
+ context_seq_offset[1:] = cu_context_lens[:-1]
+ context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
+ key = context[..., head_num_q:head_num_q + head_num_kv, :]
+ value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
+
+ # Generates key_cache and value_cache
+ cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu()
+ cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size],
+ dtype=torch.int32, device="mlu")
+ if quant_bit == 4:
+ key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
+ head_num_kv, cache_mem_len, head_size // 2), device="mlu")
+ value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
+ head_num_kv, cache_mem_len // 2, head_size), device="mlu")
+ key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8)
+ else:
+ cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size),
+ low=-128, high=127, dtype=torch.int32, device="mlu")
+ cache = cache.to(torch.int8)
+ key_cache, value_cache = cache[[0, 1]]
+
+ # Generates key_cache_scale and value_cache_scale
+ if quant_mode == 0: # quant_mode == 0 is per channel
+ cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
+ else: # quant_mode != 1 (== 1 for extend) is per head
+ cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len),
+ dtype=torch.float, device="mlu")
+ key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
+ hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_linear_cache,
+ key, value, key_cache, value_cache,
+ key_cache_scale, value_cache_scale,
+ context_lens, max_context_len,
+ context_seq_offset if use_offset else None,
+ cache_bs_id, cache_seq_offset, quant_mode,
+ quant_bit, repeats=args.repeat_times)
+ content = [f"{max_batch_size}", f"{batch_size}", f"{max_context_len}", f"{head_num_q}",
+ f"{head_num_kv}", f"{cache_mem_len}", f"{head_size}", f"{dtype}", f"{quant_mode}",
+ f"{quant_bit}", f"{quant_mode}", f"{use_offset}", f"{hardware_time}",
+ f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_paged_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_paged_cache.py
new file mode 100644
index 0000000..e00d0cb
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_paged_cache.py
@@ -0,0 +1,110 @@
+import argparse
+import math
+import random
+import os
+
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+
+from common import benchmark_forward, save_to_csv
+from itertools import product
+from tabulate import tabulate
+
+e2e_time_param_dict_list = [
+ {"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
+ "head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "block_size": 16, "head_size": 128,
+ "input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [8],
+ "use_offset": True},
+]
+
+def main():
+ parser = argparse.ArgumentParser(description="Benchmark for dequant from paged cache.")
+ parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ if "MLU3" in torch.mlu.get_device_name():
+ print("Op dequant_from_paged_cache does not support MLU300 devices.")
+ return
+ titles = ["batch_size", "max_context_len", "head_num_q", "head_num_kv",
+ "cache_mem_len", "block_size", "head_size", "input_dytpe", "quant_mode", "quant_bit",
+ "use_offset", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch_size_list = params_dict["batch_size"]
+ max_context_len_list = params_dict["max_context_len"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ block_size = params_dict["block_size"]
+ head_size = params_dict["head_size"]
+ cache_mem_len = params_dict["cache_mem_len"]
+ input_dtype_list = params_dict["input_dtype"]
+ quant_mode_list = params_dict["quant_mode"]
+ quant_bit_list = params_dict["quant_bit"]
+ use_offset = params_dict["use_offset"]
+ for quant_mode, batch_size, max_context_len, quant_bit, dtype in list(product( \
+ quant_mode_list, batch_size_list, max_context_len_list, quant_bit_list, \
+ input_dtype_list)):
+ torch.manual_seed(2766)
+ torch.mlu.manual_seed(2766)
+ total_heads = head_num_q + head_num_kv * 2
+
+ assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
+ "equal to cache_mem_len."
+ max_seq_offset = cache_mem_len - max_context_len
+ max_block_num = int(math.ceil(max_context_len / block_size))
+ total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size
+ block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num)
+ block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size,
+ max_block_num)
+ # Generates key and cache from context
+ context_lens = torch.randint(size=[batch_size], low=max_context_len, high=max_context_len + 1,
+ dtype=torch.int32, device="mlu")
+ if use_offset:
+ context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
+ dtype=torch.int32, device="mlu")
+ else:
+ context_paddings = torch.zeros_like(context_lens)
+ cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
+ total_seqlen = cu_context_lens[-1]
+ context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
+ context_seq_offset[1:] = cu_context_lens[:-1]
+ context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
+ key = context[..., head_num_q:head_num_q + head_num_kv, :]
+ value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
+ # Generates key_cache and value_cache
+ cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size),
+ low=-128, high=127, dtype=torch.int32, device="mlu")
+ cache = cache.to(torch.int8)
+ key_cache, value_cache = cache[[0, 1]]
+
+ # Generates key_cache_scale and value_cache_scale
+ if quant_mode == 0: # quant_mode == 0 is per channel
+ cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
+ else: # quant_mode != 1 (== 1 for extend) is per head
+ cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size),
+ dtype=torch.float, device="mlu")
+ key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
+ hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_paged_cache,
+ key, value, key_cache, value_cache,
+ key_cache_scale, value_cache_scale,
+ context_lens, max_context_len,
+ context_seq_offset if use_offset else None,
+ block_tables, quant_mode,
+ quant_bit, repeats=args.repeat_times)
+ content = [f"{batch_size}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}",
+ f"{cache_mem_len}", f"{block_size}", f"{head_size}", f"{dtype}", f"{quant_mode}",
+ f"{quant_bit}", f"{use_offset}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py
new file mode 100644
index 0000000..0031aa9
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py
@@ -0,0 +1,89 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
+ "gated_ffn": False, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
+ "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ hidden_size = params_dict["hidden_size"]
+ inner_size = params_dict["inner_size"]
+ gated_ffn = params_dict["gated_ffn"]
+ act_mode = params_dict["act_mode"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
+ up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
+ up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
+ down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
+ down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
+ gate_up_proj_weight, gate_up_proj_bias = None, None
+ if gated_ffn:
+ gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
+ gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.ffn,
+ input,
+ up_proj_weight,
+ up_proj_bias,
+ down_proj_weight,
+ down_proj_bias,
+ gate_up_proj_weight,
+ gate_up_proj_bias,
+ act_mode,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py
new file mode 100644
index 0000000..f7e9045
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py
@@ -0,0 +1,92 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+# for e2e time test
+e2e_time_param_dict_list = [{"batch": 1, "seq_q": 32768, "seq_kv": 32768, "head_num": 8,
+ "head_num_kv": 1, "head_size": 128, "use_causal": True,
+ "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_q": 16384, "seq_kv": 16384, "head_num": 8,
+ "head_num_kv": 1, "head_size": 128, "use_causal": True,
+ "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_q": 8192, "seq_kv": 24576, "head_num": 8,
+ "head_num_kv": 1, "head_size": 128, "use_causal": True,
+ "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_q": 4096, "seq_kv": 28672, "head_num": 8,
+ "head_num_kv": 1, "head_size": 128, "use_causal": True,
+ "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_q": 4096, "seq_kv": 32768, "head_num": 8,
+ "head_num_kv": 1, "head_size": 128, "use_causal": True,
+ "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["batch", "seq_q", "seq_kv", "head_num", "head_num_kv", "head_size", "use_causal", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_q = params_dict["seq_q"]
+ seq_kv = params_dict["seq_kv"]
+ head_num = params_dict["head_num"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ use_causal = params_dict["use_causal"]
+ softmax_scale = params_dict["softmax_scale"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ if seq_q == seq_kv:
+ qkv = torch.randn(batch, seq_q, head_num + 2 * head_num_kv, head_size).to(dtype).to(device)
+ q = qkv[:, :, : head_num, :]
+ k = qkv[:, :, head_num : head_num + head_num_kv, :]
+ v = qkv[:, :, head_num + head_num_kv : head_num + head_num * 2, :]
+ elif seq_q < seq_kv:
+ q = torch.randn(batch, seq_q, head_num, head_size).to(device).to(dtype)
+ kv = torch.randn(batch, seq_kv, head_num_kv * 2, head_size).to(device).to(dtype)
+ k = kv[:, :, : head_num_kv, :]
+ v = kv[:, :, head_num_kv :, :]
+ hardware_time, e2e_time = benchmark_forward(tmo.flash_attention,
+ q = q,
+ k = k,
+ v = v,
+ out = None,
+ cu_seq_lens_q = None,
+ cu_seq_lens_kv = None,
+ alibi_slope = None,
+ attn_bias = None,
+ max_seq_len_q = seq_q,
+ max_seq_len_kv = seq_kv,
+ softmax_scale = softmax_scale,
+ is_causal = use_causal,
+ window_size_left = -1,
+ window_size_right = -1,
+ compute_dtype = dtype,
+ return_lse = False,
+ block_tables = None,
+ k_cache_quant_scale = None,
+ v_cache_quant_scale = None,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_q}", f"{seq_kv}", f"{head_num}", f"{head_num_kv}", f"{head_size}", f"{use_causal}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py
new file mode 100644
index 0000000..f702f69
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py
@@ -0,0 +1,103 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [
+ {"batch": 1, "seq_len": 2048, "hidden_size": 8192, "has_residual": True,
+ "has_bias": False, "has_quant": True, "dynamic_quant": True,
+ "input_dtype": torch.bfloat16},
+ {"batch": 1, "seq_len": 4096, "hidden_size": 8192, "has_residual": True,
+ "has_bias": False, "has_quant": True, "dynamic_quant": True,
+ "input_dtype": torch.bfloat16},
+ {"batch": 1, "seq_len": 8192, "hidden_size": 8192, "has_residual": True,
+ "has_bias": False, "has_quant": True, "dynamic_quant": True,
+ "input_dtype": torch.bfloat16},
+ {"batch": 1, "seq_len": 32768, "hidden_size": 8192, "has_residual": True,
+ "has_bias": False, "has_quant": True, "dynamic_quant": True,
+ "input_dtype": torch.bfloat16},
+ {"batch": 490, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
+ "has_bias": False, "has_quant": True, "dynamic_quant": True,
+ "input_dtype": torch.bfloat16},
+ {"batch": 525, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
+ "has_bias": False, "has_quant": True, "dynamic_quant": True,
+ "input_dtype": torch.bfloat16},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "seq_len", "hidden_size", "has_residual", "has_bias", "has_quant",
+ "dynamic_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ hidden_size = params_dict["hidden_size"]
+ has_residual = params_dict["has_residual"]
+ has_bias = params_dict["has_bias"]
+ has_quant = params_dict["has_quant"]
+ dynamic_quant = params_dict["dynamic_quant"]
+ dtype = params_dict["input_dtype"]
+ eps = 1e-6
+
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+
+ x = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
+ beta = torch.randn(hidden_size, dtype=dtype, device=device)
+ gamma = torch.randn(hidden_size, dtype=dtype, device=device)
+ residual, bias, quant_scale = None, None, None
+ if has_residual:
+ residual = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
+ if has_bias:
+ bias = torch.randn(hidden_size, dtype=dtype, device=device)
+ if has_quant or dynamic_quant:
+ quant_scale = torch.randn(hidden_size, dtype=torch.float, device=device)
+
+ store_output_before_norm = has_residual
+ hardware_time, e2e_time = benchmark_forward(tmo.fused_layer_norm,
+ x,
+ residual,
+ gamma,
+ beta,
+ bias,
+ eps,
+ store_output_before_norm,
+ quant_scale,
+ None,
+ dynamic_quant,
+ repeats=args.repeat_times)
+ n = x.nelement()
+ sizeoft = x.element_size()
+ io_bytes = (sizeoft + 1) * n + \
+ (1 + store_output_before_norm) * (sizeoft * n if has_residual else 0) + \
+ sizeoft * hidden_size * 2 + \
+ (hidden_size * 4 if has_quant else 0) + \
+ (batch * seq_len * 4 if dynamic_quant else 0)
+ io_eff = io_bytes / hardware_time / bd
+
+ content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{has_residual}", f"{has_bias}",
+ f"{has_quant}", f"{dynamic_quant}", f"{dtype}",
+ f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py
new file mode 100644
index 0000000..3c02664
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py
@@ -0,0 +1,143 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [
+ {"batch": 1, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 16, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 72, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 128, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 490, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 525, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 1, "seq_len": 2048, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 1, "seq_len": 4096, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 1, "seq_len": 8192, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+
+ {"batch": 1, "seq_len": 32768, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
+ "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
+ "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "num_expert", "topk", "act_mode", "quant_weight", "dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ hidden_size = params_dict["hidden_size"]
+ inner_size = params_dict["inner_size"]
+ gated_ffn = params_dict["gated_ffn"]
+ act_mode = params_dict["act_mode"]
+ num_expert = params_dict["num_expert"]
+ start_expert_id = params_dict["start_expert_id"]
+ expert_size = params_dict["expert_size"]
+ topk = params_dict["topk"]
+ has_residual = params_dict["has_residual"]
+ smooth_quant = params_dict["smooth_quant"]
+ renormalize = params_dict["renormalize"]
+ input_dtype_list = params_dict["dtype"]
+ # print(f"batch:{batch}, seq_len:{seq_len}, hidden_size:{hidden_size}, inner_size:{inner_size}, "
+ # f"gated_ffn:{gated_ffn}, act_mode:{act_mode}, num_expert:{num_expert}, topk:{topk}")
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ hidden_states = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
+ router_logit = torch.randn(batch, seq_len, num_expert, device=device, dtype=torch.float32)
+
+ if False: # print token_count
+ softmax = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
+ topk_logit, expert_id = torch.topk(softmax, k=topk, dim=1)
+ if renormalize:
+ topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
+ sorted_expert_id, indices = expert_id.int().flatten().sort()
+ token_cout = torch.bincount(sorted_expert_id, minlength=num_expert).int()
+ print(token_cout)
+
+ residual = None
+ if has_residual:
+ residual = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
+ weight1 = torch.randn(num_expert, inner_size*(1+gated_ffn), hidden_size, device=device, dtype=dtype)
+ bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device=device, dtype=data_type)
+ weight2 = torch.randn(num_expert, hidden_size, inner_size, device=device, dtype=dtype)
+ bias2 = None # torch.randn(expert_num, hidden_size, device=device, dtype=data_type)
+ input_smooth, act_smooth, w1_scale, w2_scale = None, None, None, None
+ if smooth_quant:
+ input_smooth = torch.randn(expert_size, hidden_size, device=device, dtype=torch.float32).abs() + 0.1
+ act_smooth = torch.randn(expert_size, inner_size, device=device, dtype=torch.float32).abs() + 0.1
+ weight1 = torch.randint(-128, 127, (num_expert, inner_size*(1+gated_ffn), hidden_size)).to(torch.int8).mlu()
+ weight2 = torch.randint(-128, 127, (num_expert, hidden_size, inner_size)).to(torch.int8).mlu()
+ w1_scale = torch.randn(expert_size, (1+gated_ffn)*inner_size).to(device).to(torch.float32)
+ w2_scale = torch.randn(expert_size, hidden_size).to(device).to(torch.float32)
+
+ hardware_time, e2e_time = benchmark_forward(tmo.fused_moe,
+ hidden_states,
+ router_logit,
+ weight1[start_expert_id:start_expert_id+expert_size],
+ weight2[start_expert_id:start_expert_id+expert_size],
+ bias1,
+ bias2,
+ residual,
+ input_smooth,
+ act_smooth,
+ w1_scale,
+ w2_scale,
+ topk,
+ renormalize,
+ gated_ffn,
+ act_mode,
+ start_expert_id,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{num_expert}", f"{topk}", f"{act_mode}", f"{smooth_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_attention_project.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_attention_project.py
new file mode 100644
index 0000000..ed8e7bf
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_attention_project.py
@@ -0,0 +1,71 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "head_size": 80, "hidden_size": 1600, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 2048, "head_size": 128, "hidden_size": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 4096, "head_size": 128, "hidden_size": 4096, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 6144, "head_size": 128, "hidden_size": 6144, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 6656, "head_size": 128, "hidden_size": 6656, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 8192, "head_size": 128, "hidden_size": 8192, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 12288, "head_size": 128, "hidden_size": 12288, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "input_size": 14336, "head_size": 128, "hidden_size": 14336, "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["batch", "seq_len", "input_size", "hidden_size", "head_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ input_size = params_dict["input_size"]
+ hidden_size = params_dict["hidden_size"]
+ head_size = params_dict["head_size"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input = torch.randn(batch, seq_len, input_size).to(device).to(dtype)
+ weight = torch.randn(hidden_size * 3, input_size).to(device).to(dtype)
+ bias = torch.randn(hidden_size * 3).to(device).to(dtype)
+ weights = torch.chunk(weight, 3)
+ biases = torch.chunk(bias, 3)
+ norm_weight = torch.randn(input_size).to(device).to(dtype)
+ norm_bias = torch.randn(input_size).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_attention_project,
+ input,
+ weights[0],
+ biases[0],
+ weights[1],
+ biases[1],
+ weights[2],
+ biases[2],
+ norm_weight,
+ norm_bias,
+ 1e-6,
+ 'nthc',
+ head_size,
+ False,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{head_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_residual_ffn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_residual_ffn.py
new file mode 100644
index 0000000..ec6a392
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_residual_ffn.py
@@ -0,0 +1,97 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
+ "gated_ffn": False, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
+ "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
+ "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
+ "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
+ "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
+ "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
+ "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
+ "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
+ "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
+ "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
+ "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
+ "gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
+ "gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
+ "gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "residual_is", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ hidden_size = params_dict["hidden_size"]
+ inner_size = params_dict["inner_size"]
+ gated_ffn = params_dict["gated_ffn"]
+ residual_is = params_dict["residual_is"]
+ act_mode = params_dict["act_mode"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
+ up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
+ up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
+ down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
+ down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
+ layernorm_weight = torch.randn(hidden_size).to(device).to(dtype)
+ layernorm_bias = torch.randn(hidden_size).to(device).to(dtype)
+ gate_up_proj_weight, gate_up_proj_bias = None, None
+ if gated_ffn:
+ gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
+ gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_residual_ffn,
+ input,
+ up_proj_weight,
+ up_proj_bias,
+ down_proj_weight,
+ down_proj_bias,
+ gate_up_proj_weight,
+ gate_up_proj_bias,
+ layernorm_weight,
+ layernorm_bias,
+ 1e-6,
+ act_mode,
+ residual_is,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{residual_is}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py
new file mode 100644
index 0000000..b9653b6
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py
@@ -0,0 +1,90 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+from itertools import product
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 16, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 96, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "head_num": 112, "head_size": 128, "has_residual": True,
+ "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["input_shape", "has_residual", "has_bias", "has_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ head_num = params_dict["head_num"]
+ head_size = params_dict["head_size"]
+ has_residual = params_dict["has_residual"]
+ has_quant = params_dict["has_quant"]
+ has_bias = params_dict["has_bias"]
+ eps = params_dict["eps"]
+ dynamic_quant_list = params_dict["dynamic_quant"]
+ input_dtype_list = params_dict["input_dtype"]
+ dynamic_quant_list = params_dict["dynamic_quant"]
+ input_dtype_list = params_dict["input_dtype"]
+ iters = product(dynamic_quant_list, input_dtype_list)
+ for dynamic_quant, dtype in iters:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ x = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
+ beta = torch.randn(head_size).to(dtype).to(device)
+ gamma = torch.randn(head_size).to(dtype).to(device)
+ residual, bias, quant_scale = None, None, None
+ if has_residual:
+ residual = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
+ if has_bias:
+ bias = torch.randn(head_size).to(dtype).to(device)
+ if has_quant or dynamic_quant:
+ quant_scale = torch.randn(head_size).to(device)
+ hardware_time, e2e_time = benchmark_forward(tmo.fused_rms_norm,
+ x,
+ residual,
+ gamma,
+ beta,
+ bias,
+ eps,
+ False,
+ quant_scale,
+ None,
+ dynamic_quant,
+ repeats=args.repeat_times)
+ content = [f"{batch, seq_len, head_num, head_size}", f"{has_residual}", f"{has_bias}", f"{has_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py
new file mode 100644
index 0000000..07bf082
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py
@@ -0,0 +1,207 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
+
+ {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
+
+ {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
+ "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
+def main():
+ if 'MLU3' in torch.mlu.get_device_name():
+ exit()
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
+ "block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ bs = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ q_heads = params_dict["head_num_q"]
+ kv_heads = params_dict["head_num_k"]
+ head_size = params_dict["head_size"]
+ rope_dim = params_dict["rotary_dim"]
+ quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
+ paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
+ mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
+ max_decode_len = 0
+ num_blocks = 0
+ block_size = 0
+ if paged_cache:
+ num_blocks = params_dict["num_blocks"]
+ block_size = params_dict["block_size"]
+ else:
+ max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
+ input_dtype_list = params_dict["input_dtype"]
+
+ for dtype in input_dtype_list:
+ discrete_batch = True
+ max_bs = bs + 1 if discrete_batch else bs
+ input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
+ input = torch.randn(size=input_shape, dtype=dtype).mlu()
+ input_ref = input.clone()
+
+ cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
+ sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
+ gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
+ beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
+ cache_dtype = dtype
+ if quant_kv:
+ k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
+ v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
+ cache_dtype = torch.int8
+ k_scale_ops = 1 / k_scale
+ v_scale_ops = 1 / v_scale
+ else:
+ k_scale = None
+ v_scale = None
+ k_scale_ops = None
+ v_scale_ops = None
+ if paged_cache:
+ cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
+ else:
+ cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
+
+ if quant_kv:
+ cache = (cache - 0.5) * 256
+ cache = cache.to(cache_dtype)
+ k_cache = cache[0]
+ v_cache = cache[1]
+
+ cache_bs_id = None
+ cache_seq_offsets = None
+ slot_mapping = None
+ if not paged_cache:
+ if discrete_batch:
+ cache_bs_id = random.sample([*range(0, max_bs)], bs)
+ cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
+
+ cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
+ dtype=torch.int32, device='mlu')
+ else:
+ slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
+ slot_mapping = torch.IntTensor(slot_mapping).mlu()
+
+ position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
+
+ k_cache_lp = None
+ v_cache_lp = None
+ k_scale_lp = None
+ v_scale_lp = None
+ cache_bs_id_lp = None
+ cache_seq_offsets_lp = None
+
+ if mixed_cache:
+ max_decode_len_lp = 1024
+ k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
+ v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
+ max_value = torch.amax(torch.abs(k_cache_raw))
+ k_cache_raw = k_cache_raw * (7 / max_value)
+ max_value = torch.amax(torch.abs(v_cache_raw))
+ v_cache_raw = v_cache_raw * (7 / max_value)
+ k_cache_lp = k_cache_raw.to(torch.int8)
+ v_cache_lp = v_cache_raw.to(torch.int8)
+
+ k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
+ v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
+
+ cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
+ cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
+ cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
+ dtype=torch.int32, device='mlu')
+
+ hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
+ input,
+ k_cache,
+ v_cache,
+ sin_table,
+ cos_table,
+ position_id,
+ gamma,
+ beta,
+ k_cache_lp,
+ v_cache_lp,
+ cache_bs_id,
+ cache_seq_offsets,
+ cache_bs_id_lp,
+ cache_seq_offsets_lp,
+ k_scale_ops,
+ v_scale_ops,
+ k_scale_lp,
+ v_scale_lp,
+ slot_mapping,
+ None,
+ 1e-5,
+ repeats=args.repeat_times)
+ content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
+ f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py
new file mode 100644
index 0000000..edb2cc6
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py
@@ -0,0 +1,117 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [
+ {"batch": 1, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 16, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 16, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 72, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 490, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 490, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 525, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 525, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+
+ {"batch": 1, "seq_len": 1024, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 1024, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 2048, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 2048, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 4096, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 4096, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 8192, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 8192, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 32768, "k": 8192, "n": 1024, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 32768, "k": 1024, "n": 8192, "expert_num": 32,
+ "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+
+ titles = ["batch", "seq_len", "k", "n", "expert_num", "topk", "smooth_quant", "dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ k = params_dict["k"]
+ n = params_dict["n"]
+ expert_num = params_dict["expert_num"]
+ topk = params_dict["topk"]
+ is_quant = params_dict["is_quant"]
+ input_dtype_list = params_dict["dtype"]
+ # print(f"batch:{batch}, seq_len:{seq_len}, k:{k}, n:{n}, expert_num:{expert_num}, topk:{topk}, is_quant:{is_quant}")
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ max_m = batch * seq_len
+ m = batch * seq_len * topk
+ avg, rem = m // expert_num, m % expert_num
+ m_list = [avg + (i < rem) for i in range(expert_num)]
+ token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
+
+ if not is_quant:
+ a = torch.randn(m, k, dtype=dtype, device='mlu')
+ b = torch.randn(expert_num, n, k, dtype=dtype, device='mlu')
+ hardware_time, e2e_time = benchmark_forward(tmo.group_gemm,
+ a, b, token_count,
+ None, None, None, None,
+ max_m,
+ repeats=args.repeat_times)
+ else:
+ a = torch.randint(-128, 127, (m, k)).to(torch.int8).mlu()
+ b = torch.randint(-128, 127, (expert_num, n, k)).to(torch.int8).mlu()
+ a_scale = torch.randn(a.size(0), dtype=torch.float32, device='mlu')
+ b_scale = torch.randn(expert_num, n, dtype=torch.float32, device='mlu')
+ hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_group_gemm,
+ a, b, token_count,
+ None, None, None, None,
+ a_scale, b_scale, dtype, max_m,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{seq_len}", f"{k}", f"{n}", f"{expert_num}", f"{topk}",
+ f"{is_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py
new file mode 100644
index 0000000..40f3e96
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py
@@ -0,0 +1,75 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"m": 1024, "k": 1600, "n": 6400, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 2048, "n": 8192, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 4096, "n": 11008, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 4096, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 5120, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 5120, "n": 27392, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 6144, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 6656, "n": 17920, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 8192, "n": 22016, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 8192, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 8192, "n": 28672, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 8192, "n": 49152, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 12288, "n": 32768, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 14336, "n": 57344, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ m = params_dict["m"]
+ k = params_dict["k"]
+ n = params_dict["n"]
+ has_c = params_dict["has_c"]
+ has_bias = params_dict["has_bias"]
+ act_mode = params_dict["act_mode"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ a = torch.randn(m, k).to(device).to(dtype)
+ b = torch.randn(n, k).to(device).to(dtype)
+ c = None
+ if has_c:
+ c = torch.randn(m, n).to(device).to(dtype)
+ bias = None
+ if has_bias:
+ bias = torch.randn(n).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.matmul,
+ a,
+ b,
+ bias,
+ c,
+ act_mode,
+ 1.0,
+ 1.0,
+ repeats=args.repeat_times)
+ content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py
new file mode 100644
index 0000000..6acde62
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py
@@ -0,0 +1,117 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": True, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 4096, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 8192, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 32768, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 1, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 16, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 32, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 64, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 128, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 256, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},
+ {"batch": 512, "seq_len": 1, "inner_size": 1024,
+ "act_mode": "gelu", "is_gated": False, "has_bias": True,
+ "is_ep": True, "input_dtype": [torch.bfloat16]},]
+
+def gen_data(num_expert,
+ total_tokens,
+ inner_size,
+ output_stride,
+ dtype,
+ is_gated,
+ has_bias,
+ is_ep):
+ ci = inner_size * (1 + is_gated)
+ input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu')
+ cusum_token_count, token_count = generate_token_count(num_expert, total_tokens)
+ output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu')
+ output.as_strided(output.size(), (output_stride, 1))
+ start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0
+ expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert
+ bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None
+ return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["input_shape", "act_mode", "is_gated", "has_bias", "expert_num", "start_expert_id",
+ "expert_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ seq_len = params_dict["seq_len"]
+ inner_size = params_dict["inner_size"]
+ act_mode = params_dict["act_mode"]
+ is_gated = params_dict["is_gated"]
+ input_dtype_list = params_dict["input_dtype"]
+ has_bias = params_dict["has_bias"]
+ is_ep = params_dict["is_ep"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ expert_num = expert_num = random.randint(1, 256)
+ input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
+ gen_data(expert_num, batch * seq_len, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
+
+ real_bias = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_active,
+ input,
+ act_mode,
+ is_gated,
+ output,
+ real_bias,
+ cusum_token_count.mlu() if has_bias or is_ep else None,
+ start_expert_id,
+ expert_size,
+ repeats=args.repeat_times)
+
+ io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated) + \
+ real_bias.element_size() * real_bias.nelement() + \
+ (cusum_token_count.element_size() * cusum_token_count.nelement()) if has_bias or is_ep else 0
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{batch,seq_len,inner_size}", f"{act_mode}", f"{is_gated}", f"{has_bias}", f"{expert_num}",
+ f"{start_expert_id}", f"{expert_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py
new file mode 100644
index 0000000..0c947ec
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py
@@ -0,0 +1,59 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
+ {"batch": 1, "seq_len": 4096, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
+ {"batch": 1, "seq_len": 32768, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
+ {"batch": 16, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
+ {"batch": 128, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
+ {"batch": 512, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}]
+
+def main():
+ if 'MLU3' in torch.mlu.get_device_name():
+ exit()
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["batch", "seq_len", "hidden_size", "expert_num", "input_dtype", "hardware_time(us)",
+ "e2e_latency(us)", "IO efficiency"]
+
+ contents = []
+ bandwidth = get_band_width()
+ for param_dict in e2e_time_param_dict_list:
+ batch = param_dict["batch"]
+ seq_len = param_dict["seq_len"]
+ hidden_size = param_dict["hidden_size"]
+ expert_num = param_dict["expert_num"]
+ input_dtype = param_dict["input_dtype"]
+ if input_dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ input_dtype = torch.half
+ input = torch.randn(batch, seq_len, hidden_size, dtype=input_dtype, device="mlu")
+ weight = torch.randn(expert_num, hidden_size, dtype=torch.float32, device="mlu")
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_cast_gating,
+ input,
+ weight)
+ io_bytes = batch * seq_len * hidden_size * input.element_size() + \
+ expert_num * hidden_size * weight.element_size() + batch * seq_len * expert_num * weight.element_size()
+ io_coeff = io_bytes / hardware_time / bandwidth
+ content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{expert_num}", f"{input_dtype}",
+ f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py
new file mode 100644
index 0000000..9c1483c
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py
@@ -0,0 +1,166 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [
+ {"num_tokens": 16, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 128, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 490, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 525, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 2048, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 4096, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 8192, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ {"num_tokens": 32768, "num_expert": 32, "topk": 5, "start_expert_id": 0,
+ "expert_size": 32, "has_residual": False, "hidden_size": 8192,
+ "dtype": [torch.bfloat16]},
+ ]
+
+def gen_case(num_tokens,
+ topk,
+ hidden_size,
+ num_expert,
+ expert_size,
+ has_bias,
+ has_residual,
+ dtype,
+ device):
+ input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device)
+ reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device)
+ gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device)
+ bias = None
+ residual = None
+
+ cusum_token_count = None
+
+ if has_bias:
+ bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device)
+ if has_residual:
+ residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
+
+ if has_bias or expert_size < num_expert:
+ cusum_token_count, _ = generate_token_count(num_expert, num_tokens * topk)
+ cusum_token_count = cusum_token_count.to(device=device)
+ return input, reduce_weight, gather_ids, residual, bias, cusum_token_count
+
+
+def get_io_bytes(num_tokens,
+ topk,
+ hidden_size,
+ num_expert,
+ expert_size,
+ start_expert_id,
+ has_bias,
+ has_residual,
+ dtype,
+ cusum_token_count,
+ gather_ids):
+ io_bytes = 0
+ dtype_size = 4 if dtype is torch.float32 else 2
+ if cusum_token_count is not None:
+ filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \
+ (gather_ids < cusum_token_count[start_expert_id + expert_size])
+ filtered_ids = filtered_ids.to(dtype=torch.float32)
+ io_bytes += torch.sum(filtered_ids).item() * hidden_size * dtype_size
+ else:
+ io_bytes += num_tokens * topk * hidden_size * dtype_size
+
+ if has_bias:
+ io_bytes += expert_size * hidden_size * dtype_size
+
+ if has_residual:
+ io_bytes += num_tokens * hidden_size * dtype_size
+
+ io_bytes += num_tokens * topk * 4
+ io_bytes += num_tokens * hidden_size * dtype_size
+ return io_bytes
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["num_tokens", "num_expert", "topk", "start_expert_id", "expert_size", \
+ "hidden_size", "has_residual", "dtype", "hardware_time(us)", "e2e_latency(us)", "io_coeff"]
+ contents = []
+ bandwidth = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ num_tokens = params_dict["num_tokens"]
+ num_expert = params_dict["num_expert"]
+ topk = params_dict["topk"]
+ start_expert_id = params_dict["start_expert_id"]
+ expert_size = params_dict["expert_size"]
+ has_residual = params_dict["has_residual"]
+ hidden_size = params_dict["hidden_size"]
+ dtype_list = params_dict["dtype"]
+ for dtype in dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ inputs = gen_case(num_tokens,
+ topk,
+ hidden_size,
+ num_expert,
+ expert_size,
+ False,
+ has_residual,
+ dtype,
+ device)
+ input = inputs[0]
+ reduce_weight = inputs[1]
+ gather_ids = inputs[2]
+ residual = inputs[3]
+ bias = inputs[4]
+ cusum_token_count = inputs[5]
+
+ io_bytes = get_io_bytes(num_tokens,
+ topk,
+ hidden_size,
+ num_expert,
+ expert_size,
+ start_expert_id,
+ False,
+ has_residual,
+ dtype,
+ cusum_token_count,
+ gather_ids)
+
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_combine_result, input, reduce_weight,
+ gather_ids,residual, cusum_token_count,
+ start_expert_id, expert_size,
+ repeats=args.repeat_times)
+ io_coeff = io_bytes / hardware_time / bandwidth
+ content = [f"{num_tokens}", f"{num_expert}", f"{topk}", f"{start_expert_id}", \
+ f"{expert_size}", f"{hidden_size}", f"{has_residual}", f"{dtype}", \
+ f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_expand_input.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_expand_input.py
new file mode 100644
index 0000000..2e4cee0
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_expand_input.py
@@ -0,0 +1,93 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import numpy as np
+
+e2e_time_param_dict_list = [{"token_num": 1, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 16, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 32, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 64, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 128, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 512, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 1024, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 4096, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 8192, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
+ {"token_num": 32768, "hidden_size": 4096, "expert_num": 32, "topk": 5,
+ "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}]
+
+def gen_tensor(token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype):
+ input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
+ gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,)).to(torch.int32).to('mlu')
+ cusum_token_count, _ = generate_token_count(expert_num, token_num * topk)
+ cusum_token_count = cusum_token_count.to('mlu')
+ use_all_experts = expert_num == expert_size
+ if use_all_experts:
+ cusum_token_count = None
+ real_token_count = token_num * topk
+ else:
+ real_token_count = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
+ return input, gather_idx, cusum_token_count, real_token_count
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+
+ titles = ["token_num", "hidden_size", "expert_num", "topk", "start_expert_id", "expert_size", "input_dtype",
+ "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ token_num = params_dict["token_num"]
+ hidden_size = params_dict["hidden_size"]
+ expert_num = params_dict["expert_num"]
+ topk = params_dict["topk"]
+ start_expert_id = params_dict["start_expert_id"]
+ expert_size = params_dict["expert_size"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input, gather_idx, cusum_token_count, real_token_count = \
+ gen_tensor(token_num, hidden_size, expert_num,topk, start_expert_id, expert_size, dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_expand_input,
+ input,
+ gather_idx,
+ cusum_token_count,
+ start_expert_id,
+ expert_size,
+ repeats=args.repeat_times)
+ io_bytes = input.element_size() * input.nelement() + \
+ gather_idx.element_size() * gather_idx.nelement() + \
+ (cusum_token_count.element_size() * cusum_token_count.nelement() if cusum_token_count is not None else 0) + \
+ real_token_count * input.element_size()
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{token_num}", f"{hidden_size}", f"{expert_num}", f"{topk}", f"{start_expert_id}", f"{expert_size}",
+ f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__ == "__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py
new file mode 100644
index 0000000..0900741
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py
@@ -0,0 +1,69 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"token_num": 1, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 16, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 32, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 64, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 512, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 1024, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 4096, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 8192, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 32767, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
+ {"token_num": 1, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 16, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 32, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 64, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 512, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 1024, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 4096, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 8192, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
+ {"token_num": 32767, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}]
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+
+ titles = ["token_num", "expert_num", "topk", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ token_num = params_dict["token_num"]
+ expert_num = params_dict["expert_num"]
+ topk = params_dict["topk"]
+ dtype = params_dict["input_dtype"]
+ expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
+ gather_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
+ combine_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
+ token_count = torch.empty((expert_num), dtype=dtype, device='mlu')
+ cusum_token_count = torch.empty((expert_num + 1), dtype=dtype, device='mlu')
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_gen_idx,
+ expert_id,
+ expert_num,
+ repeats=args.repeat_times)
+ io_bytes = expert_id.element_size() * expert_id.nelement() + \
+ gather_idx.element_size() * gather_idx.nelement() + \
+ combine_idx.element_size() * combine_idx.nelement() + \
+ token_count.element_size() * token_count.nelement() + \
+ cusum_token_count.element_size() * cusum_token_count.nelement()
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{token_num}", f"{expert_num}", f"{topk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__ == "__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py
new file mode 100644
index 0000000..2aa5279
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py
@@ -0,0 +1,114 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+params_dict = [
+ {"token_num": 1, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 16, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 128, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 490, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 512, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 525, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 2048, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 4096, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 8192, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+ {"token_num": 32768, "hidden_size": 8192, "expert_num": 32, "topk": 5,
+ "has_gather_idx": True, "dtype": torch.bfloat16},
+
+ {"token_num": 1, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 16, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 128, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 490, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 512, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 525, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 2048, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 4096, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 8192, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ {"token_num": 32768, "hidden_size": 1024, "expert_num": 32, "topk": 5,
+ "has_gather_idx": False, "dtype": torch.bfloat16},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["token_num", "hidden_size", "expert_num", "topk", "has_gather_idx", "dtype",
+ "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+ for param in params_dict:
+ token_num, hidden_size, expert_num, topk, has_gather_idx, dtype = param.values()
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ if "MLU3" in torch.mlu.get_device_name():
+ has_gather_idx = False
+ expand_token_num = token_num * topk
+ input_shape = (token_num if has_gather_idx else expand_token_num, hidden_size)
+ input = torch.randn(input_shape).to(device).to(dtype)
+ scale = torch.randn(expert_num, hidden_size).to(device).to(torch.float32)
+
+ avg, rem = expand_token_num // expert_num, expand_token_num % expert_num
+ m_list = [avg + (i < rem) for i in range(expert_num)]
+ token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
+
+ if has_gather_idx:
+ gather_idx = torch.arange(0, token_num).repeat([topk])
+ gather_idx = gather_idx[torch.randperm(gather_idx.size(0))].to(torch.int32).mlu()
+ else:
+ gather_idx = None
+
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_quantize,
+ input,
+ scale,
+ None,
+ token_count,
+ gather_idx,
+ None,
+ None,
+ None,
+ True,
+ repeats=args.repeat_times)
+ expand_num = topk if has_gather_idx else 1
+ io_bytes = (input.element_size() + 1) * input.nelement() * expand_num
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{token_num}", f"{hidden_size}", f"{expert_num}",
+ f"{topk}", f"{has_gather_idx}", f"{dtype}",
+ f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_softmax_topk.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_softmax_topk.py
new file mode 100644
index 0000000..1000e58
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_softmax_topk.py
@@ -0,0 +1,87 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 32, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 72, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 8192, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 32768, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 2, "seq_len": 16, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 2, "seq_len": 36, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 8, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 16, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 4, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 2, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 16, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 1, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 16, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 64, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 1024, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 2048, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 8192, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ {"num_batch": 1, "seq_len": 32768, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["num_batch", "seq_len", "num_expert", "topk", "num_expert_group", "topk_group", "normalize", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bd = get_band_width()
+
+ for params_dict in e2e_time_param_dict_list:
+ num_batch = params_dict["num_batch"]
+ seq_len = params_dict["seq_len"]
+ num_expert = params_dict["num_expert"]
+ topk = params_dict["topk"]
+ num_expert_group = params_dict["num_expert_group"]
+ topk_group = params_dict["topk_group"]
+ normalize = params_dict["normalize"]
+ dtype = params_dict["input_dtype"]
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ input = torch.randn(num_batch, seq_len, num_expert, dtype=dtype, device='mlu')
+ mask = torch.randint(0, 2, (1, seq_len, num_expert), dtype = dtype, device='mlu')
+ if num_expert_group > 1:
+ mask = None
+ normed_by = "softmax_logit"
+ reduce_weight = torch.empty(num_batch, topk, dtype=torch.float, device='mlu')
+ expert_id = torch.empty(num_batch, topk, dtype=torch.int32, device='mlu')
+ hardware_time, e2e_time = benchmark_forward(tmo.moe_softmax_topk,
+ input,
+ topk,
+ normalize,
+ num_expert_group,
+ topk_group,
+ mask,
+ normed_by,
+ repeats=args.repeat_times)
+ io_bytes = input.element_size() * input.nelement() + \
+ reduce_weight.element_size() * reduce_weight.nelement() + \
+ expert_id.element_size() * expert_id.nelement()
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{num_batch}", f"{seq_len}", f"{num_expert}", f"{topk}", f"{num_expert_group}", f"{topk_group}", f"{normalize}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_linear_cache.py
new file mode 100644
index 0000000..af6dd74
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_linear_cache.py
@@ -0,0 +1,145 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
+ {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
+ {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"},
+ {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"}
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "quantize_mode", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ max_batch = params_dict["max_batch"]
+ batch = params_dict["batch"]
+ cache_mem_len = params_dict["cache_mem_len"]
+ max_context_len = params_dict["max_context_len"]
+ max_seq_offset = params_dict["max_seq_offset"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ packed = params_dict["packed"]
+ quantize_mode = params_dict["quantize_mode"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ context_lens = torch.randint(size=(batch, ), low=max_context_len,
+ high=max_context_len+1,
+ dtype=torch.int32, device='mlu')
+ # max_seq_offset = max_context_len // 3 + 1
+ context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
+ dtype=torch.int32, device='mlu')
+ cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
+ high=(cache_mem_len - max_context_len) // 3 + 1,
+ dtype=torch.int32, device='mlu')
+
+ cu_context_lens = torch.cumsum(context_lens, dim=-1)
+ cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
+ total_seqlen = cu_context_lens[-1]
+ total_heads = head_num_q + 2 * head_num_kv
+ if packed > 0:
+ context = torch.randn((total_seqlen, total_heads, head_size),
+ dtype=torch.float, device='mlu')
+ else:
+ context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
+ dtype=torch.float, device='mlu')
+ cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
+ context = context.to(dtype)
+ cache = cache.to(dtype)
+ key = context[..., head_num_q : head_num_q + head_num_kv, :]
+ value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
+ key_cache = cache[0]
+ value_cache = cache[1]
+
+ cache_bs_id = None
+ cache_bs_id = random.sample([*range(0, max_batch)], batch)
+ cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
+ key_cache = (key_cache - 0.5) * 256
+ value_cache = (value_cache - 0.5) * 256
+ key_cache = key_cache.to(torch.int8)
+ value_cache = value_cache.to(torch.int8)
+ if packed > 0:
+ if quantize_mode == "per_channel":
+ key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
+ value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
+ hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ key_cache_quantize_scale,
+ value_cache_quantize_scale,
+ cu_context_lens, max_context_len, 0,
+ packed > 0, None,
+ cache_bs_id, cache_seq_offsets,
+ repeats=args.repeat_times)
+ elif quantize_mode == "per_head":
+ key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
+ value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
+ hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ key_cache_quantize_scale,
+ value_cache_quantize_scale,
+ cu_context_lens, max_context_len, 1,
+ packed > 0, None,
+ cache_bs_id, cache_seq_offsets,
+ repeats=args.repeat_times)
+ else:
+ if quantize_mode == "per_channel":
+ key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
+ value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
+ hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ key_cache_quantize_scale,
+ value_cache_quantize_scale,
+ context_lens, max_context_len, 0,
+ packed > 0, context_seq_offsets,
+ cache_bs_id, cache_seq_offsets,
+ repeats=args.repeat_times)
+ elif quantize_mode == "per_head":
+ key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
+ value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
+ hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ key_cache_quantize_scale,
+ value_cache_quantize_scale,
+ context_lens, max_context_len, 1,
+ packed > 0, context_seq_offsets,
+ cache_bs_id, cache_seq_offsets,
+ repeats=args.repeat_times)
+
+ content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{quantize_mode}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_paged_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_paged_cache.py
new file mode 100644
index 0000000..6ed47e9
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_paged_cache.py
@@ -0,0 +1,73 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from itertools import product
+from tabulate import tabulate
+import os
+import random
+
+
+e2e_time_param_dict_list = [
+ {"token_num": [1024, 2048, 3072, 4096], "head_num_kv": 1, "head_size": 128, "block_size": 16,
+ "input_dtype": [torch.float16, torch.bfloat16]},
+ {"token_num": [1024 * 32, 2048 * 32, 3072 * 32, 4096 * 32], "head_num_kv": 1, "head_size": 128, "block_size": 16,
+ "input_dtype": [torch.float16, torch.bfloat16]},
+ {"token_num": [1024 * 64, 2048 * 64, 3072 * 64, 4096 * 64], "head_num_kv": 1, "head_size": 128, "block_size": 16,
+ "input_dtype": [torch.float16, torch.bfloat16]},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+
+ args = parser.parse_args()
+ titles = ["token_num", "head_num_kv", "head_size", "block_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ if "MLU3" in torch.mlu.get_device_name():
+ print("Op offline_quant_to_paged_cache does not support MLU300 devices.")
+ return
+ for params_dict in e2e_time_param_dict_list:
+ token_num_list = params_dict["token_num"]
+ # block_num = params_dict["block_num"]
+ block_size = params_dict["block_size"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ input_dtype_list = params_dict["input_dtype"]
+ for token_num, dtype in product(token_num_list, input_dtype_list):
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ block_num = (token_num + block_size - 1) // block_size
+ key = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
+ value = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
+ key_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
+ value_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
+ num_slots = block_num * block_size
+ slot_mapping = random.sample(range(num_slots), token_num)
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu")
+ key_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
+ value_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
+ hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_paged_cache,
+ key, value,
+ key_cache_scale, value_cache_scale,
+ slot_mapping,
+ key_cache, value_cache,
+ repeats=args.repeat_times)
+
+
+ content = [f"{token_num}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_per_token_smooth_quantize.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_per_token_smooth_quantize.py
new file mode 100644
index 0000000..8043f28
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_per_token_smooth_quantize.py
@@ -0,0 +1,61 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import random
+import csv
+
+params_dict = {
+ "token_num": [n * 5 for n in [1, 72, 512, 1024, 4096, 32768]],
+ "hidden_size": [1024, 8192],
+ "input_dtype": [torch.float16]
+ }
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ params_list = product(params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
+ bd = get_band_width()
+ for params in params_list:
+ token_num, hidden_size = params[0], params[1]
+ input_shape = (token_num, hidden_size)
+ smooth_shape = (hidden_size)
+ dtype = params[2]
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input = torch.randn(input_shape).to(device).to(dtype)
+ smooth = torch.randn(smooth_shape).to(device).to(torch.float32)
+ zero = None
+ token_count = None
+ hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
+ input,
+ smooth,
+ zero,
+ token_count,
+ repeats=args.repeat_times)
+ io_bytes = (input.element_size() + 1) * input.nelement() + \
+ smooth.element_size() * smooth.nelement() + \
+ token_num * 4
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py
new file mode 100644
index 0000000..4261e9f
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py
@@ -0,0 +1,46 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"input_shape": [100, 100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"input_shape": [100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"input_shape": [50, 50, 50], "input_dtype": [torch.float16, torch.bfloat16]},
+ {"input_shape": [1, 100, 1000], "input_dtype": [torch.float16, torch.bfloat16]}
+ ]
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["input_shape", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ input_shape = params_dict["input_shape"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input = torch.randn(input_shape).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.preload,
+ input,
+ input.element_size() * input.numel(),
+ repeats=args.repeat_times)
+ content = [f"{input_shape}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quant_to_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quant_to_linear_cache.py
new file mode 100644
index 0000000..bfd70ba
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quant_to_linear_cache.py
@@ -0,0 +1,201 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import random
+import os
+
+e2e_time_param_dict_list = [
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
+ "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
+ "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["batch", "max_context_len", "head_num_kv", "head_size", "packed", "input_dtype",
+ "quant_bit", "group_size", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ bandwidth = get_band_width()
+ for params_dict in e2e_time_param_dict_list:
+ max_batch = params_dict["max_batch"]
+ batch = params_dict["batch"]
+ cache_mem_len = params_dict["cache_mem_len"]
+ max_context_len = params_dict["max_context_len"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ packed = params_dict["packed"]
+ input_dtype_list = params_dict["input_dtype"]
+ quant_bit = params_dict["quant_bit"]
+ group_size = params_dict["group_size"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ context_lens = torch.tensor([max_context_len] * batch).to(torch.int32).mlu()
+ context_seq_offsets = torch.zeros(batch, dtype=torch.int32, device='mlu')
+ cache_seq_offsets = torch.randint(size=(batch, ),
+ low=0,
+ high = 1 if max_context_len > 1 else cache_mem_len,
+ dtype=torch.int32, device='mlu')
+ cu_context_lens = torch.cumsum(context_lens, dim=-1)
+ cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
+ total_seqlen = cu_context_lens[-1]
+ if packed > 0:
+ key = torch.randn((total_seqlen, head_num_kv, head_size),
+ dtype=torch.float, device='mlu')
+ value = torch.randn((total_seqlen, head_num_kv, head_size),
+ dtype=torch.float, device='mlu')
+ else:
+ key = torch.randn((batch, max_context_len, head_num_kv, head_size),
+ dtype=torch.float, device='mlu')
+ value = torch.randn((batch, max_context_len, head_num_kv, head_size),
+ dtype=torch.float, device='mlu')
+ key = key.to(dtype)
+ value = value.to(dtype)
+ if quant_bit == 8 and group_size == head_size:
+ key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
+ value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
+ key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
+ value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
+ if quant_bit == 8 and group_size != head_size:
+ key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
+ value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
+ key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
+ value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
+ if quant_bit == 4 and group_size == head_size:
+ key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
+ value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
+ key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
+ value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
+ if quant_bit == 4 and group_size != head_size:
+ key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
+ value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
+ key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
+ value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
+
+ cache_bs_id = None
+ cache_bs_id = random.sample([*range(0, max_batch)], batch)
+ cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
+
+ if packed > 0:
+ hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ key_cache_scale,
+ value_cache_scale,
+ cu_context_lens, max_context_len,
+ packed > 0, None,
+ cache_bs_id, cache_seq_offsets,
+ quant_bit,
+ repeats=args.repeat_times)
+ else:
+ hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ key_cache_scale,
+ value_cache_scale,
+ context_lens, max_context_len,
+ packed > 0, context_seq_offsets,
+ cache_bs_id, cache_seq_offsets,
+ quant_bit,
+ repeats=args.repeat_times)
+
+ io_bytes = key.nelement() * (key.element_size() + 1) * 2
+ io_eff = io_bytes / hardware_time / bandwidth
+ content = [f"{batch}", f"{max_context_len}", f"{head_num_kv}", f"{head_size}", f"{packed}",
+ f"{dtype}", f"{quant_bit}", f"{group_size}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py
new file mode 100644
index 0000000..26dde4e
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py
@@ -0,0 +1,61 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+params_dict = {"dynamic": [True],
+ "token_num": [1, 72, 490, 512, 525, 1024, 4096, 8192, 32768],
+ "hidden_size": [8192, 1024],
+ "input_dtype": [torch.bfloat16]}
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+ titles = ["dynamic", "token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
+ contents = []
+ params_list = product(params_dict["dynamic"], params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
+ bd = get_band_width()
+ for param in params_list:
+ dynamic, token_num, hidden_size, dtype = param[0], param[1], param[2], param[3]
+ input_shape = (token_num, hidden_size)
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ dtype = torch.half
+ input = torch.randn(input_shape).to(device).to(dtype)
+ scale = torch.randn(input_shape[-1]).to(device).to(torch.float32)
+ zero = None
+ if dynamic:
+ hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
+ input,
+ scale,
+ zero,
+ None,
+ repeats=args.repeat_times)
+ else:
+ hardware_time, e2e_time = benchmark_forward(tmo.quantize,
+ input,
+ scale,
+ zero,
+ repeats=args.repeat_times)
+ io_bytes = (input.element_size() + 1) * input.nelement() + scale.element_size() * scale.nelement()
+ io_eff = io_bytes / hardware_time / bd
+ content = [f"{dynamic}", f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_linear_cache.py
new file mode 100644
index 0000000..4ef8551
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_linear_cache.py
@@ -0,0 +1,109 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
+ "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
+ "packed": False, "input_dtype": [torch.float16, torch.bfloat16]}
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ max_batch = params_dict["max_batch"]
+ batch = params_dict["batch"]
+ cache_mem_len = params_dict["cache_mem_len"]
+ max_context_len = params_dict["max_context_len"]
+ max_seq_offset = params_dict["max_seq_offset"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ packed = params_dict["packed"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ context_lens = torch.randint(size=(batch, ), low=max_context_len,
+ high=max_context_len+1,
+ dtype=torch.int32, device='mlu')
+ # max_seq_offset = max_context_len // 3 + 1
+ context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
+ dtype=torch.int32, device='mlu')
+ cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
+ high=(cache_mem_len - max_context_len) // 3 + 1,
+ dtype=torch.int32, device='mlu')
+
+ cu_context_lens = torch.cumsum(context_lens, dim=-1)
+ cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
+ total_seqlen = cu_context_lens[-1]
+ total_heads = head_num_q + 2 * head_num_kv
+ if packed > 0:
+ context = torch.randn((total_seqlen, total_heads, head_size),
+ dtype=torch.float, device='mlu')
+ else:
+ context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
+ dtype=torch.float, device='mlu')
+ cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
+ context = context.to(dtype)
+ cache = cache.to(dtype)
+ key = context[..., head_num_q : head_num_q + head_num_kv, :]
+ value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
+ key_cache = cache[0]
+ value_cache = cache[1]
+
+ cache_bs_id = None
+ cache_bs_id = random.sample([*range(0, max_batch)], batch)
+ cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
+
+ if packed > 0:
+ hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ cu_context_lens, max_context_len,
+ packed > 0, None,
+ cache_bs_id, cache_seq_offsets,
+ repeats=args.repeat_times)
+
+ else:
+ hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
+ key, value,
+ key_cache, value_cache,
+ context_lens, max_context_len,
+ packed > 0, context_seq_offsets,
+ cache_bs_id, cache_seq_offsets,
+ repeats=args.repeat_times)
+
+ content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_paged_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_paged_cache.py
new file mode 100644
index 0000000..f0dff89
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_paged_cache.py
@@ -0,0 +1,76 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
+ "head_num_kv": 32, "head_size": 128, "quantize": True, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
+ "head_num_kv": 32, "head_size": 128, "quantize": False, "input_dtype": [torch.float16, torch.bfloat16]}
+ ]
+
+def main():
+ if 'MLU3' in torch.mlu.get_device_name():
+ exit()
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ titles = ["num_tokens", "num_block", "block_size", "head_num_q", "head_num_kv", "head_size", "input_dytpe", "quantize", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ num_tokens = params_dict["num_tokens"]
+ num_blocks = params_dict["num_block"]
+ block_size = params_dict["block_size"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ quantize = params_dict["quantize"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ qkv = torch.randn(num_tokens, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
+ key = qkv[:, head_num_q : head_num_q + head_num_kv, :]
+ value = qkv[:, head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
+
+ key_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
+ value_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
+
+ num_slots = num_blocks * block_size
+ slot_mapping = random.sample(range(num_slots), num_tokens)
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
+ slot_mapping[-1] = -1
+ if not quantize:
+ hardware_time, e2e_time = benchmark_forward(tmo.reshape_paged_cache,
+ key, value,
+ key_cache, value_cache,
+ slot_mapping,
+ repeats=args.repeat_times)
+ else:
+ k_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
+ v_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
+ hardware_time, e2e_time = benchmark_forward(tmo.quant_to_paged_cache,
+ key, value,
+ key_cache, value_cache,
+ k_cache_quant_scale,
+ v_cache_quant_scale,
+ slot_mapping,
+ repeats=args.repeat_times)
+
+ content = [f"{num_tokens}", f"{num_blocks}", f"{block_size}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{dtype}", f"{quantize}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_cached_kv_attn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_cached_kv_attn.py
new file mode 100644
index 0000000..de49901
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_cached_kv_attn.py
@@ -0,0 +1,116 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import math
+import random
+
+e2e_time_param_dict_list = [
+ {"batch": 16, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
+ "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
+ "is_pertoken": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 128, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
+ "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
+ "is_pertoken": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 512, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
+ "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
+ "is_pertoken": False, "input_dtype": [torch.bfloat16]},
+
+ {"batch": 16, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
+ "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
+ "is_pertoken": False, "input_dtype": [torch.bfloat16]},
+ {"batch": 128, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
+ "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
+ "is_pertoken": False, "input_dtype": [torch.bfloat16]},
+ ]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "max_seq_len", "head_num_q", "head_num_kv", "head_size", "block_size", "alibi_bias", "kv_cache_dtype", "use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ max_seq_len = params_dict["max_seq_len"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ block_size = params_dict["block_size"]
+ alibi_bias = params_dict["alibi_bias"]
+ kv_cache_dtype = params_dict["kv_cache_dtype"]
+ use_paged_attn = params_dict["use_paged_attn"]
+ input_dtype_list = params_dict["input_dtype"]
+ is_pertoken = params_dict["is_pertoken"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
+ input_q = input_qkv[..., 0 : head_num_q, :]
+
+ context_lens = torch.randint(max_seq_len, max_seq_len + 1, (batch, ), dtype=torch.int32).to(device)
+ max_context_len = int(max(context_lens))
+ if use_paged_attn:
+ mlu_name = torch.mlu.get_device_name()
+ if "MLU3" in mlu_name:
+ print("pagedattn is not implement on mlu370, skip it")
+ continue
+ block_size = 16
+ else:
+ block_size = max_seq_len + 512
+ num_blocks = batch * ((max_seq_len + block_size - 1) // block_size)
+ max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
+ cache_shape = (num_blocks, head_num_kv, block_size, head_size)
+ scale_shape = (num_blocks, head_num_kv, block_size) if is_pertoken else (head_num_kv, head_size)
+ block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
+ block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
+ if kv_cache_dtype is not torch.int8:
+ key_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
+ value_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
+ key_cache_scale = None
+ value_cache_scale = None
+ else:
+ key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
+ value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
+ key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
+ value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
+ alibi_slopes = None
+ if alibi_bias:
+ alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
+ alibi_slopes.uniform_(0, 0.125)
+ softmax_scale = 1 / math.sqrt(head_size)
+ hardware_time, e2e_time = benchmark_forward(tmo.single_query_cached_kv_attn,
+ input_q,
+ key_cache,
+ value_cache,
+ None,
+ block_tables,
+ context_lens,
+ key_cache_scale,
+ value_cache_scale,
+ alibi_slopes,
+ max_context_len,
+ -1,
+ -1,
+ softmax_scale,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{max_seq_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{alibi_bias}", f"{kv_cache_dtype}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_mixed_cached_kv_attn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_mixed_cached_kv_attn.py
new file mode 100644
index 0000000..6ec0a06
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_mixed_cached_kv_attn.py
@@ -0,0 +1,131 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import *
+import argparse
+from tabulate import tabulate
+import os
+import math
+import random
+
+e2e_time_param_dict_list = [{"batch": 16, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
+ "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
+ "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
+ "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
+ "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 512, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
+ "head_size": 128, "alibi_bias": True, "quant_bit_lp": 4, "quant_bit_hp": 8,
+ "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 16, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
+ "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
+ "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"batch": 128, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
+ "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
+ "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def gen_cache(batch, num_kv_heads, head_size, is_pagedattn, max_context_len, data_type, quant_bit, quant_mode):
+ int_max = float(2 ** (quant_bit - 1) - 1)
+ int_min = -float(2 ** (quant_bit - 1))
+ context_lens = torch.randint(max_context_len, max_context_len + 1, (batch, ), dtype=torch.int32).mlu()
+ block_size = 16
+ if is_pagedattn is False:
+ block_size = max_context_len
+ num_blocks = (int)(batch * ((max_context_len + block_size - 1)/ block_size))
+ max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
+ cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
+ if quant_mode == "per_token":
+ scale_shape = (num_blocks, num_kv_heads, block_size, 1)
+ else: # per channel
+ scale_shape = (num_kv_heads, head_size)
+
+ if quant_bit == 4:
+ cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2)
+ cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size)
+ key_cache = torch.zeros(cache_shape_k_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
+ value_cache = torch.zeros(cache_shape_v_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
+ key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
+ value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
+ elif quant_bit == 8:
+ key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
+ value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
+ key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
+ value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
+ elif quant_bit == -1:
+ key_cache = torch.randn(cache_shape, dtype=data_type).mlu()
+ value_cache = torch.randn(cache_shape, dtype=data_type).mlu()
+ key_cache_scale = None
+ value_cache_scale = None
+ else:
+ print("!!!!!!!!!!!gen case error, quant_bit must be in {-1, 4, 8}")
+ block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
+ block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
+ return key_cache, value_cache, key_cache_scale, value_cache_scale, context_lens, block_tables
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "max_seq_len_lp", "max_seq_len_hp", "head_num_q", "head_num_kv", "head_size", "alibi_bias", "quant_bit_lp", "quant_bit_hp","use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ max_seq_len_lp = params_dict["max_seq_len_lp"]
+ max_seq_len_hp = params_dict["max_seq_len_hp"]
+ head_num_q = params_dict["head_num_q"]
+ head_num_kv = params_dict["head_num_kv"]
+ head_size = params_dict["head_size"]
+ alibi_bias = params_dict["alibi_bias"]
+ quant_bit_lp = params_dict["quant_bit_lp"]
+ quant_bit_hp = params_dict["quant_bit_hp"]
+ use_paged_attn = params_dict["use_paged_attn"]
+ input_dtype_list = params_dict["input_dtype"]
+
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
+ input_q = input_qkv[..., 0 : head_num_q, :]
+
+ params_lp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_lp, dtype, quant_bit_lp, "per_token")
+ params_hp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_hp, dtype, quant_bit_hp, "per_channel")
+ key_cache_lp, value_cache_lp, key_cache_scale_lp, value_cache_scale_lp, context_lens_lp, block_tables_lp = params_lp
+ key_cache_hp, value_cache_hp, key_cache_scale_hp, value_cache_scale_hp, context_lens_hp, block_tables_hp = params_hp
+
+ alibi_slopes = None
+ if alibi_bias:
+ alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
+ alibi_slopes.uniform_(0, 0.125)
+ softmax_scale = 1 / math.sqrt(head_size)
+ hardware_time, e2e_time = benchmark_forward(tmo.single_query_mixed_cached_kv_attn,
+ input_q,
+ key_cache_lp, value_cache_lp,
+ key_cache_hp, value_cache_hp,
+ None, #output
+ block_tables_lp, block_tables_hp,
+ context_lens_lp, context_lens_hp,
+ key_cache_scale_lp, value_cache_scale_lp,
+ key_cache_scale_hp, value_cache_scale_hp,
+ alibi_slopes,
+ max_seq_len_lp, max_seq_len_hp,
+ softmax_scale, True,
+ quant_bit_lp, quant_bit_hp,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{max_seq_len_lp}", f"{max_seq_len_hp}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{alibi_bias}", f"{quant_bit_lp}", f"{quant_bit_hp}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_smooth_quant_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_smooth_quant_matmul.py
new file mode 100644
index 0000000..7c7e427
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_smooth_quant_matmul.py
@@ -0,0 +1,69 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
+ "act_mode": "none", "output_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
+ "act_mode": "silu", "output_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "output_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ m = params_dict["m"]
+ k = params_dict["k"]
+ n = params_dict["n"]
+ has_c = params_dict["has_c"]
+ has_bias = params_dict["has_bias"]
+ act_mode = params_dict["act_mode"]
+ output_dtype_list = params_dict["output_dtype"]
+ for dtype in output_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ a = torch.randn(m, k).to(device).to(torch.int8)
+ b = torch.randn(n, k).to(device).to(torch.int8)
+ a_scale = torch.randn(m).to(device)
+ b_scale = torch.randn(n).to(device)
+ c = None
+ if has_c:
+ c = torch.randn(m, n).to(device).to(dtype)
+ bias = None
+ if has_bias:
+ bias = torch.randn(n).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_matmul,
+ a,
+ a_scale,
+ b,
+ b_scale,
+ dtype,
+ bias,
+ c,
+ act_mode,
+ 1.0,
+ 1.0,
+ repeats=args.repeat_times)
+ content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh
new file mode 100644
index 0000000..39f28e0
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+LOG_PATH=${LOG_PATH:-.}
+
+files=($(ls benchmark_*.py))
+for file in "${files[@]}"; do
+ echo "test ${file}..."
+ op_name=$(basename "$file" .py)
+ python "$file" > ${LOG_PATH}/${op_name}.log 2>&1
+ ret_tmp=$?
+ cat ${LOG_PATH}/${op_name}.log
+ if [ $ret_tmp != 0 ]; then
+ echo "${sc} test failed..."
+ exit $ret_tmp
+ fi
+done
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_update_out_and_lse.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_update_out_and_lse.py
new file mode 100644
index 0000000..29107b4
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_update_out_and_lse.py
@@ -0,0 +1,99 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+import random
+
+e2e_time_param_dict_list = [{"batch": 16, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
+ "dtype": [torch.float16], "pack": False},
+ {"batch": 128, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
+ "dtype": [torch.float16], "pack": False},
+ {"batch": 512, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
+ "dtype": [torch.float16], "pack": False},
+ {"batch": 1024, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
+ "dtype": [torch.float16], "pack": False},
+ {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 1024, "max_seq_len": 1024,
+ "dtype": [torch.float16], "pack": True},
+ {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 2048, "max_seq_len": 2048,
+ "dtype": [torch.float16], "pack": True},
+ {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 8192, "max_seq_len": 8192,
+ "dtype": [torch.float16], "pack": True},
+ {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 32768, "max_seq_len": 32768,
+ "dtype": [torch.float16], "pack": True},]
+
+def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack):
+ if not pack:
+ out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype)
+ lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
+ block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype)
+ block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
+ seq_offset = None
+ cu_seqs = None
+ block_cu_seqs = None
+ else:
+ seq_lens = torch.randint(low=max_seq_len, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32)
+ block_seq_lens = torch.randint(low=block_seq_len, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32)
+ block_seq_lens = torch.minimum(seq_lens, block_seq_lens)
+ seq_offset = torch.zeros_like(seq_lens)
+ for i in range(batch):
+ seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32)
+ seq_offset = seq_offset.mlu()
+ cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu()
+ block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu()
+ total_seqs = torch.sum(seq_lens)
+ block_total_seqs = torch.sum(block_seq_lens)
+
+ out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype)
+ lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
+ block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype)
+ block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
+ return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["batch", "head_num", "head_size", "block_seq_len", "max_seq_len", "dtype", "pack", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ batch = params_dict["batch"]
+ head_num = params_dict["head_num"]
+ head_size = params_dict["head_size"]
+ block_seq_len = params_dict["block_seq_len"]
+ max_seq_len = params_dict["max_seq_len"]
+ dtype_list = params_dict["dtype"]
+ pack = params_dict["pack"]
+ for dtype in dtype_list:
+ out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
+ gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
+
+ hardware_time, e2e_time = benchmark_forward(tmo.update_out_and_lse,
+ out,
+ lse,
+ block_out,
+ block_lse,
+ seq_offset,
+ cu_seqs,
+ block_cu_seqs,
+ repeats=args.repeat_times)
+ content = [f"{batch}", f"{head_num}", f"{head_size}", f"{block_seq_len}", f"{max_seq_len}", f"{dtype}", f"{pack}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_weight_only_quant_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_weight_only_quant_matmul.py
new file mode 100644
index 0000000..7236df5
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_weight_only_quant_matmul.py
@@ -0,0 +1,70 @@
+import torch
+import torch_mlu
+import torch_mlu_ops as tmo
+from common import benchmark_forward, save_to_csv, save_to_csv
+import argparse
+from tabulate import tabulate
+import os
+
+e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
+ "act_mode": "none", "quant_bit": 8, "input_dtype": [torch.float16, torch.bfloat16]},
+ {"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
+ "act_mode": "silu", "quant_bit": 4, "input_dtype": [torch.float16, torch.bfloat16]}]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
+ parser.add_argument('--csv', action='store_true', help='write the report data to csv')
+ parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
+
+ args = parser.parse_args()
+ device = 'mlu'
+
+ titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "quant_bit", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
+ contents = []
+ for params_dict in e2e_time_param_dict_list:
+ m = params_dict["m"]
+ k = params_dict["k"]
+ n = params_dict["n"]
+ quant_bit = params_dict["quant_bit"]
+ has_c = params_dict["has_c"]
+ has_bias = params_dict["has_bias"]
+ act_mode = params_dict["act_mode"]
+ input_dtype_list = params_dict["input_dtype"]
+ for dtype in input_dtype_list:
+ if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
+ continue
+ a = torch.randn(m, k).to(device).to(dtype)
+ b = torch.randn(n, k if quant_bit == 8 else k//2).to(device).to(torch.int8)
+ scale = torch.randn(n).to(device)
+ zero = None
+ c = None
+ if has_c:
+ c = torch.randn(m, n).to(device).to(dtype)
+ bias = None
+ if has_bias:
+ bias = torch.randn(n).to(device).to(dtype)
+ hardware_time, e2e_time = benchmark_forward(tmo.weight_only_quant_matmul,
+ a,
+ b,
+ scale,
+ zero,
+ bias,
+ c,
+ act_mode,
+ quant_bit,
+ 1.0,
+ 1.0,
+ repeats=args.repeat_times)
+ content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{quant_bit}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
+ contents.append(content)
+ table = [titles] + contents
+ print(tabulate(table, headers="firstrow", tablefmt="grid"))
+
+ if args.csv:
+ current_file_path = __file__
+ _, file_name = os.path.split(current_file_path)
+ save_to_csv(table, args.o, file_name)
+
+if __name__=="__main__":
+ main()
diff --git a/torch_mlu_ops-v1.3.2/benchmarks/common.py b/torch_mlu_ops-v1.3.2/benchmarks/common.py
new file mode 100644
index 0000000..940d88a
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/benchmarks/common.py
@@ -0,0 +1,65 @@
+import torch
+import torch_mlu
+import time
+from pathlib import Path
+import csv
+import os
+import subprocess
+from itertools import product
+
+def benchmark_forward(fn, *inputs, repeats=1, **kwinputs):
+ notify_start = torch.mlu.Event(enable_timing=True)
+ notify_end = torch.mlu.Event(enable_timing=True)
+ notify_start.record()
+ t0 = time.perf_counter()
+ for _ in range(repeats):
+ fn(*inputs, **kwinputs)
+ notify_end.record()
+ notify_end.synchronize()
+ total_e2e_time = time.perf_counter() - t0
+ average_e2e_time = total_e2e_time / repeats * 1e6
+ total_hardware_time = notify_start.hardware_time(notify_end)
+ average_hardware_time = total_hardware_time / repeats
+ return average_hardware_time, average_e2e_time
+
+def save_to_csv(table, file_path, file_name):
+ file_name_without_ext, _ = os.path.splitext(file_name)
+ new_file_name = file_name_without_ext + '.csv'
+ if file_path is None:
+ file_path = './'
+ path = Path(file_path)
+ if path.suffix:
+ directory = path.parent
+ filename = path.name
+ else:
+ directory = path
+ filename = new_file_name
+ if not directory.exists():
+ directory.mkdir(parents=True, exist_ok=True)
+ full_path = directory / filename
+ if not full_path.exists():
+ full_path.touch()
+ with open(full_path, mode="w", newline="") as file:
+ writer = csv.writer(file)
+ writer.writerows(table)
+ print(f"output saved at: {full_path}")
+
+def get_band_width(card_id: int = 0):
+ cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2"
+ res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ assert res.returncode == 0, "Failed to get BandWidth."
+ bd = int(res.stdout.decode().strip())
+ return bd
+
+def generate_token_count(num_expert,
+ total_token_count):
+ token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
+ sum = torch.sum(token_count, dim=-1) * 1.0
+ token_count *= total_token_count / sum.item()
+ token_count = token_count.to(dtype=torch.int32)
+ cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
+ end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
+ cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
+ cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
+ cusum_token_count[-1] = total_token_count
+ return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]
diff --git a/torch_mlu_ops-v1.3.2/build.property b/torch_mlu_ops-v1.3.2/build.property
new file mode 100644
index 0000000..ccad5aa
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/build.property
@@ -0,0 +1,8 @@
+TMO_VERSION=1.3.2-1
+CNCL_VERSION=1.24.1-1
+CNNL_VERSION=1.28.3-1
+CNNLEXTRA_VERSION=1.12.3-1
+CNTOOLKIT_VERSION=3.15.6-1
+MLUOPS_VERSION=1.4.2-1
+TORCH_VERSION=1.24.1-1
+TRITON_VERSION=1.3.1-1
diff --git a/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp
new file mode 100644
index 0000000..b292474
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp
@@ -0,0 +1,66 @@
+/*************************************************************************
+ * Copyright (C) [2023-2024] by Cambricon, Inc.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifdef __GNUC__
+
+#include "stack_exception.h"
+#include
+#include
+#include
+#include
+#include
+
+#define MAX_DEPTH 32
+
+namespace tmo {
+namespace stack_exception {
+
+call_stack::call_stack(const size_t num_discard /*= 0*/) {
+ using namespace abi;
+
+ // retrieve call-stack
+ void *trace[MAX_DEPTH];
+ int stack_depth = backtrace(trace, MAX_DEPTH);
+
+ for (int i = num_discard + 1; i < stack_depth; i++) {
+ Dl_info dlinfo;
+ if (!dladdr(trace[i], &dlinfo)) break;
+
+ const char *symname = dlinfo.dli_sname;
+
+ int status;
+ char *demangled = abi::__cxa_demangle(symname, NULL, 0, &status);
+ if (status == 0 && demangled) symname = demangled;
+
+ // printf("entry: %s, %s\n", dlinfo.dli_fname,symname);
+
+ // store entry to stack
+ if (dlinfo.dli_fname && symname) {
+ entry e;
+ e.file = dlinfo.dli_fname;
+ e.line = 0; // unsupported
+ e.function = symname;
+ stack.push_back(e);
+ } else {
+ break; // skip last entries below main
+ }
+
+ if (demangled) free(demangled);
+ }
+}
+
+call_stack::~call_stack() throw() {
+ // automatic cleanup
+}
+} // namespace stack_exception
+} // namespace tmo
+
+#endif // __GNUC__
diff --git a/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h
new file mode 100644
index 0000000..2f37c73
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h
@@ -0,0 +1,127 @@
+/*************************************************************************
+ * Copyright (C) [2023-2024] by Cambricon, Inc.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef CSRC_COMMON_STACK_EXCEPTION_H_
+#define CSRC_COMMON_STACK_EXCEPTION_H_
+
+#include
+#include
+#include
+#include
+
+namespace tmo {
+namespace stack_exception {
+/** Call-stack entry datastructure. */
+struct entry {
+ /** Default constructor that clears all fields. */
+ entry() : line(0) {}
+
+ std::string file; ///< filename
+ size_t line; ///< line number
+ std::string function; ///< name of function or method
+
+ /** Serialize entry into a text string. */
+ std::string to_string() const {
+ std::ostringstream os;
+ os << file;
+ if (line > 0) {
+ os << ":" << line;
+ }
+ os << " @" << function;
+ return os.str();
+ }
+};
+
+/** Stack-trace base class, for retrieving the current call-stack. */
+class call_stack {
+ public:
+ /** Stack-trace consructor.
+ \param num_discard - number of stack entries to discard at the top. */
+ call_stack(const size_t num_discard = 0);
+
+ virtual ~call_stack() throw();
+
+ /** Serializes the entire call-stack into a text string. */
+ std::string to_string() const {
+ std::ostringstream os;
+ for (size_t i = 0; i < stack.size(); i++)
+ os << stack[i].to_string() << std::endl;
+ return os.str();
+ }
+
+ /** Call stack. */
+ std::vector stack;
+};
+
+/** Abstract base-class for all stack-augmented exception classes.
+ * Enables catching of all stack-augmented exception classes. */
+class stack_exception_base : public call_stack {
+ public:
+ stack_exception_base(const bool _show_stack) : call_stack(2), show_stack(_show_stack) {}
+ virtual ~stack_exception_base() throw() {}
+
+ virtual const char *what() const throw() = 0;
+
+ /// flag to indicate if stack-trace is included in what() messages
+ bool show_stack;
+};
+
+/** Template for stack-augmented exception classes. */
+template
+class stack_exception : public T, public stack_exception_base {
+ public:
+ stack_exception(const std::string &msg) : T(msg), stack_exception_base(true) {}
+ virtual ~stack_exception() throw() {}
+
+ stack_exception(const char *file, int line, const char *pretty_function, const std::string &msg)
+ : T(msg), stack_exception_base(true) {
+ entry e;
+ e.file = file;
+ e.line = line;
+ e.function = pretty_function;
+ stack.insert(stack.begin(), e);
+ }
+
+ virtual const char *what() const throw() {
+ if (show_stack) {
+ // concatenate message with stack trace
+ buffer = "[" + std::string(T::what()) + "]\n" + stack_exception::to_string();
+ return buffer.c_str();
+ } else {
+ return T::what();
+ }
+ }
+
+ private:
+ mutable std::string buffer;
+};
+/** Stack-augmented exception classes for all std::exception classes. */
+// typedef stack_exception TmoException;
+// typedef stack_exception stack_range_error;
+// typedef stack_exception stack_overflow_error;
+// typedef stack_exception stack_underflow_error;
+// typedef stack_exception stack_logic_error;
+// typedef stack_exception stack_domain_error;
+// typedef stack_exception stack_invalid_argument;
+// typedef stack_exception stack_length_error;
+// typedef stack_exception stack_out_of_range;
+} // namespace stack_exception
+
+class _TmoException : public stack_exception::stack_exception {
+ public:
+ _TmoException(const std::string &msg) : stack_exception(msg) {}
+ _TmoException(const char *file, int line, const char *pretty_function, const std::string &msg)
+ : stack_exception(file, line, pretty_function, msg) {}
+};
+#define TmoException(msg) tmo::_TmoException(__FILE__, __LINE__, __PRETTY_FUNCTION__, msg)
+} // namespace tmo
+
+#endif // CSRC_COMMON_STACK_EXCEPTION_H_
diff --git a/torch_mlu_ops-v1.3.2/csrc/common/utils.h b/torch_mlu_ops-v1.3.2/csrc/common/utils.h
new file mode 100644
index 0000000..3c1e916
--- /dev/null
+++ b/torch_mlu_ops-v1.3.2/csrc/common/utils.h
@@ -0,0 +1,293 @@
+/*************************************************************************
+ * Copyright (C) [2023-2024] by Cambricon, Inc.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ *************************************************************************/
+#ifndef CSRC_COMMON_UTILS_H_
+#define CSRC_COMMON_UTILS_H_
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include // NOLINT
+#include
+#include
+#include