diff --git a/.DS_Store b/.DS_Store index df710b9..b0f9aa7 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/torch_mlu_ops-v1.3.2/.clang-format b/torch_mlu_ops-v1.3.2/.clang-format new file mode 100644 index 0000000..d1fd6de --- /dev/null +++ b/torch_mlu_ops-v1.3.2/.clang-format @@ -0,0 +1,4 @@ +BasedOnStyle: Chromium +ColumnLimit: 100 +PointerAlignment: Right +AllowShortIfStatementsOnASingleLine: true diff --git a/torch_mlu_ops-v1.3.2/.clang-tidy b/torch_mlu_ops-v1.3.2/.clang-tidy new file mode 100644 index 0000000..569ec47 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/.clang-tidy @@ -0,0 +1,15 @@ +Checks: "bugprone-*,google-*,-google-explicit-constructor,cppcoreguidelines-avoid-const-or-ref-data-members,cppcoreguidelines-init-variables,cppcoreguidelines-interfaces-global-init,cppcoreguidelines-misleading-capture-default-by-value,cppcoreguidelines-missing-std-forward,cppcoreguidelines-no-malloc,cppcoreguidelines-pro-type-member-init,cppcoreguidelines-rvalue-reference-param-not-moved,cppcoreguidelines-slicing,cppcoreguidelines-virtual-class-destructor,performance-unnecessary-copy-initialization", +CheckOptions: [ + { + key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion, + value: false + }, + { + key: cppcoreguidelines-narrowing-conversions.WarnOnIntegerToFloatingPointNarrowingConversion, + value: false + }, + { + key: cppcoreguidelines-narrowing-conversions.IgnoreConversionFromTypes, + value: size_t;ptrdiff_t;size_type;difference_type + } +] diff --git a/torch_mlu_ops-v1.3.2/.dockerignore b/torch_mlu_ops-v1.3.2/.dockerignore new file mode 100644 index 0000000..24149d9 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/.dockerignore @@ -0,0 +1,13 @@ +.cache/ +.devcontainer/ +build/ +__pycache__/ +.vscode/ +*.so +*.out +*.o +core_dump_* +record.txt +*.log +CMakePresets.json +.clangd diff --git a/torch_mlu_ops-v1.3.2/.gitattributes b/torch_mlu_ops-v1.3.2/.gitattributes new file mode 100644 index 0000000..13fcb4e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/.gitattributes @@ -0,0 +1,4 @@ +# git archive exclude files +tools/ci export-ignore +docker export-ignore +# legacy export-ignore diff --git a/torch_mlu_ops-v1.3.2/.gitignore b/torch_mlu_ops-v1.3.2/.gitignore new file mode 100644 index 0000000..dd75353 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/.gitignore @@ -0,0 +1,25 @@ +.cache/ +.devcontainer/ +build/ +packages/ +__pycache__/ +.idea/ +.vscode/ +*.so +*.out +*.o +*.pyc +*.deb +*.pt +core_dump_* +record.txt +*.log +CMakePresets.json +.clangd +compile_flags.txt +legacy/samples/pytorch/outputs/ +legacy/src/thirdparty/bangtransformer.egg-info/ +legacy/src/thirdparty/dist/ +torch_mlu_ops.egg-info/ +torch_mlu_ops/_version.py +dist/ diff --git a/torch_mlu_ops-v1.3.2/.leak.supp b/torch_mlu_ops-v1.3.2/.leak.supp new file mode 100644 index 0000000..b9c7dcf --- /dev/null +++ b/torch_mlu_ops-v1.3.2/.leak.supp @@ -0,0 +1,15 @@ +leak:libtorch_cpu.so +leak:libtorch_mlu.so +leak:libtorch_mlu_python.so +leak:libtriton.so +leak:libtorch_python.so +leak:libcatch_python.so +leak:pybind11::cpp_function::initialize_generic +leak:pybind11::cpp_function::make_function_record +leak:pybind11::detail::process_attribute +leak:libstdc++.so +leak:numpy +leak:Objects +leak:Modules +leak:Python +leak:libcrypto.so diff --git a/torch_mlu_ops-v1.3.2/LICENSE b/torch_mlu_ops-v1.3.2/LICENSE new file mode 100644 index 0000000..08e8fd0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2024, Cambricon Technologies +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/torch_mlu_ops-v1.3.2/README.md b/torch_mlu_ops-v1.3.2/README.md new file mode 100644 index 0000000..2e818f6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/README.md @@ -0,0 +1,183 @@ +
+ +Torch-MLU-Ops +=========================== + +![system](https://img.shields.io/badge/system-ubuntu22.04_debian10.11-blue.svg)![python](https://img.shields.io/badge/python-3.10-green)![pytorch](https://img.shields.io/badge/pytorch-2.1_2.4_2.5-green)![release](https://img.shields.io/badge/release-1.3.2-green.svg)![license](https://img.shields.io/badge/license-BSD_3-blue.svg) + +
+ + + 📖 Torch_MLU_Ops用户手册 + + +      + + + 🌏 寒武纪开发者社区 + + +      + + + 🛠️ 依赖组件获取 + + +
+ + +--- +
+ +## 目录 + +- [Torch-MLU-Ops](#torch-mlu-ops) + - [目录](#目录) + - [简介](#简介) + - [安装](#安装) + - [环境要求](#环境要求) + - [通过docker镜像启动](#通过docker镜像启动) + - [编译安装](#编译安装) + - [测试](#测试) + - [目录结构](#目录结构) + - [关键特性](#关键特性) + +## 简介 + +Torch-MLU-Ops是寒武纪设计和开发的PyTorch第三方算子库。对于使用PyTorch框架的开发者,通过Torch-MLU-Ops,能够便捷地使用这些自定义算子,进行算子的集成、评测和业务部署。 + +Torch-MLU-Ops已全量覆盖LLM(Large Language Model)推理场景下的常见算子。作为后端,已支持Cambricon vLLM、Cambricon TGI、Cambricon Stable Diffusion web UI、Cambricon ComfyUI以及Cambricon Diffusers。 + + +## 安装 + +### 环境要求 + +Torch-MLU-Ops支持的操作系统以及软件依赖如下。 + +不同平台支持的软件版本如下: + +| 操作系统 | 编译器版本 | CMake版本 | +| ------------------------| ----------------------| -------------------------- | +| Ubuntu22.04 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 | +| Debian10.11 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 | + + +编译、运行Torch-MLU-Ops同时还需要配置以下依赖。 + +| Torch-MLU-Ops | Torch-MLU | CNNL | CNNL_Extra | CNCL | CNToolkit | +| -----------------| ------------------| ---------| --------------- | --------| -----------| +| v1.3.2 | v1.24.1 | v1.28.3 | v1.12.3 | v1.24.1 | v3.15.6 | +| v1.3.1 | v1.24.1 | v1.28.3 | v1.12.2 | v1.24.1 | v3.15.5 | +| v1.3.0 | v1.24.0 | v1.28.2 | v1.12.1 | v1.24.0 | v3.15.4 | +| v1.2.3 | v1.24.0 | v1.28.1 | v1.12.1 | v1.24.0 | v3.15.3 | +| v1.2.2 | v1.23.1 | v1.27.4 | v1.11.3 | v1.23.0 | v3.14.3 | +| v1.2.1 | v1.23.1 | v1.27.3 | v1.11.3 | v1.22.3 | v3.14.2 | +| v1.2.0 | v1.23.0 | v1.27.1 | v1.11.1 | v1.22.1 | v3.14.1 | +| v1.1.4 | v1.22.2 | v1.26.7 | v1.10.4 | v1.21.1 | v3.13.7 | +| v1.1.3 | v1.22.2 | v1.26.6 | v1.10.3 | v1.21.1 | v3.13.5 | + + +此外运行Torch-MLU-Ops还依赖PyTorch环境。Torch-MLU-Ops已支持的PyTorch版本请参考。 + +| Torch-MLU-Ops | PyTorch | +| -----------------| ------------------| +| v1.3.2 | v2.1, v2.4, v2.5 | +| v1.3.1 | v2.1, v2.4, v2.5 | +| v1.3.0 | v2.1, v2.4, v2.5 | +| v1.2.3 | v2.1, v2.4, v2.5 | +| v1.2.2 | v2.1, v2.4, v2.5 | +| v1.2.1 | v2.1, v2.4, v2.5 | +| v1.2.0 | v2.1, v2.4, v2.5 | +| v1.1.4 | v2.1, v2.3, v2.4 | +| v1.1.3 | v2.1, v2.3 | + +### 通过docker镜像启动 + +寒武纪提供的解决方案镜像已提供了Torch-MLU-Ops所需要的依赖,请参考Cambricon PyTorch Container Image使用。您可以从[寒武纪开发者社区](https://account.cambricon.com/interaction/SMmmcGFmw837gghtZ7VR4)获取该镜像。 + +### 编译安装 + +Torch-MLU-Ops编译脚本依赖于`NEUWARE_HOME`环境变量,该环境变量指向寒武纪CNToolkit的安装路径,通常为`/usr/local/neuware`。 + +在运行示例程序前,您需要将`NEUWARE_HOME`中的`lib64`目录添加到`LD_LIBRARY_PATH`环境变量中,例如: + +```shell +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NEUWARE_HOME/lib64 +``` + +如果您使用寒武纪提供的docker镜像,以上环境变量均已经被设置好了。 + +设置好`NEUWARE_HOME`环境变量后,您可以使用以下命令编译Torch-MLU-Ops。 + +从远端clone仓库,进入项目的主目录: +```shell +cd torch_mlu_ops +``` + +源码安装: +```shell +pip install -e . +``` + +wheel包安装: +```shell +python setup.py bdist_wheel +pip install dist/torch_mlu_ops*.whl +``` + +> _注意,如果安装过程中遇到权限问题,请获取相关权限或切换到拥有权限的user执行pip install 这步操作。_ + +移除Torch-MLU-Ops(Optional): +```shell +pip uninstall torch_mlu_ops +``` + +### 测试 + +在测试前,请确保Torch-MLU-Ops已安装,您可以通过 `pip list` 命令查询Torch-MLU-Ops的安装情况。若有如下类似打印,则表明安装成功。 +```shell +torch-mlu-ops 1.1.0+pt21 +``` + +Torch-MLU-Ops模块的导入方法请参考。 +```python +import torch_mlu_ops as tmo +``` + +在`./tests/ops_pytest`目录下,提供了`ops`级别的测试示例程序,您可以参考如下命令进行测试。 +```shell +cd tests/ops_pytest/ +./run_test.sh +# or 单独测试某个测例,例如 +python test_flash_attention.py +``` +在`./tests/kernels_pytest`目录下,提供了`kernels`级别的测试示例程序,您可以参考如下命令进行测试。 +```shell +cd tests/kernels_pytest +./build.sh # 需要先编译 +./run_test.sh +``` +## 目录结构 +``` +torch_mlu_ops +|-- benchmarks ## 性能benchmarks测试代码目录 +|-- csrc ## C以及BangC源码目录 +| |-- common +| |-- kernels +| |-- ops +| `-- torch_api +|-- docs ## 文档目录 +| |-- release_notes +| `-- user_guide +|-- tests ## 测试代码目录,包含了kernel和ops级别的测试 +| |-- kernels_pytest +| `-- ops_pytest +|-- tools ## 工具目录,主要包含CI相关代码 +| |-- ci +`-- torch_mlu_ops ## PyTorch第三方算子接口实现相关代码 +``` + +## 关键特性 + +Torch-MLU-Ops提供的自定义算子列表以及接口说明请参考[Torch-MLU-Ops用户手册](./docs/user_guide/)。 diff --git a/torch_mlu_ops-v1.3.2/benchmarks/README.md b/torch_mlu_ops-v1.3.2/benchmarks/README.md new file mode 100644 index 0000000..4717d61 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/README.md @@ -0,0 +1,51 @@ +## benchmark测试脚本使用方式 + +Torch-MLU-Ops benchmark测试脚本为用户提供了进行算子性能测试的便捷入口。 +用户可通过以下命令获取各个参数的含义。 + +```bash +# 测试命令帮助 +python3 benchmark_xxx.py --help +``` +各个参数含义如下: + +`options`: +- -h, --help show this help message and exit +- --repeat_times REPEAT_TIMES repeat times for testing +- --csv write the report data to csv +- -o O specify the output folder name under --csv mode + +```bash +# 测试命令示例如下 +python3 benchmark_active.py --repeat_times 10 --csv -o './active/' +``` +支持如下算子: + +| op_name | +| ---------------------------------| +| active | +| apply_rotary | +| attention_project | +| ffn | +| flash_attn | +| fused_layer_norm | +| fused_moe | +| fused_norm_attention_project | +| fused_norm_residual_ffn | +| fused_rms_norm | +| group_gemm | +| matmul | +| offline_quant_to_linear_cache | +| per_token_smooth_quantize | +| preload | +| quantize | +| reshape_linear_cache | +| quant_to_linear_cache | +| reshape_paged_cache | +| single_query_cached_kv_attn | +| smooth_quant_matmul | +| weight_only_quant_matmul | +| moe_gen_idx | +| moe_expand_input | +| moe_softmax_topk | +| moe_combine_result | diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py new file mode 100644 index 0000000..e87a7ef --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py @@ -0,0 +1,64 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}, + {"batch": 16, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}, + {"batch": 72, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}, + {"batch": 1024, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}, + {"batch": 4096, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}, + {"batch": 8192, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}, + {"batch": 32768, "seq_len": 5, "hidden_size": 1024, + "act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["input_shape", "act_mode", "is_gated", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + hidden_size = params_dict["hidden_size"] + act_mode = params_dict["act_mode"] + is_gated = params_dict["is_gated"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.active, + input, + act_mode, + is_gated, + repeats=args.repeat_times) + io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated) + io_eff = io_bytes / hardware_time / bd + content = [f"{batch,seq_len,hidden_size}", f"{act_mode}", f"{is_gated}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py new file mode 100644 index 0000000..c2bec3c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py @@ -0,0 +1,84 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "rotary_dim": 128, + "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "rotary_dim": 64, + "interleaved": True, "discrete": False, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "rotary_dim": 128, + "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "rotary_dim": 128, + "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "rotary_dim": 64, + "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "rotary_dim": 96, + "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 4, "seq_len": 1, "head_num": 80, "head_size": 128, "rotary_dim": 128, + "interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}] +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "seq_len", "head_num", "head_size", "rotary_dim", "interleaved", "discrete", "dynamic_ntk", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + head_num = params_dict["head_num"] + head_size = params_dict["head_size"] + # full/partial + rotary_dim = params_dict["rotary_dim"] + # cross/fold + interleaved = params_dict["interleaved"] + # discrete + discrete = params_dict["discrete"] + # dynamic_ntk + dynamic_ntk = params_dict["dynamic_ntk"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input = torch.randn(batch, seq_len, head_num, head_size).to(device).to(dtype) # [batch, seqlen, head_num, head_size] + if dynamic_ntk: + sin_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype) + cos_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype) + else: + sin_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype) + cos_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype) + if discrete: + pos_ids = torch.randint(0, seq_len, (batch * seq_len,)).to(device).to(torch.int32) + else: + pos_ids = None + hardware_time, e2e_time = benchmark_forward(tmo.apply_rotary, + input, + sin_cache, + cos_cache, + pos_ids, + None, + interleaved, + discrete, + dynamic_ntk, + seq_len, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{head_num}", f"{head_size}", f"{rotary_dim}", f"{interleaved}", f"{discrete}", f"{dynamic_ntk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_attention_project.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_attention_project.py new file mode 100644 index 0000000..0cbc627 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_attention_project.py @@ -0,0 +1,74 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "hidden_size": 1600, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 2048, "hidden_size": 2048, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 4096, "hidden_size": 4096, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 6144, "hidden_size": 6144, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 6656, "hidden_size": 6656, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 8192, "hidden_size": 8192, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 12288, "hidden_size": 12288, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 14336, "hidden_size": 14336, + "has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "seq_len", "input_size", "hidden_size", "has_residual", "has_bias", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + input_size = params_dict["input_size"] + hidden_size = params_dict["hidden_size"] + has_residual = params_dict["has_residual"] + has_bias = params_dict["has_bias"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + x = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device) + weight = torch.randn(hidden_size, input_size).to(dtype).to(device) + residual, bias = None, None + if has_residual: + residual = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device) + if has_bias: + bias = torch.randn(hidden_size).to(dtype).to(device) + hardware_time, e2e_time = benchmark_forward(tmo.attention_project, + x, + weight, + bias, + residual, + 1.0, + 1.0, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{has_residual}", f"{has_bias}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py new file mode 100644 index 0000000..740c7e0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py @@ -0,0 +1,60 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 2, "m": 1024, "k": 1600, "n": 6400, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 2, "m": 1024, "k": 2048, "n": 8192, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 2, "m": 1024, "k": 4096, "n": 11008, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 2, "m": 1024, "k": 5120, "n": 16384, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 2, "m": 1024, "k": 6144, "n": 24576, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "m", "k", "n", "has_c", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + m = params_dict["m"] + k = params_dict["k"] + n = params_dict["n"] + has_c = params_dict["has_c"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + a = torch.randn(batch, m, k).to(device).to(dtype) + b = torch.randn(batch, n, k).to(device).to(dtype) + c = None + if has_c: + c = torch.randn(batch, m, n).to(device).to(dtype) + + hardware_time, e2e_time = benchmark_forward(tmo.batch_matmul, + a, + b, + c, + 1.0, + 1.0, + repeats=args.repeat_times) + content = [f"{batch}", f"{m}", f"{k}", f"{n}", f"{has_c}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_linear_cache.py new file mode 100644 index 0000000..4cb5f7d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_linear_cache.py @@ -0,0 +1,120 @@ +import argparse +import random +import os + +import torch +import torch_mlu +import torch_mlu_ops as tmo + +from common import benchmark_forward, save_to_csv +from itertools import product +from tabulate import tabulate + +e2e_time_param_dict_list = [ + {"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096], + "head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "head_size": 128, + "input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [4, 8], + "use_offset": True}, +] + +def main(): + parser = argparse.ArgumentParser(description="Benchmark for dequant from linear cache.") + parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["max_batch_size", "batch_size", "max_context_len", "head_num_q", "head_num_kv", + "cache_mem_len", "head_size", "input_dytpe", "quant_mode", "quant_bit", + "use_offset", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + mlu_name = torch.mlu.get_device_name() + for params_dict in e2e_time_param_dict_list: + max_batch_size = params_dict["max_batch_size"] + batch_size_list = params_dict["batch_size"] + max_context_len_list = params_dict["max_context_len"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + cache_mem_len = params_dict["cache_mem_len"] + input_dtype_list = params_dict["input_dtype"] + quant_mode_list = params_dict["quant_mode"] + quant_bit_list = params_dict["quant_bit"] + use_offset = params_dict["use_offset"] + for batch_size, max_context_len, quant_mode, quant_bit, dtype in list(product( \ + batch_size_list, max_context_len_list, quant_mode_list, quant_bit_list, \ + input_dtype_list)): + torch.manual_seed(2766) + torch.mlu.manual_seed(2766) + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \ + * head_size >= 2**31 - 1): + print("large tensor is not support on {}, skip".format(mlu_name)) + continue + total_heads = head_num_q + head_num_kv * 2 + assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \ + "equal to cache_mem_len." + max_seq_offset = cache_mem_len - max_context_len + # Generates key and cache from context + context_lens = torch.randint(size=[batch_size], low=max_context_len, + high=max_context_len + 1, + dtype=torch.int32, device="mlu") + if use_offset: + context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset, + dtype=torch.int32, device="mlu") + else: + context_paddings = torch.zeros_like(context_lens) + cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1) + total_seqlen = cu_context_lens[-1] + context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu") + context_seq_offset[1:] = cu_context_lens[:-1] + context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu") + key = context[..., head_num_q:head_num_q + head_num_kv, :] + value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :] + + # Generates key_cache and value_cache + cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu() + cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size], + dtype=torch.int32, device="mlu") + if quant_bit == 4: + key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size, + head_num_kv, cache_mem_len, head_size // 2), device="mlu") + value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size, + head_num_kv, cache_mem_len // 2, head_size), device="mlu") + key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8) + else: + cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size), + low=-128, high=127, dtype=torch.int32, device="mlu") + cache = cache.to(torch.int8) + key_cache, value_cache = cache[[0, 1]] + + # Generates key_cache_scale and value_cache_scale + if quant_mode == 0: # quant_mode == 0 is per channel + cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu") + else: # quant_mode != 1 (== 1 for extend) is per head + cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len), + dtype=torch.float, device="mlu") + key_cache_scale, value_cache_scale = cache_scale[[0, 1]] + hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_linear_cache, + key, value, key_cache, value_cache, + key_cache_scale, value_cache_scale, + context_lens, max_context_len, + context_seq_offset if use_offset else None, + cache_bs_id, cache_seq_offset, quant_mode, + quant_bit, repeats=args.repeat_times) + content = [f"{max_batch_size}", f"{batch_size}", f"{max_context_len}", f"{head_num_q}", + f"{head_num_kv}", f"{cache_mem_len}", f"{head_size}", f"{dtype}", f"{quant_mode}", + f"{quant_bit}", f"{quant_mode}", f"{use_offset}", f"{hardware_time}", + f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_paged_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_paged_cache.py new file mode 100644 index 0000000..e00d0cb --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_dequant_from_paged_cache.py @@ -0,0 +1,110 @@ +import argparse +import math +import random +import os + +import torch +import torch_mlu +import torch_mlu_ops as tmo + +from common import benchmark_forward, save_to_csv +from itertools import product +from tabulate import tabulate + +e2e_time_param_dict_list = [ + {"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096], + "head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "block_size": 16, "head_size": 128, + "input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [8], + "use_offset": True}, +] + +def main(): + parser = argparse.ArgumentParser(description="Benchmark for dequant from paged cache.") + parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + if "MLU3" in torch.mlu.get_device_name(): + print("Op dequant_from_paged_cache does not support MLU300 devices.") + return + titles = ["batch_size", "max_context_len", "head_num_q", "head_num_kv", + "cache_mem_len", "block_size", "head_size", "input_dytpe", "quant_mode", "quant_bit", + "use_offset", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch_size_list = params_dict["batch_size"] + max_context_len_list = params_dict["max_context_len"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + block_size = params_dict["block_size"] + head_size = params_dict["head_size"] + cache_mem_len = params_dict["cache_mem_len"] + input_dtype_list = params_dict["input_dtype"] + quant_mode_list = params_dict["quant_mode"] + quant_bit_list = params_dict["quant_bit"] + use_offset = params_dict["use_offset"] + for quant_mode, batch_size, max_context_len, quant_bit, dtype in list(product( \ + quant_mode_list, batch_size_list, max_context_len_list, quant_bit_list, \ + input_dtype_list)): + torch.manual_seed(2766) + torch.mlu.manual_seed(2766) + total_heads = head_num_q + head_num_kv * 2 + + assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \ + "equal to cache_mem_len." + max_seq_offset = cache_mem_len - max_context_len + max_block_num = int(math.ceil(max_context_len / block_size)) + total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size + block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size, + max_block_num) + # Generates key and cache from context + context_lens = torch.randint(size=[batch_size], low=max_context_len, high=max_context_len + 1, + dtype=torch.int32, device="mlu") + if use_offset: + context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset, + dtype=torch.int32, device="mlu") + else: + context_paddings = torch.zeros_like(context_lens) + cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1) + total_seqlen = cu_context_lens[-1] + context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu") + context_seq_offset[1:] = cu_context_lens[:-1] + context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu") + key = context[..., head_num_q:head_num_q + head_num_kv, :] + value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :] + # Generates key_cache and value_cache + cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size), + low=-128, high=127, dtype=torch.int32, device="mlu") + cache = cache.to(torch.int8) + key_cache, value_cache = cache[[0, 1]] + + # Generates key_cache_scale and value_cache_scale + if quant_mode == 0: # quant_mode == 0 is per channel + cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu") + else: # quant_mode != 1 (== 1 for extend) is per head + cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size), + dtype=torch.float, device="mlu") + key_cache_scale, value_cache_scale = cache_scale[[0, 1]] + hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_paged_cache, + key, value, key_cache, value_cache, + key_cache_scale, value_cache_scale, + context_lens, max_context_len, + context_seq_offset if use_offset else None, + block_tables, quant_mode, + quant_bit, repeats=args.repeat_times) + content = [f"{batch_size}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", + f"{cache_mem_len}", f"{block_size}", f"{head_size}", f"{dtype}", f"{quant_mode}", + f"{quant_bit}", f"{use_offset}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py new file mode 100644 index 0000000..0031aa9 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py @@ -0,0 +1,89 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400, + "gated_ffn": False, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344, + "gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}] +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + hidden_size = params_dict["hidden_size"] + inner_size = params_dict["inner_size"] + gated_ffn = params_dict["gated_ffn"] + act_mode = params_dict["act_mode"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype) + up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype) + up_proj_bias = torch.randn(inner_size).to(device).to(dtype) + down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype) + down_proj_bias = torch.randn(hidden_size).to(device).to(dtype) + gate_up_proj_weight, gate_up_proj_bias = None, None + if gated_ffn: + gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype) + gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.ffn, + input, + up_proj_weight, + up_proj_bias, + down_proj_weight, + down_proj_bias, + gate_up_proj_weight, + gate_up_proj_bias, + act_mode, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py new file mode 100644 index 0000000..f7e9045 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py @@ -0,0 +1,92 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +# for e2e time test +e2e_time_param_dict_list = [{"batch": 1, "seq_q": 32768, "seq_kv": 32768, "head_num": 8, + "head_num_kv": 1, "head_size": 128, "use_causal": True, + "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_q": 16384, "seq_kv": 16384, "head_num": 8, + "head_num_kv": 1, "head_size": 128, "use_causal": True, + "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_q": 8192, "seq_kv": 24576, "head_num": 8, + "head_num_kv": 1, "head_size": 128, "use_causal": True, + "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_q": 4096, "seq_kv": 28672, "head_num": 8, + "head_num_kv": 1, "head_size": 128, "use_causal": True, + "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_q": 4096, "seq_kv": 32768, "head_num": 8, + "head_num_kv": 1, "head_size": 128, "use_causal": True, + "softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["batch", "seq_q", "seq_kv", "head_num", "head_num_kv", "head_size", "use_causal", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_q = params_dict["seq_q"] + seq_kv = params_dict["seq_kv"] + head_num = params_dict["head_num"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + use_causal = params_dict["use_causal"] + softmax_scale = params_dict["softmax_scale"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + if seq_q == seq_kv: + qkv = torch.randn(batch, seq_q, head_num + 2 * head_num_kv, head_size).to(dtype).to(device) + q = qkv[:, :, : head_num, :] + k = qkv[:, :, head_num : head_num + head_num_kv, :] + v = qkv[:, :, head_num + head_num_kv : head_num + head_num * 2, :] + elif seq_q < seq_kv: + q = torch.randn(batch, seq_q, head_num, head_size).to(device).to(dtype) + kv = torch.randn(batch, seq_kv, head_num_kv * 2, head_size).to(device).to(dtype) + k = kv[:, :, : head_num_kv, :] + v = kv[:, :, head_num_kv :, :] + hardware_time, e2e_time = benchmark_forward(tmo.flash_attention, + q = q, + k = k, + v = v, + out = None, + cu_seq_lens_q = None, + cu_seq_lens_kv = None, + alibi_slope = None, + attn_bias = None, + max_seq_len_q = seq_q, + max_seq_len_kv = seq_kv, + softmax_scale = softmax_scale, + is_causal = use_causal, + window_size_left = -1, + window_size_right = -1, + compute_dtype = dtype, + return_lse = False, + block_tables = None, + k_cache_quant_scale = None, + v_cache_quant_scale = None, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_q}", f"{seq_kv}", f"{head_num}", f"{head_num_kv}", f"{head_size}", f"{use_causal}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py new file mode 100644 index 0000000..f702f69 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py @@ -0,0 +1,103 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [ + {"batch": 1, "seq_len": 2048, "hidden_size": 8192, "has_residual": True, + "has_bias": False, "has_quant": True, "dynamic_quant": True, + "input_dtype": torch.bfloat16}, + {"batch": 1, "seq_len": 4096, "hidden_size": 8192, "has_residual": True, + "has_bias": False, "has_quant": True, "dynamic_quant": True, + "input_dtype": torch.bfloat16}, + {"batch": 1, "seq_len": 8192, "hidden_size": 8192, "has_residual": True, + "has_bias": False, "has_quant": True, "dynamic_quant": True, + "input_dtype": torch.bfloat16}, + {"batch": 1, "seq_len": 32768, "hidden_size": 8192, "has_residual": True, + "has_bias": False, "has_quant": True, "dynamic_quant": True, + "input_dtype": torch.bfloat16}, + {"batch": 490, "seq_len": 1, "hidden_size": 8192, "has_residual": True, + "has_bias": False, "has_quant": True, "dynamic_quant": True, + "input_dtype": torch.bfloat16}, + {"batch": 525, "seq_len": 1, "hidden_size": 8192, "has_residual": True, + "has_bias": False, "has_quant": True, "dynamic_quant": True, + "input_dtype": torch.bfloat16}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "seq_len", "hidden_size", "has_residual", "has_bias", "has_quant", + "dynamic_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + hidden_size = params_dict["hidden_size"] + has_residual = params_dict["has_residual"] + has_bias = params_dict["has_bias"] + has_quant = params_dict["has_quant"] + dynamic_quant = params_dict["dynamic_quant"] + dtype = params_dict["input_dtype"] + eps = 1e-6 + + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + + x = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device) + beta = torch.randn(hidden_size, dtype=dtype, device=device) + gamma = torch.randn(hidden_size, dtype=dtype, device=device) + residual, bias, quant_scale = None, None, None + if has_residual: + residual = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device) + if has_bias: + bias = torch.randn(hidden_size, dtype=dtype, device=device) + if has_quant or dynamic_quant: + quant_scale = torch.randn(hidden_size, dtype=torch.float, device=device) + + store_output_before_norm = has_residual + hardware_time, e2e_time = benchmark_forward(tmo.fused_layer_norm, + x, + residual, + gamma, + beta, + bias, + eps, + store_output_before_norm, + quant_scale, + None, + dynamic_quant, + repeats=args.repeat_times) + n = x.nelement() + sizeoft = x.element_size() + io_bytes = (sizeoft + 1) * n + \ + (1 + store_output_before_norm) * (sizeoft * n if has_residual else 0) + \ + sizeoft * hidden_size * 2 + \ + (hidden_size * 4 if has_quant else 0) + \ + (batch * seq_len * 4 if dynamic_quant else 0) + io_eff = io_bytes / hardware_time / bd + + content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{has_residual}", f"{has_bias}", + f"{has_quant}", f"{dynamic_quant}", f"{dtype}", + f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py new file mode 100644 index 0000000..3c02664 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py @@ -0,0 +1,143 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [ + {"batch": 1, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 16, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 72, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 128, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 490, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 525, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 1, "seq_len": 2048, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 1, "seq_len": 4096, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 1, "seq_len": 8192, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + + {"batch": 1, "seq_len": 32768, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32, + "gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu", + "topk": 5, "renormalize": False, "dtype": [torch.bfloat16]}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "num_expert", "topk", "act_mode", "quant_weight", "dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + hidden_size = params_dict["hidden_size"] + inner_size = params_dict["inner_size"] + gated_ffn = params_dict["gated_ffn"] + act_mode = params_dict["act_mode"] + num_expert = params_dict["num_expert"] + start_expert_id = params_dict["start_expert_id"] + expert_size = params_dict["expert_size"] + topk = params_dict["topk"] + has_residual = params_dict["has_residual"] + smooth_quant = params_dict["smooth_quant"] + renormalize = params_dict["renormalize"] + input_dtype_list = params_dict["dtype"] + # print(f"batch:{batch}, seq_len:{seq_len}, hidden_size:{hidden_size}, inner_size:{inner_size}, " + # f"gated_ffn:{gated_ffn}, act_mode:{act_mode}, num_expert:{num_expert}, topk:{topk}") + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + hidden_states = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype) + router_logit = torch.randn(batch, seq_len, num_expert, device=device, dtype=torch.float32) + + if False: # print token_count + softmax = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1) + topk_logit, expert_id = torch.topk(softmax, k=topk, dim=1) + if renormalize: + topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1) + sorted_expert_id, indices = expert_id.int().flatten().sort() + token_cout = torch.bincount(sorted_expert_id, minlength=num_expert).int() + print(token_cout) + + residual = None + if has_residual: + residual = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype) + weight1 = torch.randn(num_expert, inner_size*(1+gated_ffn), hidden_size, device=device, dtype=dtype) + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device=device, dtype=data_type) + weight2 = torch.randn(num_expert, hidden_size, inner_size, device=device, dtype=dtype) + bias2 = None # torch.randn(expert_num, hidden_size, device=device, dtype=data_type) + input_smooth, act_smooth, w1_scale, w2_scale = None, None, None, None + if smooth_quant: + input_smooth = torch.randn(expert_size, hidden_size, device=device, dtype=torch.float32).abs() + 0.1 + act_smooth = torch.randn(expert_size, inner_size, device=device, dtype=torch.float32).abs() + 0.1 + weight1 = torch.randint(-128, 127, (num_expert, inner_size*(1+gated_ffn), hidden_size)).to(torch.int8).mlu() + weight2 = torch.randint(-128, 127, (num_expert, hidden_size, inner_size)).to(torch.int8).mlu() + w1_scale = torch.randn(expert_size, (1+gated_ffn)*inner_size).to(device).to(torch.float32) + w2_scale = torch.randn(expert_size, hidden_size).to(device).to(torch.float32) + + hardware_time, e2e_time = benchmark_forward(tmo.fused_moe, + hidden_states, + router_logit, + weight1[start_expert_id:start_expert_id+expert_size], + weight2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + input_smooth, + act_smooth, + w1_scale, + w2_scale, + topk, + renormalize, + gated_ffn, + act_mode, + start_expert_id, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{num_expert}", f"{topk}", f"{act_mode}", f"{smooth_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_attention_project.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_attention_project.py new file mode 100644 index 0000000..ed8e7bf --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_attention_project.py @@ -0,0 +1,71 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "head_size": 80, "hidden_size": 1600, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 2048, "head_size": 128, "hidden_size": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 4096, "head_size": 128, "hidden_size": 4096, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 6144, "head_size": 128, "hidden_size": 6144, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 6656, "head_size": 128, "hidden_size": 6656, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 8192, "head_size": 128, "hidden_size": 8192, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 12288, "head_size": 128, "hidden_size": 12288, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "input_size": 14336, "head_size": 128, "hidden_size": 14336, "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["batch", "seq_len", "input_size", "hidden_size", "head_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + input_size = params_dict["input_size"] + hidden_size = params_dict["hidden_size"] + head_size = params_dict["head_size"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input = torch.randn(batch, seq_len, input_size).to(device).to(dtype) + weight = torch.randn(hidden_size * 3, input_size).to(device).to(dtype) + bias = torch.randn(hidden_size * 3).to(device).to(dtype) + weights = torch.chunk(weight, 3) + biases = torch.chunk(bias, 3) + norm_weight = torch.randn(input_size).to(device).to(dtype) + norm_bias = torch.randn(input_size).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_attention_project, + input, + weights[0], + biases[0], + weights[1], + biases[1], + weights[2], + biases[2], + norm_weight, + norm_bias, + 1e-6, + 'nthc', + head_size, + False, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{head_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_residual_ffn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_residual_ffn.py new file mode 100644 index 0000000..ec6a392 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_norm_residual_ffn.py @@ -0,0 +1,97 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400, + "gated_ffn": False, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192, + "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008, + "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336, + "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384, + "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824, + "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392, + "gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920, + "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016, + "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576, + "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672, + "gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152, + "gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768, + "gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344, + "gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "residual_is", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + hidden_size = params_dict["hidden_size"] + inner_size = params_dict["inner_size"] + gated_ffn = params_dict["gated_ffn"] + residual_is = params_dict["residual_is"] + act_mode = params_dict["act_mode"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype) + up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype) + up_proj_bias = torch.randn(inner_size).to(device).to(dtype) + down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype) + down_proj_bias = torch.randn(hidden_size).to(device).to(dtype) + layernorm_weight = torch.randn(hidden_size).to(device).to(dtype) + layernorm_bias = torch.randn(hidden_size).to(device).to(dtype) + gate_up_proj_weight, gate_up_proj_bias = None, None + if gated_ffn: + gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype) + gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_residual_ffn, + input, + up_proj_weight, + up_proj_bias, + down_proj_weight, + down_proj_bias, + gate_up_proj_weight, + gate_up_proj_bias, + layernorm_weight, + layernorm_bias, + 1e-6, + act_mode, + residual_is, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{residual_is}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py new file mode 100644 index 0000000..b9653b6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py @@ -0,0 +1,90 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +from itertools import product + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 16, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 96, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "head_num": 112, "head_size": 128, "has_residual": True, + "has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["input_shape", "has_residual", "has_bias", "has_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + head_num = params_dict["head_num"] + head_size = params_dict["head_size"] + has_residual = params_dict["has_residual"] + has_quant = params_dict["has_quant"] + has_bias = params_dict["has_bias"] + eps = params_dict["eps"] + dynamic_quant_list = params_dict["dynamic_quant"] + input_dtype_list = params_dict["input_dtype"] + dynamic_quant_list = params_dict["dynamic_quant"] + input_dtype_list = params_dict["input_dtype"] + iters = product(dynamic_quant_list, input_dtype_list) + for dynamic_quant, dtype in iters: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + x = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device) + beta = torch.randn(head_size).to(dtype).to(device) + gamma = torch.randn(head_size).to(dtype).to(device) + residual, bias, quant_scale = None, None, None + if has_residual: + residual = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device) + if has_bias: + bias = torch.randn(head_size).to(dtype).to(device) + if has_quant or dynamic_quant: + quant_scale = torch.randn(head_size).to(device) + hardware_time, e2e_time = benchmark_forward(tmo.fused_rms_norm, + x, + residual, + gamma, + beta, + bias, + eps, + False, + quant_scale, + None, + dynamic_quant, + repeats=args.repeat_times) + content = [f"{batch, seq_len, head_num, head_size}", f"{has_residual}", f"{has_bias}", f"{has_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py new file mode 100644 index 0000000..07bf082 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py @@ -0,0 +1,207 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]}, + + {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]}, + + {"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128, + "mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},] +def main(): + if 'MLU3' in torch.mlu.get_device_name(): + exit() + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \ + "block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + bs = params_dict["batch"] + seq_len = params_dict["seq_len"] + q_heads = params_dict["head_num_q"] + kv_heads = params_dict["head_num_k"] + head_size = params_dict["head_size"] + rope_dim = params_dict["rotary_dim"] + quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True + paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False + mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False + max_decode_len = 0 + num_blocks = 0 + block_size = 0 + if paged_cache: + num_blocks = params_dict["num_blocks"] + block_size = params_dict["block_size"] + else: + max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32 + input_dtype_list = params_dict["input_dtype"] + + for dtype in input_dtype_list: + discrete_batch = True + max_bs = bs + 1 if discrete_batch else bs + input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size) + input = torch.randn(size=input_shape, dtype=dtype).mlu() + input_ref = input.clone() + + cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() + sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() + gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu() + beta = torch.randn(size=(head_size, ), dtype=dtype).mlu() + cache_dtype = dtype + if quant_kv: + k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu() + v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu() + cache_dtype = torch.int8 + k_scale_ops = 1 / k_scale + v_scale_ops = 1 / v_scale + else: + k_scale = None + v_scale = None + k_scale_ops = None + v_scale_ops = None + if paged_cache: + cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu') + else: + cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu') + + if quant_kv: + cache = (cache - 0.5) * 256 + cache = cache.to(cache_dtype) + k_cache = cache[0] + v_cache = cache[1] + + cache_bs_id = None + cache_seq_offsets = None + slot_mapping = None + if not paged_cache: + if discrete_batch: + cache_bs_id = random.sample([*range(0, max_bs)], bs) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + + cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2, + dtype=torch.int32, device='mlu') + else: + slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs) + slot_mapping = torch.IntTensor(slot_mapping).mlu() + + position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu') + + k_cache_lp = None + v_cache_lp = None + k_scale_lp = None + v_scale_lp = None + cache_bs_id_lp = None + cache_seq_offsets_lp = None + + if mixed_cache: + max_decode_len_lp = 1024 + k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu') + v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu') + max_value = torch.amax(torch.abs(k_cache_raw)) + k_cache_raw = k_cache_raw * (7 / max_value) + max_value = torch.amax(torch.abs(v_cache_raw)) + v_cache_raw = v_cache_raw * (7 / max_value) + k_cache_lp = k_cache_raw.to(torch.int8) + v_cache_lp = v_cache_raw.to(torch.int8) + + k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu() + v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu() + + cache_bs_id_lp = random.sample([*range(0, max_bs)], bs) + cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu() + cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2, + dtype=torch.int32, device='mlu') + + hardware_time, e2e_time = benchmark_forward(tmo.fused_rope, + input, + k_cache, + v_cache, + sin_table, + cos_table, + position_id, + gamma, + beta, + k_cache_lp, + v_cache_lp, + cache_bs_id, + cache_seq_offsets, + cache_bs_id_lp, + cache_seq_offsets_lp, + k_scale_ops, + v_scale_ops, + k_scale_lp, + v_scale_lp, + slot_mapping, + None, + 1e-5, + repeats=args.repeat_times) + content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \ + f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py new file mode 100644 index 0000000..edb2cc6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py @@ -0,0 +1,117 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [ + {"batch": 1, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 16, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 16, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 72, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 490, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 490, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 525, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 525, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + + {"batch": 1, "seq_len": 1024, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 1024, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 2048, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 2048, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 4096, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 4096, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 8192, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 8192, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 32768, "k": 8192, "n": 1024, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 32768, "k": 1024, "n": 8192, "expert_num": 32, + "topk": 5, "is_quant": True, "dtype": [torch.bfloat16]}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + + titles = ["batch", "seq_len", "k", "n", "expert_num", "topk", "smooth_quant", "dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + k = params_dict["k"] + n = params_dict["n"] + expert_num = params_dict["expert_num"] + topk = params_dict["topk"] + is_quant = params_dict["is_quant"] + input_dtype_list = params_dict["dtype"] + # print(f"batch:{batch}, seq_len:{seq_len}, k:{k}, n:{n}, expert_num:{expert_num}, topk:{topk}, is_quant:{is_quant}") + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + max_m = batch * seq_len + m = batch * seq_len * topk + avg, rem = m // expert_num, m % expert_num + m_list = [avg + (i < rem) for i in range(expert_num)] + token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu') + + if not is_quant: + a = torch.randn(m, k, dtype=dtype, device='mlu') + b = torch.randn(expert_num, n, k, dtype=dtype, device='mlu') + hardware_time, e2e_time = benchmark_forward(tmo.group_gemm, + a, b, token_count, + None, None, None, None, + max_m, + repeats=args.repeat_times) + else: + a = torch.randint(-128, 127, (m, k)).to(torch.int8).mlu() + b = torch.randint(-128, 127, (expert_num, n, k)).to(torch.int8).mlu() + a_scale = torch.randn(a.size(0), dtype=torch.float32, device='mlu') + b_scale = torch.randn(expert_num, n, dtype=torch.float32, device='mlu') + hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_group_gemm, + a, b, token_count, + None, None, None, None, + a_scale, b_scale, dtype, max_m, + repeats=args.repeat_times) + content = [f"{batch}", f"{seq_len}", f"{k}", f"{n}", f"{expert_num}", f"{topk}", + f"{is_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py new file mode 100644 index 0000000..40f3e96 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py @@ -0,0 +1,75 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"m": 1024, "k": 1600, "n": 6400, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 2048, "n": 8192, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 4096, "n": 11008, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 4096, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 5120, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 5120, "n": 27392, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 6144, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 6656, "n": 17920, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 8192, "n": 22016, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 8192, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 8192, "n": 28672, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 8192, "n": 49152, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 12288, "n": 32768, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 14336, "n": 57344, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + m = params_dict["m"] + k = params_dict["k"] + n = params_dict["n"] + has_c = params_dict["has_c"] + has_bias = params_dict["has_bias"] + act_mode = params_dict["act_mode"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + a = torch.randn(m, k).to(device).to(dtype) + b = torch.randn(n, k).to(device).to(dtype) + c = None + if has_c: + c = torch.randn(m, n).to(device).to(dtype) + bias = None + if has_bias: + bias = torch.randn(n).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.matmul, + a, + b, + bias, + c, + act_mode, + 1.0, + 1.0, + repeats=args.repeat_times) + content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py new file mode 100644 index 0000000..6acde62 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py @@ -0,0 +1,117 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "inner_size": 1024, + "act_mode": "gelu", "is_gated": True, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 4096, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": False, "input_dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 8192, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 32768, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 1, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 16, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 32, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 64, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 128, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 256, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]}, + {"batch": 512, "seq_len": 1, "inner_size": 1024, + "act_mode": "gelu", "is_gated": False, "has_bias": True, + "is_ep": True, "input_dtype": [torch.bfloat16]},] + +def gen_data(num_expert, + total_tokens, + inner_size, + output_stride, + dtype, + is_gated, + has_bias, + is_ep): + ci = inner_size * (1 + is_gated) + input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu') + cusum_token_count, token_count = generate_token_count(num_expert, total_tokens) + output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu') + output.as_strided(output.size(), (output_stride, 1)) + start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0 + expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert + bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None + return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["input_shape", "act_mode", "is_gated", "has_bias", "expert_num", "start_expert_id", + "expert_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + seq_len = params_dict["seq_len"] + inner_size = params_dict["inner_size"] + act_mode = params_dict["act_mode"] + is_gated = params_dict["is_gated"] + input_dtype_list = params_dict["input_dtype"] + has_bias = params_dict["has_bias"] + is_ep = params_dict["is_ep"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + expert_num = expert_num = random.randint(1, 256) + input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(expert_num, batch * seq_len, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + + real_bias = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None + hardware_time, e2e_time = benchmark_forward(tmo.moe_active, + input, + act_mode, + is_gated, + output, + real_bias, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size, + repeats=args.repeat_times) + + io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated) + \ + real_bias.element_size() * real_bias.nelement() + \ + (cusum_token_count.element_size() * cusum_token_count.nelement()) if has_bias or is_ep else 0 + io_eff = io_bytes / hardware_time / bd + content = [f"{batch,seq_len,inner_size}", f"{act_mode}", f"{is_gated}", f"{has_bias}", f"{expert_num}", + f"{start_expert_id}", f"{expert_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py new file mode 100644 index 0000000..0c947ec --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py @@ -0,0 +1,59 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}, + {"batch": 1, "seq_len": 4096, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}, + {"batch": 1, "seq_len": 32768, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}, + {"batch": 16, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}, + {"batch": 128, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}, + {"batch": 512, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}] + +def main(): + if 'MLU3' in torch.mlu.get_device_name(): + exit() + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["batch", "seq_len", "hidden_size", "expert_num", "input_dtype", "hardware_time(us)", + "e2e_latency(us)", "IO efficiency"] + + contents = [] + bandwidth = get_band_width() + for param_dict in e2e_time_param_dict_list: + batch = param_dict["batch"] + seq_len = param_dict["seq_len"] + hidden_size = param_dict["hidden_size"] + expert_num = param_dict["expert_num"] + input_dtype = param_dict["input_dtype"] + if input_dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + input_dtype = torch.half + input = torch.randn(batch, seq_len, hidden_size, dtype=input_dtype, device="mlu") + weight = torch.randn(expert_num, hidden_size, dtype=torch.float32, device="mlu") + hardware_time, e2e_time = benchmark_forward(tmo.moe_cast_gating, + input, + weight) + io_bytes = batch * seq_len * hidden_size * input.element_size() + \ + expert_num * hidden_size * weight.element_size() + batch * seq_len * expert_num * weight.element_size() + io_coeff = io_bytes / hardware_time / bandwidth + content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{expert_num}", f"{input_dtype}", + f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py new file mode 100644 index 0000000..9c1483c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py @@ -0,0 +1,166 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [ + {"num_tokens": 16, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 128, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 490, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 525, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 2048, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 4096, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 8192, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + {"num_tokens": 32768, "num_expert": 32, "topk": 5, "start_expert_id": 0, + "expert_size": 32, "has_residual": False, "hidden_size": 8192, + "dtype": [torch.bfloat16]}, + ] + +def gen_case(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + has_bias, + has_residual, + dtype, + device): + input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device) + reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device) + gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device) + bias = None + residual = None + + cusum_token_count = None + + if has_bias: + bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device) + if has_residual: + residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device) + + if has_bias or expert_size < num_expert: + cusum_token_count, _ = generate_token_count(num_expert, num_tokens * topk) + cusum_token_count = cusum_token_count.to(device=device) + return input, reduce_weight, gather_ids, residual, bias, cusum_token_count + + +def get_io_bytes(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + start_expert_id, + has_bias, + has_residual, + dtype, + cusum_token_count, + gather_ids): + io_bytes = 0 + dtype_size = 4 if dtype is torch.float32 else 2 + if cusum_token_count is not None: + filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \ + (gather_ids < cusum_token_count[start_expert_id + expert_size]) + filtered_ids = filtered_ids.to(dtype=torch.float32) + io_bytes += torch.sum(filtered_ids).item() * hidden_size * dtype_size + else: + io_bytes += num_tokens * topk * hidden_size * dtype_size + + if has_bias: + io_bytes += expert_size * hidden_size * dtype_size + + if has_residual: + io_bytes += num_tokens * hidden_size * dtype_size + + io_bytes += num_tokens * topk * 4 + io_bytes += num_tokens * hidden_size * dtype_size + return io_bytes + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["num_tokens", "num_expert", "topk", "start_expert_id", "expert_size", \ + "hidden_size", "has_residual", "dtype", "hardware_time(us)", "e2e_latency(us)", "io_coeff"] + contents = [] + bandwidth = get_band_width() + for params_dict in e2e_time_param_dict_list: + num_tokens = params_dict["num_tokens"] + num_expert = params_dict["num_expert"] + topk = params_dict["topk"] + start_expert_id = params_dict["start_expert_id"] + expert_size = params_dict["expert_size"] + has_residual = params_dict["has_residual"] + hidden_size = params_dict["hidden_size"] + dtype_list = params_dict["dtype"] + for dtype in dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + inputs = gen_case(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + False, + has_residual, + dtype, + device) + input = inputs[0] + reduce_weight = inputs[1] + gather_ids = inputs[2] + residual = inputs[3] + bias = inputs[4] + cusum_token_count = inputs[5] + + io_bytes = get_io_bytes(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + start_expert_id, + False, + has_residual, + dtype, + cusum_token_count, + gather_ids) + + hardware_time, e2e_time = benchmark_forward(tmo.moe_combine_result, input, reduce_weight, + gather_ids,residual, cusum_token_count, + start_expert_id, expert_size, + repeats=args.repeat_times) + io_coeff = io_bytes / hardware_time / bandwidth + content = [f"{num_tokens}", f"{num_expert}", f"{topk}", f"{start_expert_id}", \ + f"{expert_size}", f"{hidden_size}", f"{has_residual}", f"{dtype}", \ + f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_expand_input.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_expand_input.py new file mode 100644 index 0000000..2e4cee0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_expand_input.py @@ -0,0 +1,93 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import numpy as np + +e2e_time_param_dict_list = [{"token_num": 1, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 16, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 32, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 64, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 128, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 512, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 1024, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 4096, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 8192, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}, + {"token_num": 32768, "hidden_size": 4096, "expert_num": 32, "topk": 5, + "start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}] + +def gen_tensor(token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype): + input = torch.randn(token_num, hidden_size).to(dtype).to('mlu') + gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,)).to(torch.int32).to('mlu') + cusum_token_count, _ = generate_token_count(expert_num, token_num * topk) + cusum_token_count = cusum_token_count.to('mlu') + use_all_experts = expert_num == expert_size + if use_all_experts: + cusum_token_count = None + real_token_count = token_num * topk + else: + real_token_count = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id] + return input, gather_idx, cusum_token_count, real_token_count + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + + titles = ["token_num", "hidden_size", "expert_num", "topk", "start_expert_id", "expert_size", "input_dtype", + "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + for params_dict in e2e_time_param_dict_list: + token_num = params_dict["token_num"] + hidden_size = params_dict["hidden_size"] + expert_num = params_dict["expert_num"] + topk = params_dict["topk"] + start_expert_id = params_dict["start_expert_id"] + expert_size = params_dict["expert_size"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input, gather_idx, cusum_token_count, real_token_count = \ + gen_tensor(token_num, hidden_size, expert_num,topk, start_expert_id, expert_size, dtype) + hardware_time, e2e_time = benchmark_forward(tmo.moe_expand_input, + input, + gather_idx, + cusum_token_count, + start_expert_id, + expert_size, + repeats=args.repeat_times) + io_bytes = input.element_size() * input.nelement() + \ + gather_idx.element_size() * gather_idx.nelement() + \ + (cusum_token_count.element_size() * cusum_token_count.nelement() if cusum_token_count is not None else 0) + \ + real_token_count * input.element_size() + io_eff = io_bytes / hardware_time / bd + content = [f"{token_num}", f"{hidden_size}", f"{expert_num}", f"{topk}", f"{start_expert_id}", f"{expert_size}", + f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__ == "__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py new file mode 100644 index 0000000..0900741 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py @@ -0,0 +1,69 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"token_num": 1, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 16, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 32, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 64, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 512, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 1024, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 4096, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 8192, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 32767, "expert_num": 32, "topk": 5, "input_dtype": torch.int32}, + {"token_num": 1, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 16, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 32, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 64, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 512, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 1024, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 4096, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 8192, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}, + {"token_num": 32767, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}] +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + + titles = ["token_num", "expert_num", "topk", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + for params_dict in e2e_time_param_dict_list: + token_num = params_dict["token_num"] + expert_num = params_dict["expert_num"] + topk = params_dict["topk"] + dtype = params_dict["input_dtype"] + expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu') + gather_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu') + combine_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu') + token_count = torch.empty((expert_num), dtype=dtype, device='mlu') + cusum_token_count = torch.empty((expert_num + 1), dtype=dtype, device='mlu') + hardware_time, e2e_time = benchmark_forward(tmo.moe_gen_idx, + expert_id, + expert_num, + repeats=args.repeat_times) + io_bytes = expert_id.element_size() * expert_id.nelement() + \ + gather_idx.element_size() * gather_idx.nelement() + \ + combine_idx.element_size() * combine_idx.nelement() + \ + token_count.element_size() * token_count.nelement() + \ + cusum_token_count.element_size() * cusum_token_count.nelement() + io_eff = io_bytes / hardware_time / bd + content = [f"{token_num}", f"{expert_num}", f"{topk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__ == "__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py new file mode 100644 index 0000000..2aa5279 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py @@ -0,0 +1,114 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import random + +params_dict = [ + {"token_num": 1, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 16, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 128, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 490, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 512, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 525, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 2048, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 4096, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 8192, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + {"token_num": 32768, "hidden_size": 8192, "expert_num": 32, "topk": 5, + "has_gather_idx": True, "dtype": torch.bfloat16}, + + {"token_num": 1, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 16, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 128, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 490, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 512, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 525, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 2048, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 4096, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 8192, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + {"token_num": 32768, "hidden_size": 1024, "expert_num": 32, "topk": 5, + "has_gather_idx": False, "dtype": torch.bfloat16}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["token_num", "hidden_size", "expert_num", "topk", "has_gather_idx", "dtype", + "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + for param in params_dict: + token_num, hidden_size, expert_num, topk, has_gather_idx, dtype = param.values() + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + if "MLU3" in torch.mlu.get_device_name(): + has_gather_idx = False + expand_token_num = token_num * topk + input_shape = (token_num if has_gather_idx else expand_token_num, hidden_size) + input = torch.randn(input_shape).to(device).to(dtype) + scale = torch.randn(expert_num, hidden_size).to(device).to(torch.float32) + + avg, rem = expand_token_num // expert_num, expand_token_num % expert_num + m_list = [avg + (i < rem) for i in range(expert_num)] + token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu') + + if has_gather_idx: + gather_idx = torch.arange(0, token_num).repeat([topk]) + gather_idx = gather_idx[torch.randperm(gather_idx.size(0))].to(torch.int32).mlu() + else: + gather_idx = None + + hardware_time, e2e_time = benchmark_forward(tmo.moe_quantize, + input, + scale, + None, + token_count, + gather_idx, + None, + None, + None, + True, + repeats=args.repeat_times) + expand_num = topk if has_gather_idx else 1 + io_bytes = (input.element_size() + 1) * input.nelement() * expand_num + io_eff = io_bytes / hardware_time / bd + content = [f"{token_num}", f"{hidden_size}", f"{expert_num}", + f"{topk}", f"{has_gather_idx}", f"{dtype}", + f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_softmax_topk.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_softmax_topk.py new file mode 100644 index 0000000..1000e58 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_softmax_topk.py @@ -0,0 +1,87 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 32, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 72, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 8192, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 32768, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 2, "seq_len": 16, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 2, "seq_len": 36, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 8, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 16, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 4, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 2, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 16, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 1, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 16, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 64, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 1024, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 2048, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 8192, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + {"num_batch": 1, "seq_len": 32768, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["num_batch", "seq_len", "num_expert", "topk", "num_expert_group", "topk_group", "normalize", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bd = get_band_width() + + for params_dict in e2e_time_param_dict_list: + num_batch = params_dict["num_batch"] + seq_len = params_dict["seq_len"] + num_expert = params_dict["num_expert"] + topk = params_dict["topk"] + num_expert_group = params_dict["num_expert_group"] + topk_group = params_dict["topk_group"] + normalize = params_dict["normalize"] + dtype = params_dict["input_dtype"] + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + input = torch.randn(num_batch, seq_len, num_expert, dtype=dtype, device='mlu') + mask = torch.randint(0, 2, (1, seq_len, num_expert), dtype = dtype, device='mlu') + if num_expert_group > 1: + mask = None + normed_by = "softmax_logit" + reduce_weight = torch.empty(num_batch, topk, dtype=torch.float, device='mlu') + expert_id = torch.empty(num_batch, topk, dtype=torch.int32, device='mlu') + hardware_time, e2e_time = benchmark_forward(tmo.moe_softmax_topk, + input, + topk, + normalize, + num_expert_group, + topk_group, + mask, + normed_by, + repeats=args.repeat_times) + io_bytes = input.element_size() * input.nelement() + \ + reduce_weight.element_size() * reduce_weight.nelement() + \ + expert_id.element_size() * expert_id.nelement() + io_eff = io_bytes / hardware_time / bd + content = [f"{num_batch}", f"{seq_len}", f"{num_expert}", f"{topk}", f"{num_expert_group}", f"{topk_group}", f"{normalize}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_linear_cache.py new file mode 100644 index 0000000..af6dd74 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_linear_cache.py @@ -0,0 +1,145 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"}, + {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"}, + {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"}, + {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"} + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "quantize_mode", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + max_batch = params_dict["max_batch"] + batch = params_dict["batch"] + cache_mem_len = params_dict["cache_mem_len"] + max_context_len = params_dict["max_context_len"] + max_seq_offset = params_dict["max_seq_offset"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + packed = params_dict["packed"] + quantize_mode = params_dict["quantize_mode"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + context_lens = torch.randint(size=(batch, ), low=max_context_len, + high=max_context_len+1, + dtype=torch.int32, device='mlu') + # max_seq_offset = max_context_len // 3 + 1 + context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch, ), low=-1, + high=(cache_mem_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + total_heads = head_num_q + 2 * head_num_kv + if packed > 0: + context = torch.randn((total_seqlen, total_heads, head_size), + dtype=torch.float, device='mlu') + else: + context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size), + dtype=torch.float, device='mlu') + cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu') + context = context.to(dtype) + cache = cache.to(dtype) + key = context[..., head_num_q : head_num_q + head_num_kv, :] + value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :] + key_cache = cache[0] + value_cache = cache[1] + + cache_bs_id = None + cache_bs_id = random.sample([*range(0, max_batch)], batch) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + key_cache = (key_cache - 0.5) * 256 + value_cache = (value_cache - 0.5) * 256 + key_cache = key_cache.to(torch.int8) + value_cache = value_cache.to(torch.int8) + if packed > 0: + if quantize_mode == "per_channel": + key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32) + value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32) + hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache, + key, value, + key_cache, value_cache, + key_cache_quantize_scale, + value_cache_quantize_scale, + cu_context_lens, max_context_len, 0, + packed > 0, None, + cache_bs_id, cache_seq_offsets, + repeats=args.repeat_times) + elif quantize_mode == "per_head": + key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32) + value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32) + hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache, + key, value, + key_cache, value_cache, + key_cache_quantize_scale, + value_cache_quantize_scale, + cu_context_lens, max_context_len, 1, + packed > 0, None, + cache_bs_id, cache_seq_offsets, + repeats=args.repeat_times) + else: + if quantize_mode == "per_channel": + key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32) + value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32) + hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache, + key, value, + key_cache, value_cache, + key_cache_quantize_scale, + value_cache_quantize_scale, + context_lens, max_context_len, 0, + packed > 0, context_seq_offsets, + cache_bs_id, cache_seq_offsets, + repeats=args.repeat_times) + elif quantize_mode == "per_head": + key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32) + value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32) + hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache, + key, value, + key_cache, value_cache, + key_cache_quantize_scale, + value_cache_quantize_scale, + context_lens, max_context_len, 1, + packed > 0, context_seq_offsets, + cache_bs_id, cache_seq_offsets, + repeats=args.repeat_times) + + content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{quantize_mode}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_paged_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_paged_cache.py new file mode 100644 index 0000000..6ed47e9 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_offline_quant_to_paged_cache.py @@ -0,0 +1,73 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from itertools import product +from tabulate import tabulate +import os +import random + + +e2e_time_param_dict_list = [ + {"token_num": [1024, 2048, 3072, 4096], "head_num_kv": 1, "head_size": 128, "block_size": 16, + "input_dtype": [torch.float16, torch.bfloat16]}, + {"token_num": [1024 * 32, 2048 * 32, 3072 * 32, 4096 * 32], "head_num_kv": 1, "head_size": 128, "block_size": 16, + "input_dtype": [torch.float16, torch.bfloat16]}, + {"token_num": [1024 * 64, 2048 * 64, 3072 * 64, 4096 * 64], "head_num_kv": 1, "head_size": 128, "block_size": 16, + "input_dtype": [torch.float16, torch.bfloat16]}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + + args = parser.parse_args() + titles = ["token_num", "head_num_kv", "head_size", "block_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + if "MLU3" in torch.mlu.get_device_name(): + print("Op offline_quant_to_paged_cache does not support MLU300 devices.") + return + for params_dict in e2e_time_param_dict_list: + token_num_list = params_dict["token_num"] + # block_num = params_dict["block_num"] + block_size = params_dict["block_size"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + input_dtype_list = params_dict["input_dtype"] + for token_num, dtype in product(token_num_list, input_dtype_list): + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + block_num = (token_num + block_size - 1) // block_size + key = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu") + value = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu") + key_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu") + value_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu") + num_slots = block_num * block_size + slot_mapping = random.sample(range(num_slots), token_num) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu") + key_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu") + value_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu") + hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_paged_cache, + key, value, + key_cache_scale, value_cache_scale, + slot_mapping, + key_cache, value_cache, + repeats=args.repeat_times) + + + content = [f"{token_num}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_per_token_smooth_quantize.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_per_token_smooth_quantize.py new file mode 100644 index 0000000..8043f28 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_per_token_smooth_quantize.py @@ -0,0 +1,61 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import random +import csv + +params_dict = { + "token_num": [n * 5 for n in [1, 72, 512, 1024, 4096, 32768]], + "hidden_size": [1024, 8192], + "input_dtype": [torch.float16] + } + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + params_list = product(params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"]) + bd = get_band_width() + for params in params_list: + token_num, hidden_size = params[0], params[1] + input_shape = (token_num, hidden_size) + smooth_shape = (hidden_size) + dtype = params[2] + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input = torch.randn(input_shape).to(device).to(dtype) + smooth = torch.randn(smooth_shape).to(device).to(torch.float32) + zero = None + token_count = None + hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize, + input, + smooth, + zero, + token_count, + repeats=args.repeat_times) + io_bytes = (input.element_size() + 1) * input.nelement() + \ + smooth.element_size() * smooth.nelement() + \ + token_num * 4 + io_eff = io_bytes / hardware_time / bd + content = [f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py new file mode 100644 index 0000000..4261e9f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py @@ -0,0 +1,46 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"input_shape": [100, 100, 100], "input_dtype": [torch.float16, torch.bfloat16]}, + {"input_shape": [100, 100], "input_dtype": [torch.float16, torch.bfloat16]}, + {"input_shape": [50, 50, 50], "input_dtype": [torch.float16, torch.bfloat16]}, + {"input_shape": [1, 100, 1000], "input_dtype": [torch.float16, torch.bfloat16]} + ] +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["input_shape", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + input_shape = params_dict["input_shape"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input = torch.randn(input_shape).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.preload, + input, + input.element_size() * input.numel(), + repeats=args.repeat_times) + content = [f"{input_shape}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quant_to_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quant_to_linear_cache.py new file mode 100644 index 0000000..bfd70ba --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quant_to_linear_cache.py @@ -0,0 +1,201 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import random +import os + +e2e_time_param_dict_list = [ + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1, + "head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + {"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1, + "head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["batch", "max_context_len", "head_num_kv", "head_size", "packed", "input_dtype", + "quant_bit", "group_size", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + bandwidth = get_band_width() + for params_dict in e2e_time_param_dict_list: + max_batch = params_dict["max_batch"] + batch = params_dict["batch"] + cache_mem_len = params_dict["cache_mem_len"] + max_context_len = params_dict["max_context_len"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + packed = params_dict["packed"] + input_dtype_list = params_dict["input_dtype"] + quant_bit = params_dict["quant_bit"] + group_size = params_dict["group_size"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + context_lens = torch.tensor([max_context_len] * batch).to(torch.int32).mlu() + context_seq_offsets = torch.zeros(batch, dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch, ), + low=0, + high = 1 if max_context_len > 1 else cache_mem_len, + dtype=torch.int32, device='mlu') + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + if packed > 0: + key = torch.randn((total_seqlen, head_num_kv, head_size), + dtype=torch.float, device='mlu') + value = torch.randn((total_seqlen, head_num_kv, head_size), + dtype=torch.float, device='mlu') + else: + key = torch.randn((batch, max_context_len, head_num_kv, head_size), + dtype=torch.float, device='mlu') + value = torch.randn((batch, max_context_len, head_num_kv, head_size), + dtype=torch.float, device='mlu') + key = key.to(dtype) + value = value.to(dtype) + if quant_bit == 8 and group_size == head_size: + key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu() + if quant_bit == 8 and group_size != head_size: + key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu() + if quant_bit == 4 and group_size == head_size: + key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu() + if quant_bit == 4 and group_size != head_size: + key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu() + + cache_bs_id = None + cache_bs_id = random.sample([*range(0, max_batch)], batch) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + + if packed > 0: + hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache, + key, value, + key_cache, value_cache, + key_cache_scale, + value_cache_scale, + cu_context_lens, max_context_len, + packed > 0, None, + cache_bs_id, cache_seq_offsets, + quant_bit, + repeats=args.repeat_times) + else: + hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache, + key, value, + key_cache, value_cache, + key_cache_scale, + value_cache_scale, + context_lens, max_context_len, + packed > 0, context_seq_offsets, + cache_bs_id, cache_seq_offsets, + quant_bit, + repeats=args.repeat_times) + + io_bytes = key.nelement() * (key.element_size() + 1) * 2 + io_eff = io_bytes / hardware_time / bandwidth + content = [f"{batch}", f"{max_context_len}", f"{head_num_kv}", f"{head_size}", f"{packed}", + f"{dtype}", f"{quant_bit}", f"{group_size}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py new file mode 100644 index 0000000..26dde4e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py @@ -0,0 +1,61 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import random + +params_dict = {"dynamic": [True], + "token_num": [1, 72, 490, 512, 525, 1024, 4096, 8192, 32768], + "hidden_size": [8192, 1024], + "input_dtype": [torch.bfloat16]} + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + titles = ["dynamic", "token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"] + contents = [] + params_list = product(params_dict["dynamic"], params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"]) + bd = get_band_width() + for param in params_list: + dynamic, token_num, hidden_size, dtype = param[0], param[1], param[2], param[3] + input_shape = (token_num, hidden_size) + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.half + input = torch.randn(input_shape).to(device).to(dtype) + scale = torch.randn(input_shape[-1]).to(device).to(torch.float32) + zero = None + if dynamic: + hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize, + input, + scale, + zero, + None, + repeats=args.repeat_times) + else: + hardware_time, e2e_time = benchmark_forward(tmo.quantize, + input, + scale, + zero, + repeats=args.repeat_times) + io_bytes = (input.element_size() + 1) * input.nelement() + scale.element_size() * scale.nelement() + io_eff = io_bytes / hardware_time / bd + content = [f"{dynamic}", f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_linear_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_linear_cache.py new file mode 100644 index 0000000..4ef8551 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_linear_cache.py @@ -0,0 +1,109 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512, + "max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128, + "packed": False, "input_dtype": [torch.float16, torch.bfloat16]} + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + max_batch = params_dict["max_batch"] + batch = params_dict["batch"] + cache_mem_len = params_dict["cache_mem_len"] + max_context_len = params_dict["max_context_len"] + max_seq_offset = params_dict["max_seq_offset"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + packed = params_dict["packed"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + context_lens = torch.randint(size=(batch, ), low=max_context_len, + high=max_context_len+1, + dtype=torch.int32, device='mlu') + # max_seq_offset = max_context_len // 3 + 1 + context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch, ), low=-1, + high=(cache_mem_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + total_heads = head_num_q + 2 * head_num_kv + if packed > 0: + context = torch.randn((total_seqlen, total_heads, head_size), + dtype=torch.float, device='mlu') + else: + context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size), + dtype=torch.float, device='mlu') + cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu') + context = context.to(dtype) + cache = cache.to(dtype) + key = context[..., head_num_q : head_num_q + head_num_kv, :] + value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :] + key_cache = cache[0] + value_cache = cache[1] + + cache_bs_id = None + cache_bs_id = random.sample([*range(0, max_batch)], batch) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + + if packed > 0: + hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache, + key, value, + key_cache, value_cache, + cu_context_lens, max_context_len, + packed > 0, None, + cache_bs_id, cache_seq_offsets, + repeats=args.repeat_times) + + else: + hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache, + key, value, + key_cache, value_cache, + context_lens, max_context_len, + packed > 0, context_seq_offsets, + cache_bs_id, cache_seq_offsets, + repeats=args.repeat_times) + + content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_paged_cache.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_paged_cache.py new file mode 100644 index 0000000..f0dff89 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_reshape_paged_cache.py @@ -0,0 +1,76 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32, + "head_num_kv": 32, "head_size": 128, "quantize": True, "input_dtype": [torch.float16, torch.bfloat16]}, + {"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32, + "head_num_kv": 32, "head_size": 128, "quantize": False, "input_dtype": [torch.float16, torch.bfloat16]} + ] + +def main(): + if 'MLU3' in torch.mlu.get_device_name(): + exit() + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + titles = ["num_tokens", "num_block", "block_size", "head_num_q", "head_num_kv", "head_size", "input_dytpe", "quantize", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + num_tokens = params_dict["num_tokens"] + num_blocks = params_dict["num_block"] + block_size = params_dict["block_size"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + quantize = params_dict["quantize"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + qkv = torch.randn(num_tokens, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu() + key = qkv[:, head_num_q : head_num_q + head_num_kv, :] + value = qkv[:, head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :] + + key_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu() + value_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu() + + num_slots = num_blocks * block_size + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu() + slot_mapping[-1] = -1 + if not quantize: + hardware_time, e2e_time = benchmark_forward(tmo.reshape_paged_cache, + key, value, + key_cache, value_cache, + slot_mapping, + repeats=args.repeat_times) + else: + k_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32) + v_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32) + hardware_time, e2e_time = benchmark_forward(tmo.quant_to_paged_cache, + key, value, + key_cache, value_cache, + k_cache_quant_scale, + v_cache_quant_scale, + slot_mapping, + repeats=args.repeat_times) + + content = [f"{num_tokens}", f"{num_blocks}", f"{block_size}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{dtype}", f"{quantize}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_cached_kv_attn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_cached_kv_attn.py new file mode 100644 index 0000000..de49901 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_cached_kv_attn.py @@ -0,0 +1,116 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import math +import random + +e2e_time_param_dict_list = [ + {"batch": 16, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128, + "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False, + "is_pertoken": False, "input_dtype": [torch.bfloat16]}, + {"batch": 128, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128, + "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False, + "is_pertoken": False, "input_dtype": [torch.bfloat16]}, + {"batch": 512, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128, + "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False, + "is_pertoken": False, "input_dtype": [torch.bfloat16]}, + + {"batch": 16, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128, + "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False, + "is_pertoken": False, "input_dtype": [torch.bfloat16]}, + {"batch": 128, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128, + "block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False, + "is_pertoken": False, "input_dtype": [torch.bfloat16]}, + ] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "max_seq_len", "head_num_q", "head_num_kv", "head_size", "block_size", "alibi_bias", "kv_cache_dtype", "use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + max_seq_len = params_dict["max_seq_len"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + block_size = params_dict["block_size"] + alibi_bias = params_dict["alibi_bias"] + kv_cache_dtype = params_dict["kv_cache_dtype"] + use_paged_attn = params_dict["use_paged_attn"] + input_dtype_list = params_dict["input_dtype"] + is_pertoken = params_dict["is_pertoken"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype) + input_q = input_qkv[..., 0 : head_num_q, :] + + context_lens = torch.randint(max_seq_len, max_seq_len + 1, (batch, ), dtype=torch.int32).to(device) + max_context_len = int(max(context_lens)) + if use_paged_attn: + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + continue + block_size = 16 + else: + block_size = max_seq_len + 512 + num_blocks = batch * ((max_seq_len + block_size - 1) // block_size) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape = (num_blocks, head_num_kv, block_size, head_size) + scale_shape = (num_blocks, head_num_kv, block_size) if is_pertoken else (head_num_kv, head_size) + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + if kv_cache_dtype is not torch.int8: + key_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device) + value_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device) + key_cache_scale = None + value_cache_scale = None + else: + key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device) + value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device) + key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device) + value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device) + alibi_slopes = None + if alibi_bias: + alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device) + alibi_slopes.uniform_(0, 0.125) + softmax_scale = 1 / math.sqrt(head_size) + hardware_time, e2e_time = benchmark_forward(tmo.single_query_cached_kv_attn, + input_q, + key_cache, + value_cache, + None, + block_tables, + context_lens, + key_cache_scale, + value_cache_scale, + alibi_slopes, + max_context_len, + -1, + -1, + softmax_scale, + repeats=args.repeat_times) + content = [f"{batch}", f"{max_seq_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{alibi_bias}", f"{kv_cache_dtype}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_mixed_cached_kv_attn.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_mixed_cached_kv_attn.py new file mode 100644 index 0000000..6ec0a06 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_single_query_mixed_cached_kv_attn.py @@ -0,0 +1,131 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import * +import argparse +from tabulate import tabulate +import os +import math +import random + +e2e_time_param_dict_list = [{"batch": 16, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1, + "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8, + "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1, + "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8, + "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 512, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1, + "head_size": 128, "alibi_bias": True, "quant_bit_lp": 4, "quant_bit_hp": 8, + "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 16, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1, + "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8, + "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}, + {"batch": 128, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1, + "head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8, + "use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}] + +def gen_cache(batch, num_kv_heads, head_size, is_pagedattn, max_context_len, data_type, quant_bit, quant_mode): + int_max = float(2 ** (quant_bit - 1) - 1) + int_min = -float(2 ** (quant_bit - 1)) + context_lens = torch.randint(max_context_len, max_context_len + 1, (batch, ), dtype=torch.int32).mlu() + block_size = 16 + if is_pagedattn is False: + block_size = max_context_len + num_blocks = (int)(batch * ((max_context_len + block_size - 1)/ block_size)) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape = (num_blocks, num_kv_heads, block_size, head_size) + if quant_mode == "per_token": + scale_shape = (num_blocks, num_kv_heads, block_size, 1) + else: # per channel + scale_shape = (num_kv_heads, head_size) + + if quant_bit == 4: + cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2) + cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size) + key_cache = torch.zeros(cache_shape_k_int4).uniform_(int_min, int_max).to(torch.int8).mlu() + value_cache = torch.zeros(cache_shape_v_int4).uniform_(int_min, int_max).to(torch.int8).mlu() + key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + elif quant_bit == 8: + key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu() + value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu() + key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + elif quant_bit == -1: + key_cache = torch.randn(cache_shape, dtype=data_type).mlu() + value_cache = torch.randn(cache_shape, dtype=data_type).mlu() + key_cache_scale = None + value_cache_scale = None + else: + print("!!!!!!!!!!!gen case error, quant_bit must be in {-1, 4, 8}") + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + return key_cache, value_cache, key_cache_scale, value_cache_scale, context_lens, block_tables + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "max_seq_len_lp", "max_seq_len_hp", "head_num_q", "head_num_kv", "head_size", "alibi_bias", "quant_bit_lp", "quant_bit_hp","use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + max_seq_len_lp = params_dict["max_seq_len_lp"] + max_seq_len_hp = params_dict["max_seq_len_hp"] + head_num_q = params_dict["head_num_q"] + head_num_kv = params_dict["head_num_kv"] + head_size = params_dict["head_size"] + alibi_bias = params_dict["alibi_bias"] + quant_bit_lp = params_dict["quant_bit_lp"] + quant_bit_hp = params_dict["quant_bit_hp"] + use_paged_attn = params_dict["use_paged_attn"] + input_dtype_list = params_dict["input_dtype"] + + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype) + input_q = input_qkv[..., 0 : head_num_q, :] + + params_lp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_lp, dtype, quant_bit_lp, "per_token") + params_hp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_hp, dtype, quant_bit_hp, "per_channel") + key_cache_lp, value_cache_lp, key_cache_scale_lp, value_cache_scale_lp, context_lens_lp, block_tables_lp = params_lp + key_cache_hp, value_cache_hp, key_cache_scale_hp, value_cache_scale_hp, context_lens_hp, block_tables_hp = params_hp + + alibi_slopes = None + if alibi_bias: + alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device) + alibi_slopes.uniform_(0, 0.125) + softmax_scale = 1 / math.sqrt(head_size) + hardware_time, e2e_time = benchmark_forward(tmo.single_query_mixed_cached_kv_attn, + input_q, + key_cache_lp, value_cache_lp, + key_cache_hp, value_cache_hp, + None, #output + block_tables_lp, block_tables_hp, + context_lens_lp, context_lens_hp, + key_cache_scale_lp, value_cache_scale_lp, + key_cache_scale_hp, value_cache_scale_hp, + alibi_slopes, + max_seq_len_lp, max_seq_len_hp, + softmax_scale, True, + quant_bit_lp, quant_bit_hp, + repeats=args.repeat_times) + content = [f"{batch}", f"{max_seq_len_lp}", f"{max_seq_len_hp}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{alibi_bias}", f"{quant_bit_lp}", f"{quant_bit_hp}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_smooth_quant_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_smooth_quant_matmul.py new file mode 100644 index 0000000..7c7e427 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_smooth_quant_matmul.py @@ -0,0 +1,69 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, + "act_mode": "none", "output_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True, + "act_mode": "silu", "output_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "output_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + m = params_dict["m"] + k = params_dict["k"] + n = params_dict["n"] + has_c = params_dict["has_c"] + has_bias = params_dict["has_bias"] + act_mode = params_dict["act_mode"] + output_dtype_list = params_dict["output_dtype"] + for dtype in output_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + a = torch.randn(m, k).to(device).to(torch.int8) + b = torch.randn(n, k).to(device).to(torch.int8) + a_scale = torch.randn(m).to(device) + b_scale = torch.randn(n).to(device) + c = None + if has_c: + c = torch.randn(m, n).to(device).to(dtype) + bias = None + if has_bias: + bias = torch.randn(n).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_matmul, + a, + a_scale, + b, + b_scale, + dtype, + bias, + c, + act_mode, + 1.0, + 1.0, + repeats=args.repeat_times) + content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh new file mode 100644 index 0000000..39f28e0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh @@ -0,0 +1,15 @@ +#!/bin/bash +LOG_PATH=${LOG_PATH:-.} + +files=($(ls benchmark_*.py)) +for file in "${files[@]}"; do + echo "test ${file}..." + op_name=$(basename "$file" .py) + python "$file" > ${LOG_PATH}/${op_name}.log 2>&1 + ret_tmp=$? + cat ${LOG_PATH}/${op_name}.log + if [ $ret_tmp != 0 ]; then + echo "${sc} test failed..." + exit $ret_tmp + fi +done diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_update_out_and_lse.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_update_out_and_lse.py new file mode 100644 index 0000000..29107b4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_update_out_and_lse.py @@ -0,0 +1,99 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv +import argparse +from tabulate import tabulate +import os +import random + +e2e_time_param_dict_list = [{"batch": 16, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1, + "dtype": [torch.float16], "pack": False}, + {"batch": 128, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1, + "dtype": [torch.float16], "pack": False}, + {"batch": 512, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1, + "dtype": [torch.float16], "pack": False}, + {"batch": 1024, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1, + "dtype": [torch.float16], "pack": False}, + {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 1024, "max_seq_len": 1024, + "dtype": [torch.float16], "pack": True}, + {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 2048, "max_seq_len": 2048, + "dtype": [torch.float16], "pack": True}, + {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 8192, "max_seq_len": 8192, + "dtype": [torch.float16], "pack": True}, + {"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 32768, "max_seq_len": 32768, + "dtype": [torch.float16], "pack": True},] + +def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack): + if not pack: + out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype) + lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32) + block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype) + block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32) + seq_offset = None + cu_seqs = None + block_cu_seqs = None + else: + seq_lens = torch.randint(low=max_seq_len, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32) + block_seq_lens = torch.randint(low=block_seq_len, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32) + block_seq_lens = torch.minimum(seq_lens, block_seq_lens) + seq_offset = torch.zeros_like(seq_lens) + for i in range(batch): + seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32) + seq_offset = seq_offset.mlu() + cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu() + block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu() + total_seqs = torch.sum(seq_lens) + block_total_seqs = torch.sum(block_seq_lens) + + out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype) + lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32) + block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype) + block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32) + return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["batch", "head_num", "head_size", "block_seq_len", "max_seq_len", "dtype", "pack", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + batch = params_dict["batch"] + head_num = params_dict["head_num"] + head_size = params_dict["head_size"] + block_seq_len = params_dict["block_seq_len"] + max_seq_len = params_dict["max_seq_len"] + dtype_list = params_dict["dtype"] + pack = params_dict["pack"] + for dtype in dtype_list: + out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \ + gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack) + + hardware_time, e2e_time = benchmark_forward(tmo.update_out_and_lse, + out, + lse, + block_out, + block_lse, + seq_offset, + cu_seqs, + block_cu_seqs, + repeats=args.repeat_times) + content = [f"{batch}", f"{head_num}", f"{head_size}", f"{block_seq_len}", f"{max_seq_len}", f"{dtype}", f"{pack}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/benchmark_weight_only_quant_matmul.py b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_weight_only_quant_matmul.py new file mode 100644 index 0000000..7236df5 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/benchmark_weight_only_quant_matmul.py @@ -0,0 +1,70 @@ +import torch +import torch_mlu +import torch_mlu_ops as tmo +from common import benchmark_forward, save_to_csv, save_to_csv +import argparse +from tabulate import tabulate +import os + +e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, + "act_mode": "none", "quant_bit": 8, "input_dtype": [torch.float16, torch.bfloat16]}, + {"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True, + "act_mode": "silu", "quant_bit": 4, "input_dtype": [torch.float16, torch.bfloat16]}] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing') + parser.add_argument('--csv', action='store_true', help='write the report data to csv') + parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode') + + args = parser.parse_args() + device = 'mlu' + + titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "quant_bit", "input_dtype", "hardware_time(us)", "e2e_latency(us)"] + contents = [] + for params_dict in e2e_time_param_dict_list: + m = params_dict["m"] + k = params_dict["k"] + n = params_dict["n"] + quant_bit = params_dict["quant_bit"] + has_c = params_dict["has_c"] + has_bias = params_dict["has_bias"] + act_mode = params_dict["act_mode"] + input_dtype_list = params_dict["input_dtype"] + for dtype in input_dtype_list: + if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported(): + continue + a = torch.randn(m, k).to(device).to(dtype) + b = torch.randn(n, k if quant_bit == 8 else k//2).to(device).to(torch.int8) + scale = torch.randn(n).to(device) + zero = None + c = None + if has_c: + c = torch.randn(m, n).to(device).to(dtype) + bias = None + if has_bias: + bias = torch.randn(n).to(device).to(dtype) + hardware_time, e2e_time = benchmark_forward(tmo.weight_only_quant_matmul, + a, + b, + scale, + zero, + bias, + c, + act_mode, + quant_bit, + 1.0, + 1.0, + repeats=args.repeat_times) + content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{quant_bit}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"] + contents.append(content) + table = [titles] + contents + print(tabulate(table, headers="firstrow", tablefmt="grid")) + + if args.csv: + current_file_path = __file__ + _, file_name = os.path.split(current_file_path) + save_to_csv(table, args.o, file_name) + +if __name__=="__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/benchmarks/common.py b/torch_mlu_ops-v1.3.2/benchmarks/common.py new file mode 100644 index 0000000..940d88a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/benchmarks/common.py @@ -0,0 +1,65 @@ +import torch +import torch_mlu +import time +from pathlib import Path +import csv +import os +import subprocess +from itertools import product + +def benchmark_forward(fn, *inputs, repeats=1, **kwinputs): + notify_start = torch.mlu.Event(enable_timing=True) + notify_end = torch.mlu.Event(enable_timing=True) + notify_start.record() + t0 = time.perf_counter() + for _ in range(repeats): + fn(*inputs, **kwinputs) + notify_end.record() + notify_end.synchronize() + total_e2e_time = time.perf_counter() - t0 + average_e2e_time = total_e2e_time / repeats * 1e6 + total_hardware_time = notify_start.hardware_time(notify_end) + average_hardware_time = total_hardware_time / repeats + return average_hardware_time, average_e2e_time + +def save_to_csv(table, file_path, file_name): + file_name_without_ext, _ = os.path.splitext(file_name) + new_file_name = file_name_without_ext + '.csv' + if file_path is None: + file_path = './' + path = Path(file_path) + if path.suffix: + directory = path.parent + filename = path.name + else: + directory = path + filename = new_file_name + if not directory.exists(): + directory.mkdir(parents=True, exist_ok=True) + full_path = directory / filename + if not full_path.exists(): + full_path.touch() + with open(full_path, mode="w", newline="") as file: + writer = csv.writer(file) + writer.writerows(table) + print(f"output saved at: {full_path}") + +def get_band_width(card_id: int = 0): + cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2" + res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + assert res.returncode == 0, "Failed to get BandWidth." + bd = int(res.stdout.decode().strip()) + return bd + +def generate_token_count(num_expert, + total_token_count): + token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32) + sum = torch.sum(token_count, dim=-1) * 1.0 + token_count *= total_token_count / sum.item() + token_count = token_count.to(dtype=torch.int32) + cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32) + end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count + cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) + cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count) + cusum_token_count[-1] = total_token_count + return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1] diff --git a/torch_mlu_ops-v1.3.2/build.property b/torch_mlu_ops-v1.3.2/build.property new file mode 100644 index 0000000..ccad5aa --- /dev/null +++ b/torch_mlu_ops-v1.3.2/build.property @@ -0,0 +1,8 @@ +TMO_VERSION=1.3.2-1 +CNCL_VERSION=1.24.1-1 +CNNL_VERSION=1.28.3-1 +CNNLEXTRA_VERSION=1.12.3-1 +CNTOOLKIT_VERSION=3.15.6-1 +MLUOPS_VERSION=1.4.2-1 +TORCH_VERSION=1.24.1-1 +TRITON_VERSION=1.3.1-1 diff --git a/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp new file mode 100644 index 0000000..b292474 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifdef __GNUC__ + +#include "stack_exception.h" +#include +#include +#include +#include +#include + +#define MAX_DEPTH 32 + +namespace tmo { +namespace stack_exception { + +call_stack::call_stack(const size_t num_discard /*= 0*/) { + using namespace abi; + + // retrieve call-stack + void *trace[MAX_DEPTH]; + int stack_depth = backtrace(trace, MAX_DEPTH); + + for (int i = num_discard + 1; i < stack_depth; i++) { + Dl_info dlinfo; + if (!dladdr(trace[i], &dlinfo)) break; + + const char *symname = dlinfo.dli_sname; + + int status; + char *demangled = abi::__cxa_demangle(symname, NULL, 0, &status); + if (status == 0 && demangled) symname = demangled; + + // printf("entry: %s, %s\n", dlinfo.dli_fname,symname); + + // store entry to stack + if (dlinfo.dli_fname && symname) { + entry e; + e.file = dlinfo.dli_fname; + e.line = 0; // unsupported + e.function = symname; + stack.push_back(e); + } else { + break; // skip last entries below main + } + + if (demangled) free(demangled); + } +} + +call_stack::~call_stack() throw() { + // automatic cleanup +} +} // namespace stack_exception +} // namespace tmo + +#endif // __GNUC__ diff --git a/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h new file mode 100644 index 0000000..2f37c73 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h @@ -0,0 +1,127 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_COMMON_STACK_EXCEPTION_H_ +#define CSRC_COMMON_STACK_EXCEPTION_H_ + +#include +#include +#include +#include + +namespace tmo { +namespace stack_exception { +/** Call-stack entry datastructure. */ +struct entry { + /** Default constructor that clears all fields. */ + entry() : line(0) {} + + std::string file; ///< filename + size_t line; ///< line number + std::string function; ///< name of function or method + + /** Serialize entry into a text string. */ + std::string to_string() const { + std::ostringstream os; + os << file; + if (line > 0) { + os << ":" << line; + } + os << " @" << function; + return os.str(); + } +}; + +/** Stack-trace base class, for retrieving the current call-stack. */ +class call_stack { + public: + /** Stack-trace consructor. + \param num_discard - number of stack entries to discard at the top. */ + call_stack(const size_t num_discard = 0); + + virtual ~call_stack() throw(); + + /** Serializes the entire call-stack into a text string. */ + std::string to_string() const { + std::ostringstream os; + for (size_t i = 0; i < stack.size(); i++) + os << stack[i].to_string() << std::endl; + return os.str(); + } + + /** Call stack. */ + std::vector stack; +}; + +/** Abstract base-class for all stack-augmented exception classes. + * Enables catching of all stack-augmented exception classes. */ +class stack_exception_base : public call_stack { + public: + stack_exception_base(const bool _show_stack) : call_stack(2), show_stack(_show_stack) {} + virtual ~stack_exception_base() throw() {} + + virtual const char *what() const throw() = 0; + + /// flag to indicate if stack-trace is included in what() messages + bool show_stack; +}; + +/** Template for stack-augmented exception classes. */ +template +class stack_exception : public T, public stack_exception_base { + public: + stack_exception(const std::string &msg) : T(msg), stack_exception_base(true) {} + virtual ~stack_exception() throw() {} + + stack_exception(const char *file, int line, const char *pretty_function, const std::string &msg) + : T(msg), stack_exception_base(true) { + entry e; + e.file = file; + e.line = line; + e.function = pretty_function; + stack.insert(stack.begin(), e); + } + + virtual const char *what() const throw() { + if (show_stack) { + // concatenate message with stack trace + buffer = "[" + std::string(T::what()) + "]\n" + stack_exception::to_string(); + return buffer.c_str(); + } else { + return T::what(); + } + } + + private: + mutable std::string buffer; +}; +/** Stack-augmented exception classes for all std::exception classes. */ +// typedef stack_exception TmoException; +// typedef stack_exception stack_range_error; +// typedef stack_exception stack_overflow_error; +// typedef stack_exception stack_underflow_error; +// typedef stack_exception stack_logic_error; +// typedef stack_exception stack_domain_error; +// typedef stack_exception stack_invalid_argument; +// typedef stack_exception stack_length_error; +// typedef stack_exception stack_out_of_range; +} // namespace stack_exception + +class _TmoException : public stack_exception::stack_exception { + public: + _TmoException(const std::string &msg) : stack_exception(msg) {} + _TmoException(const char *file, int line, const char *pretty_function, const std::string &msg) + : stack_exception(file, line, pretty_function, msg) {} +}; +#define TmoException(msg) tmo::_TmoException(__FILE__, __LINE__, __PRETTY_FUNCTION__, msg) +} // namespace tmo + +#endif // CSRC_COMMON_STACK_EXCEPTION_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/common/utils.h b/torch_mlu_ops-v1.3.2/csrc/common/utils.h new file mode 100644 index 0000000..3c1e916 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/common/utils.h @@ -0,0 +1,293 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_COMMON_UTILS_H_ +#define CSRC_COMMON_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include "cn_api.h" +#include "cnnl.h" +#include "cnnl_extra.h" +#include "cnrt.h" +#include "stack_exception.h" + +namespace tmo { +inline cnnlQuantizeLayout_t strToQuantizeLayout(std::string param) { + static std::map quantize_layout_map = { + {"quantize_none", CNNL_QUANTIZE_NONE}, + {"quantize_per_tensor", CNNL_QUANTIZE_PER_TENSOR}, + {"quantize_per_channel", CNNL_QUANTIZE_PER_CHANNEL}, + {"quantize_per_token", CNNL_QUANTIZE_PER_TOKEN}, + {"quantize_group_wise", CNNL_QUANTIZE_GROUP_WISE}}; + return quantize_layout_map[param]; +} + +inline cnnlActivationMode_t strToActivationMode(std::string param) { + static std::map act_mode_map = { + {"gelu", CNNL_ACTIVATION_GELU}, + {"relu", CNNL_ACTIVATION_RELU}, + {"sigmoid", CNNL_ACTIVATION_SIGMOID}, + {"silu", CNNL_ACTIVATION_SWISH}, + {"none", CNNL_ACTIVATION_IDENTITY}}; + return act_mode_map[param]; +} +inline cnnlLLMQuantAlgo_t strToQuantizeAlgo(std::string param) { + static std::map quant_algo_map = { + {"weight_only", CNNL_WEIGHT_ONLY}, + {"smooth_quant", CNNL_SMOOTH_QUANT}, + {"none", CNNL_NO_QUANT}}; + return quant_algo_map[param]; +} + +namespace lnres { +namespace internal { +using LnresEnum = cnnlTransformerLayernormResidualStructure_t; +struct Helper { + int layernorm_position; // 0: no layernorm, 1: pre layernorm, 2: post layernorm + int residual_position; // 0: no residual, 1: layernorm inside residual, 2: layernorm outside + // residual + constexpr Helper(cnnlTransformerLayernormResidualStructure_t mode); + + constexpr Helper(int layernorm_position, int residual_position) + : layernorm_position(layernorm_position), residual_position(residual_position) {} + + constexpr bool operator==(const Helper &other) const { + return layernorm_position == other.layernorm_position && + residual_position == other.residual_position; + } + + constexpr operator cnnlTransformerLayernormResidualStructure_t() const; +}; + +constexpr int NO = 0; +constexpr int PRE = 1; +constexpr int POST = 2; +constexpr int CONTAIN = 1; +constexpr int EXCLUDE = 2; +using TPair = std::pair; +constexpr std::array pairs = { + TPair{{NO, NO}, CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL}, // noResidual + {{NO, CONTAIN}, CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL}, // useInputAsResidual + {{NO, EXCLUDE}, CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL}, // useInputAsResidual + {{PRE, NO}, CNNL_TRANSFORMER_PRE_LAYERNORM_NO_RESIDUAL}, // noResidual + {{PRE, CONTAIN}, CNNL_TRANSFORMER_PRE_LAYERNORM_INSIDE_RESIDUAL}, // useInputAsResidual + // residualThenLayernorm + {{PRE, EXCLUDE}, CNNL_TRANSFORMER_PRE_LAYERNORM_OUTSIDE_RESIDUAL}, // useLayernormAsResidual + // residualThenLayernorm + {{POST, NO}, CNNL_TRANSFORMER_POST_LAYERNORM_NO_RESIDUAL}, // noResidual + {{POST, CONTAIN}, CNNL_TRANSFORMER_POST_LAYERNORM_INSIDE_RESIDUAL}, // useInputAsResidual + // layernormThenResidual + {{POST, EXCLUDE}, CNNL_TRANSFORMER_POST_LAYERNORM_OUTSIDE_RESIDUAL}, // useInputAsResidual + // residualThenLayernorm +}; + +constexpr Helper from(LnresEnum mode) { + for (size_t i = 0; i < pairs.size(); ++i) { + if (pairs[i].second == mode) { + return pairs[i].first; + } + } + // throw TmoException("Invalid cnnlTransformerLayernormResidualStructure_t"); + return Helper(NO, NO); +} + +constexpr LnresEnum to(Helper mode) { + for (size_t i = 0; i < pairs.size(); ++i) { + if (pairs[i].first == mode) { + return pairs[i].second; + } + } + return CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL; + // throw TmoException("Invalid Helper"); +} + +constexpr Helper::Helper(LnresEnum mode) : Helper(from(mode)) {} + +constexpr Helper::operator LnresEnum() const { + return to(*this); +} +} // namespace internal + +using namespace internal; + +inline LnresEnum makeLnresEnum(bool has_ln, bool has_residual, bool residual_is_input) { + return Helper(has_ln ? PRE : NO, has_residual ? (residual_is_input ? CONTAIN : EXCLUDE) : NO); +} + +inline LnresEnum removeResidual(LnresEnum mode) { + Helper helper(mode); + return Helper(helper.layernorm_position, NO); +} + +inline LnresEnum removeLayernorm(LnresEnum mode) { + Helper helper(mode); + return Helper(NO, helper.residual_position); +} + +inline bool useLayernormAsResidual(LnresEnum mode) { + Helper helper(mode); + return helper.layernorm_position == PRE && helper.residual_position == EXCLUDE; +} + +inline bool useInputAsResidual(LnresEnum mode) { + Helper helper(mode); + return helper.residual_position == CONTAIN || + (helper.layernorm_position == NO && helper.residual_position != NO) || + (helper.layernorm_position == POST && helper.residual_position == EXCLUDE); +} + +inline bool hasResidual(LnresEnum mode) { + Helper helper(mode); + return helper.residual_position != NO; +} + +inline bool hasLayernorm(LnresEnum mode) { + Helper helper(mode); + return helper.layernorm_position != NO; +} + +inline bool isPostLayernorm(LnresEnum mode) { + Helper helper(mode); + return helper.layernorm_position == POST; +} + +inline bool isPreLayernorm(LnresEnum mode) { + Helper helper(mode); + return helper.layernorm_position == PRE; +} + +inline bool residualThenLayernorm(LnresEnum first_layer, LnresEnum second_layer) { + Helper h1(first_layer); + Helper h2(second_layer); + if (h1.residual_position == NO) { // h1 has no residual + return false; + } + + if (h1.layernorm_position == POST && h2.layernorm_position == PRE) { + throw TmoException("too many layernorms"); + } + + return (h1.residual_position != NO && h1.layernorm_position != POST && + h2.layernorm_position == PRE) || // l1 residual + l2 pre layernorm + (h1.layernorm_position == POST && h2.layernorm_position != PRE && + h1.residual_position == EXCLUDE); // l1 inside residual + l1 post layernorm +} + +inline bool layernormThenResidual(LnresEnum first_layer, LnresEnum second_layer) { + Helper h1(first_layer); + Helper h2(second_layer); + if (h1.residual_position == NO) { // h1 has no residual + return false; + } + + if (h1.layernorm_position == POST && h2.layernorm_position == PRE) { + throw TmoException("too many layernorms"); + } + return (h1.layernorm_position == POST && h1.residual_position == CONTAIN); +} + +inline bool residualOnly(LnresEnum first_layer, LnresEnum second_layer) { + Helper h1(first_layer); + Helper h2(second_layer); + return h1.residual_position != NO && h1.layernorm_position != POST && + h2.layernorm_position != PRE; +} +} // namespace lnres +} // namespace tmo + +#ifndef CNNL_CHECK +#define CNNL_CHECK(expr) \ + if (expr != CNNL_STATUS_SUCCESS) { \ + std::cerr << __FILE__ << ":" << __LINE__ \ + << " Check failed: " #expr " == CNNL_STATUS_SUCCESS. " << std::endl; \ + } +#endif + +#define CNNL_CHECK_FATAL(expr) \ + if ((expr) != CNNL_STATUS_SUCCESS) { \ + std::cerr << __FILE__ << ":" << __LINE__ << ": " \ + << " Check failed: " #expr " == CNNL_STATUS_SUCCESS. " << std::endl; \ + throw TmoException("Check failed: " #expr " == CNNL_STATUS_SUCCESS."); \ + } + +#define TMO_KERNEL_CHECK_FATAL(expr) \ + if ((expr) != tmo::KernelStatus::KERNEL_STATUS_SUCCESS) { \ + std::cerr << __FILE__ << ":" << __LINE__ << ": " \ + << " Check failed: " #expr " == KernelStatus::KERNEL_STATUS_SUCCESS. " << std::endl; \ + throw TmoException("Check failed: " #expr " == KernelStatus::KERNEL_STATUS_SUCCESS."); \ + } + +#define CHECK_FATAL(expr, ...) \ + if (!(expr)) { \ + std::cerr << __FILE__ << ":" << __LINE__ << ": " \ + << " Check failed: " #expr ". " << tmo::stringize(__VA_ARGS__) << std::endl; \ + throw TmoException("Check failed: " #expr ". " + tmo::stringize(__VA_ARGS__)); \ + } + +#undef CNRT_CHECK +#define CNRT_CHECK(val) \ + do { \ + cnrtRet_t __ret = val; \ + if (__ret) { \ + printf("[%s:%d] CNRT error, code=%d(%s) \"%s\" \n", __FILE__, __LINE__, (unsigned int)__ret, \ + cnrtGetErrorStr(__ret), #val); \ + throw TmoException(cnrtGetErrorStr(__ret)); \ + } \ + } while (0) + +#define CN_CHECK(val) \ + do { \ + CNresult __ret = val; \ + if (__ret) { \ + const char *cn_err_string = nullptr; \ + cnGetErrorString(__ret, &cn_err_string); \ + printf("[%s:%d] CN error, code=%d(%s) \"%s\" \n", __FILE__, __LINE__, (unsigned int)__ret, \ + cn_err_string, #val); \ + throw TmoException(cn_err_string); \ + } \ + } while (0) + +#define PAD_UP_DIV(x, y) (((x) + (y) - 1) / (y)) + +#define TMO_EXPORT __attribute__((__visibility__("default"))) +#define TMO_HIDDEN __attribute__((__visibility__("hidden"))) + +#define DELETE_COPY_ASSIGN_CONSTRUCT(CLASSNAME) \ + CLASSNAME(const CLASSNAME &) = delete; \ + CLASSNAME(CLASSNAME &&) = delete; \ + CLASSNAME &operator=(const CLASSNAME &) = delete; \ + CLASSNAME &operator=(CLASSNAME &&) = delete; + +// Note: Return type without const when const object called. +#define CLASS_CAST_TYPE_OPERATOR_DEFINE(DESCNAME, DESCOBJECT) \ + inline operator DESCNAME() const { \ + return const_cast(DESCOBJECT); \ + } \ + inline operator DESCNAME() { \ + return DESCOBJECT; \ + } + +#endif // CSRC_COMMON_UTILS_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/CMakeLists.txt b/torch_mlu_ops-v1.3.2/csrc/kernels/CMakeLists.txt new file mode 100644 index 0000000..d5b1005 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/CMakeLists.txt @@ -0,0 +1,103 @@ +cmake_minimum_required(VERSION 3.8) +project(tmo_kernels) +message(STATUS "project name: ${PROJECT_NAME}") +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +################################################################################ +# Build Evironment +################################################################################ +set(BANG_TARGET_CPU_ARCH ${TARGET_CPU_ARCH}) +message("-- TARGET_CPU_ARCH=${TARGET_CPU_ARCH}") +set(TARGET_MLU_ARCH ${TARGET_MLU_ARCH}) +message("-- TARGET_MLU_ARCH=${TARGET_MLU_ARCH}") +set(NEUWARE_HOME ${NEUWARE_HOME}) +message("-- NEUWARE_HOME=${NEUWARE_HOME}") + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} + "${CMAKE_SOURCE_DIR}/cmake" + "${NEUWARE_HOME}/cmake" + "${NEUWARE_HOME}/cmake/modules" +) + +find_package(BANG) +if(NOT BANG_FOUND) + message(FATAL_ERROR "BANG cannot be found.") +else () + if (NOT BANG_CNCC_EXECUTABLE) + message(FATAL_ERROR "cncc not found, please ensure cncc is in your PATH env or set variable BANG_CNCC_EXECUTABLE from cmake. Otherwise you should check path used by find_program(BANG_CNCC_EXECUTABLE) in FindBANG.cmake") + endif() +endif() +set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}/test") +set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/lib") + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -pthread -pipe") +set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${CMAKE_C_FLAGS} -g3 -O0") +set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${CMAKE_C_FLAGS} -O3") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fPIC -std=c++17 -pthread -pipe") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${CMAKE_CXX_FLAGS} -g3 -O0") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CMAKE_CXX_FLAGS} -O3") + +set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS_RELEASE} -Wl,--gc-sections -fPIC") + +set(BANG_CNCC_FLAGS "-Wall -Werror -Wdeprecated-declarations -fPIC -std=c++17 -pthread --target=${TARGET_CPU_ARCH}") + +if ( "${_cncc_version}" VERSION_LESS "5.0.0") # [CNNLCORE-19128] + message(STATUS "Default rounding mode will be rn when computing float numbers, otherwise will be tz when computing int numbers") + # This compile option was enabled by JIRA: CNNLCORE-12027 + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -Xbang-cnas --deprecated-cvt-default-round-mode-rn") +endif() + +if(${TARGET_CPU_ARCH} MATCHES ".*x86_64.*") + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -mcmodel=large") +endif() + +string(TOLOWER ${CMAKE_BUILD_TYPE} _CMAKE_BUILD_TYPE_LOWER) +if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "debug") + message(STATUS "Build debug mode") + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g3 -O0") +endif() + +if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "release") + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3 -DNDEBUG") +endif() + +if(${TARGET_MLU_ARCH} MATCHES "CNFATBIN") + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=mtp_592 --bang-mlu-arch=mtp_613 --no-neuware-version-check") + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64") +else() + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=${TARGET_MLU_ARCH}") + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64") +endif() + +# setup predefined macro for host sources, only for single mlu arch, useful for edge +if (${TARGET_MLU_ARCH} MATCHES "^(m?tp_)?([0-9]+)$") + # convert mtp_xxx or tp_xxx to xxx + string(REGEX REPLACE "^(m?tp_)?([0-9]+)$" "\\2" _TARGET_MLU_ARCH ${TARGET_MLU_ARCH}) + add_definitions(-DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH}) + set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH}") +endif() + +################################################################################ +# Neuware Evironment +################################################################################ +if(EXISTS ${NEUWARE_HOME}) + include_directories("${NEUWARE_HOME}/include") + link_directories("${NEUWARE_HOME}/lib64") + link_directories("${NEUWARE_HOME}/lib") +else() + message(FATAL_ERROR "NEUWARE cannot be found, refer README.md to prepare NEUWARE_HOME environment.") +endif() + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}") + +################################################################################ +# Build TMO kernels +################################################################################ +# aux_source_directory(src DIR_SRCS) +file(GLOB_RECURSE bang_src_files FOLLOW_SYMLINKS "${CMAKE_CURRENT_SOURCE_DIR}/*.mlu") + +bang_add_library(tmo_kernels STATIC "${bang_src_files}") +target_link_libraries(tmo_kernels cnnl cnrt cndrv dl) diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mlu new file mode 100644 index 0000000..8e5a6b9 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mlu @@ -0,0 +1,28 @@ +#include "add_scalar.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { + +#define ONCHIP_DATA_NUM ((int)(__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int))) +__nram__ int nram_buffer[ONCHIP_DATA_NUM]; + +__mlu_global__ void MLUBlockAddScalar(int *dst, int *src, int count, int scalar) { + int offset = ONCHIP_DATA_NUM * taskId; + int deal_num = std::min(ONCHIP_DATA_NUM, count - offset); + if (deal_num <= 0) return; + __memcpy(nram_buffer, src + offset, deal_num * sizeof(int), GDRAM2NRAM); + __bang_add_scalar(nram_buffer, nram_buffer, scalar, deal_num); + __memcpy(dst + offset, nram_buffer, deal_num * sizeof(int), NRAM2GDRAM); +} +} // namespace kernels + +KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar) { + uint32_t task_dim = (count + ONCHIP_DATA_NUM - 1) / ONCHIP_DATA_NUM; + cnrtDim3_t dim{task_dim, 1, 1}; + kernels::MLUBlockAddScalar<<>>(dst, src, count, scalar); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mluh new file mode 100644 index 0000000..cbde7fd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mluh @@ -0,0 +1,29 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_ADD_SCALAR_MLUH_ +#define CSRC_KERNELS_ADD_SCALAR_MLUH_ +#include "kernel_utils.h" + +namespace tmo { +/** + * @brief Add src with a scalar and save the result to dst. + * @param queue: The queue for mlu. + * @param dst: Pointer to the MLU memory of dst. + * @param src: Pointer to the MLU memory of src. + * @param count: The elements number in src. + * @param scalar: The scalar to add. + * @note: only support int. dst can overlap with src. + */ +KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar); +} // namespace tmo + +#endif // CSRC_KERNELS_ADD_SCALAR_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/build.sh b/torch_mlu_ops-v1.3.2/csrc/kernels/build.sh new file mode 100755 index 0000000..7209179 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/build.sh @@ -0,0 +1,205 @@ +#!/bin/bash +set -e + +TOP_DIR="$( cd "$( dirname "$0" )" && pwd )" +cd ${TOP_DIR} + +################################################################################ +# Evironment Variables + +# BUILD_MODE: release/debug +# BUILD_DIR: build(default) +# TARGET_MLU_ARCH: CNFATBIN/MLU590 +# TARGET_CPU_ARCH: x86_64-linux-gnu +# TARGET_C_COMPILER: C comppiler full-path +# TARGET_CXX_COMPILER: CXX comppiler full-path +# STRIP strip tool path +################################################################################ +BUILD_MODE=${BUILD_MODE:-release} +BUILD_DIR="${BUILD_DIR:-build}" +BUILD_JOBS=${BUILD_JOBS:-32} +TARGET_MLU_ARCH=${TARGET_MLU_ARCH:-CNFATBIN} +TARGET_CPU_ARCH=${TARGET_CPU_ARCH:-$(uname -m)-linux-gnu} +TARGET_C_COMPILER=${TARGET_C_COMPILER:-gcc} +TARGET_CXX_COMPILER=${TARGET_CXX_COMPILER:-g++} +STRIP="${STRIP}" # empty by default, check later + +# to forward variable to other scripts +export BUILD_DIR + +################################################################################ +# Shell Common Functions +################################################################################ +check_deb_package() { + if [ -z "$(dpkg -l | grep ${1})" ]; then + echo "-- Please sudo apt install ${1}" + exit -1 + fi +} + +check_rpm_package() { + if [ -z "$(rpm -qa | grep ${1})" ]; then + echo "-- Please sudo yum install ${1}" + exit -1 + fi +} + +usage () { + echo "USAGE: build.sh " + echo + echo " If need specify neuware path, please:" + echo " export NEUWARE_HOME=/path/of/your/neuware" + echo + echo "OPTIONS:" + echo " -h, --help Print usage" + echo " If no --mluxxx specified, default arch is cnfatbin which contain all mlu arch" + echo " --mlu590 Build for target product MLU590: __BANG_ARCH__ = 592" + echo " cncc --bang-mlu-arch=mtp_592, cnas --mlu-arch mtp_592" + echo " -d, --debug Build test case with debug mode" + echo " -v, --verbose Build with verbose output" + echo " -j, --jobs=* Build parallel jobs" + echo " --cache Build without deleting BUILD_DIR contents first" +} + +################################################################################ +# Build Main Entry +################################################################################ +# 1. Check cmake tool for build, cmake-3.23.1 is recommended +if [ -f "/etc/os-release" ]; then + source /etc/os-release + if [[ "${NAME}" == Ubuntu* ]] || [[ "${NAME}" == Debian* ]]; then + check_deb_package cmake + CMAKE=cmake + elif [[ "${NAME}" == CentOS* ]] || [[ "${NAME}" == Kylin* ]]; then + if [[ "${VERSION_ID}" == 7 ]]; then + check_rpm_package cmake3 + CMAKE=cmake3 + else + check_rpm_package cmake + CMAKE=cmake + fi + elif [[ "${NAME}" == Anolis* ]];then + check_rpm_package cmake + CMAKE=cmake + else + echo "-- Not support build on this os!" + exit -1 + fi +else + echo "-- Not support build on this os!" + exit -1 +fi + +# 2. Create build dir +if [ ! -d "$BUILD_DIR" ]; then + mkdir "$BUILD_DIR" +fi + +# 3. Handle build options +cmdline_args=$(getopt -o h,d,v,j: --long help,debug,verbose,jobs:,mlu590,cache -n 'build.sh' -- "$@") +eval set -- "$cmdline_args" +if [ $? != 0 ]; then echo "Unknown options, use -h or --help" >&2 ; exit -1; fi +if [ $# != 0 ]; then + while true; do + case "$1" in + --mlu590) + TARGET_MLU_ARCH="mtp_592" + shift + ;; + -h | --help) + usage + exit 0 + ;; + -d | --debug) + BUILD_MODE="debug" + echo "-- Using debug mode." + shift + ;; + -v | --verbose) + BUILD_VERBOSE="VERBOSE=1" + shift + ;; + -j | --jobs) + shift + BUILD_JOBS=$1 + shift + ;; + --cache) + FLAG_KEEP_CACHE=1 + shift + ;; + --) + shift + break + ;; + *) + echo "-- Unknown options ${1}, use -h or --help" + usage + exit -1 + ;; + esac + done +fi + +# 5. Check NEUWARE_HOME and cncc +if [ ! -z "${NEUWARE_HOME}" ]; then + echo "-- using NEUWARE_HOME = ${NEUWARE_HOME}" +else + echo "-- NEUWARE_HOME is null, refer README.md to prepare NEUWARE_HOME environment." + exit -1 +fi + +# 6. Check device compiler +export PATH="${NEUWARE_HOME}/bin":$PATH +export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64":$LD_LIBRARY_PATH +if [ -z $(which cncc) ]; then + echo "-- ERROR: cannot find cncc" + exit -1 +fi +cncc --version || ( echo "-- ERROR: cncc is not for current CPU target" && exit -1 ) +echo "-- cncc: $(which cncc)" + +# Check host compiler +## check compiler version and consider activate devtoolset for CentOS 7 +if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then + if [ ! -f "/opt/rh/devtoolset-7/enable" ]; then + echo "You are using CentOS 7 but without 'devtoolset-7' installed." + echo "Please install devtoolset-7 or gnu-g++ that verion >= 5." + sleep 2 + else + source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \ + || echo "devtoolset-7 has installed on your server, but source failed." + fi +fi +if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then + echo "we do not support g++<5, try to use higher version" + exit 1 +fi +TARGET_C_COMPILER=$(which gcc) +TARGET_CXX_COMPILER=$(which g++) +echo "-- TARGET_C_COMPILER: " ${TARGET_C_COMPILER} +echo "-- TARGET_CXX_COMPILER: " ${TARGET_CXX_COMPILER} +export CC=$(basename ${TARGET_C_COMPILER}) +export CXX=$(basename ${TARGET_CXX_COMPILER}) + +################################################################################ +# Project Build +################################################################################ +CMAKE_EXTRA_OPTIONS=() + +SOURCE_DIR=${TOP_DIR} +pushd ${BUILD_DIR} + if [[ -z "${FLAG_KEEP_CACHE}" ]]; then + echo "Remove cmake cache ${PWD}" + rm -rf ./* + fi + ${CMAKE} -DCMAKE_BUILD_TYPE="${BUILD_MODE}" \ + -DNEUWARE_HOME="${NEUWARE_HOME}" \ + -DTARGET_MLU_ARCH="${TARGET_MLU_ARCH}" \ + -DTARGET_CPU_ARCH="${TARGET_CPU_ARCH}" \ + -DCMAKE_C_COMPILER="$(basename ${TARGET_C_COMPILER})" \ + -DCMAKE_CXX_COMPILER="$(basename ${TARGET_CXX_COMPILER})" \ + -DCMAKE_STRIP="${STRIP}" \ + ${CMAKE_EXTRA_OPTIONS[@]} ${SOURCE_DIR} +popd +${CMAKE} --build ${BUILD_DIR} -- ${BUILD_VERBOSE} -j${BUILD_JOBS} diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mlu new file mode 100644 index 0000000..9dd7f9a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mlu @@ -0,0 +1,192 @@ +#include +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "copy_blocks.mluh" +#include "kernel_utils.h" +// clang-format off +#include +// clang-format on + +namespace tmo { +#define NRAM_REMAIN_SIZE (32 * 1024) +#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) +#define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753 +#define LAYER_SIZE 128 +#define BLOCK_PAIR_SIZE 512 +#define ALIGN_BYTES 64 + +struct CopyBlocksInfo { + void *key_addrs[LAYER_SIZE]; + void *value_addrs[LAYER_SIZE]; + unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2]; + bool has_value_cache = true; +}; + +namespace kernels { +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; + +__mlu_func__ void copyBlocksNodld(CopyBlocksInfo info, + uint32_t num_per_core, + uint32_t block_mapping_offset, + int32_t num_layers, + uint32_t block_size_in_bytes) { + for (uint32_t i = 0; i < num_per_core; i++) { + uint32_t map_offset = block_mapping_offset + i * 2; + uint32_t src_idx = info.mapping_addrs[map_offset]; + uint32_t dst_idx = info.mapping_addrs[map_offset + 1]; + int64_t src_offset = block_size_in_bytes * src_idx; + int64_t dst_offset = block_size_in_bytes * dst_idx; + for (uint32_t j = 0; j < num_layers; j++) { + __memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset, + block_size_in_bytes, GDRAM2GDRAM); + if (info.has_value_cache) { + __memcpy((int8_t *)info.value_addrs[j] + dst_offset, + (int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM); + } + } + } +} + +__mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info, + int32_t num_pairs, + int32_t num_layers, + uint32_t block_size_in_bytes) { + uint32_t num_per_core = num_pairs / taskDim; + uint32_t remain_for_core = num_pairs % taskDim; + num_per_core += ((taskId < remain_for_core) ? 1 : 0); + uint32_t block_mapping_offset = + num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core); + block_mapping_offset *= 2; +#if (__BANG_ARCH__ >= 592) + if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) { + auto num_pair_data_width = sizeof(int32_t); + uint32_t align_num = ALIGN_BYTES / num_pair_data_width; + unsigned int num_per_core_2 = num_per_core * 2; + unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num; + unsigned int *gather_src_offset = (unsigned int *)nram_buffer; + unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align; + int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align); + uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2; + unsigned int *scatter_dst_offset = gather_src_offset + num_per_core; + + uint32_t num_per_loop = nram_remain / block_size_in_bytes; + uint32_t repeat = num_per_core / num_per_loop; + uint32_t remain = num_per_core % num_per_loop; + + for (int i = 0; i < num_per_core; i++) { + unsigned int mapping_addrs_idx = block_mapping_offset + i * 2; + block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx]; + block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1]; + } + __bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes, + num_per_core_2); + __sync(); + for (uint32_t k = 0; k < num_layers; k++) { + for (uint32_t i = 0; i < repeat; i++) { + __gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop, + (unsigned int)block_size_in_bytes, GDRAM2NRAM, + (unsigned int)block_size_in_bytes, num_per_loop); + __scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop, + (unsigned int)block_size_in_bytes, NRAM2GDRAM, + (unsigned int)block_size_in_bytes, num_per_loop); + if (info.has_value_cache) { + __gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop, + (unsigned int)block_size_in_bytes, GDRAM2NRAM, + (unsigned int)block_size_in_bytes, num_per_loop); + __scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop, + (unsigned int)block_size_in_bytes, NRAM2GDRAM, + (unsigned int)block_size_in_bytes, num_per_loop); + } + } + if (remain != 0) { + uint32_t repeat_nums = repeat * num_per_loop; + __gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums, + (unsigned int)block_size_in_bytes, GDRAM2NRAM, + (unsigned int)block_size_in_bytes, remain); + __scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums, + (unsigned int)block_size_in_bytes, NRAM2GDRAM, + (unsigned int)block_size_in_bytes, remain); + if (info.has_value_cache) { + __gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums, + (unsigned int)block_size_in_bytes, GDRAM2NRAM, + (unsigned int)block_size_in_bytes, remain); + __scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums, + (unsigned int)block_size_in_bytes, NRAM2GDRAM, + (unsigned int)block_size_in_bytes, remain); + } + } + } + } else { + copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes); + } +#else + copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes); +#endif +} +} // namespace kernels + +KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue, + const std::vector &key_caches, + const std::vector &value_caches, + const std::vector &block_mapping_vec, + const size_t block_size_in_bytes) { + CNdev dev; + cnCtxGetDevice(&dev); + + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + cnrtFunctionType_t k_type = cnrtFuncTypeBlock; + if (key_caches.empty()) { + std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (!value_caches.empty() && key_caches.size() != value_caches.size()) { + std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches " + << "size if value_caches is not empty." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + int32_t mapping_size = block_mapping_vec.size(); + int32_t num_pairs = mapping_size / 2; + uint32_t task_dim = std::min(num_pairs, cluster_num * core_num); + cnrtDim3_t k_dim{task_dim, 1, 1}; + + int32_t num_layers = key_caches.size(); + int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE); + int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num); + int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE); + int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num); + + CopyBlocksInfo info; + if (value_caches.empty()) { + info.has_value_cache = false; + } + for (int32_t i = 0; i < layer_loop_num; i++) { + int32_t sub_num_layers = + std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop); + for (int32_t l = 0; l < sub_num_layers; l++) { + info.key_addrs[l] = key_caches[l + i * layer_num_per_loop]; + if (info.has_value_cache) { + info.value_addrs[l] = value_caches[l + i * layer_num_per_loop]; + } + } + for (int32_t j = 0; j < pair_loop_num; j++) { + int32_t sub_num_pairs = + std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop); + int32_t lens_block_mapping = sub_num_pairs * 2; + int32_t block_vec_offset = j * pair_num_per_loop * 2; + for (int32_t m = 0; m < lens_block_mapping; m++) { + info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset]; + } + kernels::launchCopyBlocksKernel<<>>(info, sub_num_pairs, sub_num_layers, + block_size_in_bytes); + } + } + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mluh new file mode 100644 index 0000000..af5aadd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mluh @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_COPY_BLOCKS_MLUH_ +#define CSRC_KERNELS_COPY_BLOCKS_MLUH_ + +#include +#include "cnnl.h" +#include "kernel_utils.h" + +namespace tmo { +/** + * @brief Perform copy_blocks operation. + * @param queue: The queue for mlu. + * @param key_caches: Output/Input. Pointer to the MLU memory that stores the key_caches + * vector which has shape [num_layers]. + * @param value_caches: Output/Input. Pointer to the MLU memory that stores the value_caches + * vector which has shape [num_layers]. + * @param block_mapping_vec: block_mapping vector. + * @param block_size_in_bytes: one block data size. + */ +KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue, + const std::vector &key_caches, + const std::vector &value_caches, + const std::vector &block_mapping_vec, + const size_t block_size_in_bytes); +} // namespace tmo + +#endif // CSRC_KERNELS_COPY_BLOCKS_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mlu new file mode 100644 index 0000000..be4b56e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mlu @@ -0,0 +1,271 @@ +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "create_cos_sin_table.mluh" + +namespace { +// constexpr int LINEAR_SCALING = 0; +// constexpr int FIX_NTK_SCALING = 1; +constexpr int DYNAMIC_NTK_SCALING = 2; +} // namespace + +namespace tmo { +namespace kernels { +__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - 32 * 1024]; +__nram__ const float range[64] = { + 0.0F, 2.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 20.0F, + 22.0F, 24.0F, 26.0F, 28.0F, 30.0F, 32.0F, 34.0F, 36.0F, 38.0F, 40.0F, 42.0F, + 44.0F, 46.0F, 48.0F, 50.0F, 52.0F, 54.0F, 56.0F, 58.0F, 60.0F, 62.0F, 64.0F, + 66.0F, 68.0F, 70.0F, 72.0F, 74.0F, 76.0F, 78.0F, 80.0F, 82.0F, 84.0F, 86.0F, + 88.0F, 90.0F, 92.0F, 94.0F, 96.0F, 98.0F, 100.0F, 102.0F, 104.0F, 106.0F, 108.0F, + 110.0F, 112.0F, 114.0F, 116.0F, 118.0F, 120.0F, 122.0F, 124.0F, 126.0F}; + +__mlu_func__ void genRangeDims(float *range_nram, int elem_count) { + int count = 64; + __bang_move(range_nram, range, std::min(count, elem_count) * sizeof(float)); + while (count < elem_count) { + __bang_add_scalar(range_nram + count, range_nram, (float)count * 2.0F, + std::min(count, elem_count - count)); + count *= 2; + } +} + +__mlu_func__ int getBatchMaxSeqLen(int *seq_lens_nram, int *seq_lens, int batch) { + __memcpy(seq_lens_nram, seq_lens, batch * sizeof(int), GDRAM2NRAM); + __bang_argmax((float *)seq_lens_nram, (float *)seq_lens_nram, batch); + return __load_nram(seq_lens_nram); +} + +__mlu_func__ float getNTKAlpha(int curr_seq_len, int max_position_embeddings, int kv_seq_len) { + int seq_len = kv_seq_len > max_position_embeddings ? curr_seq_len : kv_seq_len; + float context_value = std::log2((float)seq_len / (float)max_position_embeddings) + 1.0F; + float ntk_alpha = std::pow(2.0F, std::ceil(context_value)) - 1.0F; + return std::max(ntk_alpha, 1.0F); +} + +__mlu_func__ void getRotaryInvFreq(float *inv_freq_nram, + float *base_nram, + float *range_nram, + float base, + int rotary_dim, + int elem_count) { + // inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) + __bang_write_value(base_nram, elem_count, base); + __bang_mul_scalar(inv_freq_nram, range_nram, 1.0F / (float)rotary_dim, elem_count); + __bang_log(base_nram, base_nram, elem_count); + __bang_mul(inv_freq_nram, inv_freq_nram, base_nram, elem_count); + __bang_pow2(inv_freq_nram, inv_freq_nram, elem_count); + __bang_recip(inv_freq_nram, inv_freq_nram, elem_count); +} + +template +__mlu_func__ void convertCosSinTable(float *cos_table, float *sin_table, int elem_count) {} + +template <> +__mlu_func__ void convertCosSinTable(float *cos_table, float *sin_table, int elem_count) { + __bang_float2half((half *)cos_table, cos_table, elem_count); + __bang_float2half((half *)sin_table, sin_table, elem_count); +} + +template <> +__mlu_func__ void convertCosSinTable(float *cos_table, + float *sin_table, + int elem_count) { + __bang_float2bfloat16((bfloat16_t *)cos_table, cos_table, elem_count); + __bang_float2bfloat16((bfloat16_t *)sin_table, sin_table, elem_count); +} + +__mlu_global__ void MLUUpdateCachedAlpha(float *rotary_emb_alpha_cached, + int *seq_lens, + int max_position_embeddings, + int batch) { + int *seq_lens_nram = (int *)nram_buffer; // [batch] + int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch); + rotary_emb_alpha_cached[taskIdY] = + getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len); +} + +template +__mlu_global__ void MLUCreateCosSinTableKernel(void *cos_sin_table, + float *rotary_emb_alpha_cached, + int *seq_lens, + int max_position_embeddings, + int batch, + int batch_stride, + int rotary_seq_len, + int rotary_dim, + int rotary_stride, + float rotary_base, + float rotary_scaling, + int rotary_scaling_type, + int seq_seg, + bool interleaved, + cnnlDataType_t dtype) { + int half_rotary_dim = rotary_dim / 2; + float *base_nram = (float *)nram_buffer; // [rotary_dim / 2] + float *range_nram = base_nram + half_rotary_dim; // [rotary_dim / 2] + float *inv_freq_nram = range_nram + half_rotary_dim; // [rotary_dim / 2] + float *freqs_nram = inv_freq_nram + half_rotary_dim; // [rotary_dim / 2] + float *cos_nram = freqs_nram + half_rotary_dim; // [rotary_dim] + float *sin_nram = cos_nram + rotary_dim; // [rotary_dim] + float *swap_nram = sin_nram + rotary_dim; // [rotary_dim] + int *seq_lens_nram = (int *)(swap_nram + rotary_dim); // [batch] + + genRangeDims(range_nram, half_rotary_dim); + + float adjust_base = rotary_base; + if (rotary_scaling_type == DYNAMIC_NTK_SCALING) { + int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch); + float ntk_alpha = getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len); + if (rotary_emb_alpha_cached[taskIdY] == ntk_alpha) { + return; + } + adjust_base = rotary_base * std::pow(ntk_alpha, (float)rotary_dim / (float)(rotary_dim - 2)); + } + + getRotaryInvFreq(inv_freq_nram, base_nram, range_nram, adjust_base, rotary_dim, half_rotary_dim); + + int seq_start = taskIdX * seq_seg; + int seq_end = (taskIdX + 1) * seq_seg > rotary_seq_len ? rotary_seq_len : (taskIdX + 1) * seq_seg; + + T *cos_table = (T *)cos_sin_table + (size_t)taskIdY * batch_stride; + T *sin_table = cos_table + rotary_dim; + + for (int idx = seq_start; idx < seq_end; ++idx) { + __bang_mul_scalar(freqs_nram, inv_freq_nram, idx, half_rotary_dim); + __bang_cos(cos_nram, freqs_nram, half_rotary_dim); + __bang_sin(sin_nram, freqs_nram, half_rotary_dim); + convertCosSinTable(cos_nram, sin_nram, half_rotary_dim); + if (!interleaved) { + __memcpy(cos_table + idx * rotary_stride, cos_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM, + half_rotary_dim * sizeof(T), 0, 1); + __memcpy(sin_table + idx * rotary_stride, sin_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM, + half_rotary_dim * sizeof(T), 0, 1); + } else { + __bang_move((T *)cos_nram + half_rotary_dim, (T *)cos_nram, half_rotary_dim * sizeof(T)); + __bang_transpose((T *)swap_nram, (T *)cos_nram, 2, half_rotary_dim); + __memcpy(cos_table + idx * rotary_stride, (T *)swap_nram, half_rotary_dim * 2 * sizeof(T), + NRAM2GDRAM); + __bang_move((T *)sin_nram + half_rotary_dim, (T *)sin_nram, half_rotary_dim * sizeof(T)); + __bang_transpose((T *)cos_nram, (T *)sin_nram, 2, half_rotary_dim); + __memcpy((T *)sin_table + idx * rotary_stride, (T *)cos_nram, half_rotary_dim * 2 * sizeof(T), + NRAM2GDRAM); + } + } +} + +#if __BANG_ARCH__ < 592 +template <> +__mlu_global__ void MLUCreateCosSinTableKernel(void *cos_sin_table, + float *rotary_emb_alpha_cached, + int *seq_lens, + int max_position_embeddings, + int batch, + int batch_stride, + int rotary_seq_len, + int rotary_dim, + int rotary_stride, + float rotary_base, + float rotary_scaling, + int rotary_scaling_type, + int seq_seg, + bool interleaved, + cnnlDataType_t dtype) {} +#endif +} // namespace kernels + +KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue, + void *cos_sin_table, + float *rotary_emb_alpha_cached, + int *seq_lens, + int max_position_embeddings, + int batch, + int batch_stride, + int rotary_seq_len, + int rotary_dim, + int rotary_stride, + float rotary_base, + float rotary_scaling, + int rotary_scaling_type, + bool interleaved, + cnnlDataType_t data_type) { + bool is_supported_dtype = data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_FLOAT || + data_type == CNNL_DTYPE_BFLOAT16; + if (!is_supported_dtype) { + std::cerr << "[invokeCreateCosSinTable]: unsupport data type for create cos sin table kernel." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + // clang-format off + void (*create_sin_cos_kernels[])( + void*, /* cos_sin_table */ + float*, /* rotary_emb_alpha_cached */ + int*, /* seq_lens */ + int, /* max_position_embeddings */ + int, /* batch */ + int, /* batch_stride */ + int, /* rotary_seq_len */ + int, /* rotary_dim */ + int, /* rotary_stride */ + float, /* rotary_base */ + float, /* rotary_scaling */ + int, /* rotary_scaling_type */ + int, /* seq_seg */ + bool, /* interleaved */ + cnnlDataType_t /* data_type */ + ) = { + kernels::MLUCreateCosSinTableKernel, + kernels::MLUCreateCosSinTableKernel, + kernels::MLUCreateCosSinTableKernel + }; + // clang-format on + + int kernel_index = 0; + if (data_type == CNNL_DTYPE_HALF) { + kernel_index = 0; + } else if (data_type == CNNL_DTYPE_FLOAT) { + kernel_index = 1; + } else { + kernel_index = 2; + } + + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num = 1; + int core_num = 1; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + + int used_core_num = std::min(rotary_seq_len, cluster_num * core_num); + int seq_seg = (rotary_seq_len + used_core_num - 1) / used_core_num; + + cnrtDim3_t dim1; + dim1.x = used_core_num; + dim1.y = rotary_scaling_type == DYNAMIC_NTK_SCALING ? batch : 1; + dim1.z = 1; + + if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) { + std::cerr << "[invokeCreateCosSinTable]: MLU300 devices do not support bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + create_sin_cos_kernels[kernel_index]<<>>( + cos_sin_table, rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch, + batch_stride, rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling, + rotary_scaling_type, seq_seg, interleaved, data_type); + + if (rotary_scaling_type == DYNAMIC_NTK_SCALING) { + cnrtDim3_t dim2; + dim2.x = 1; + dim2.y = batch; + dim2.z = 1; + kernels::MLUUpdateCachedAlpha<<>>( + rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch); + } + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mluh new file mode 100644 index 0000000..001c8f7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mluh @@ -0,0 +1,62 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_ +#define CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" + +namespace tmo { + +/** + * @brief Create cos and sin table for rotary embedding. + * @param queue: The queue for mlu. + * @param cos_sin_table: Output. Pointer to the MLU memory that stores the cos and sin table. + * If rotary_scaling_type is linear, the shape is [rotary_seq_len, rotary_stride]. + * If rotary_scaling_type is dynamic ntk, the shape is [batch, rotary_seq_len, + * rotary_stride]. + * @param rotary_emb_alpha_cached: Output/Input. Pointer to the MLU memory that + * stores the ntk alpha cache. Only used in dynamic ntk, the shape is [batch]. + * @param seq_lens: Input. Pointer to the MLU memory that stores the true sequence len. + * The shape is [batch]. + * @param max_position_embeddings: The maximum rotary embedding positions. + * @param batch: Batch size. + * @param batch_stride: The stride for batch dim of cos_sin_table. + * Only used in dynamic ntk, the value is rotary_seq_len * rotary_stride. + * @param rotary_seq_len: The rotary sequence length of cos and sin table. + * @param rotary_dim: The rotary dim value of cos and sin table. + * @param rotary_stride: The stride of rotary_seq_len dim for cos and sin table. + * @param rotary_base: The rotary base, value is usually 10000. + * @param rotary_scaling: The rotary scaling, value is usually 1. + * @param rotary_scaling_type: The rotary scaling type, value is linear or dynamic ntk. + * @param interleaved: A boolean value indicates compute mode of rotary embedding. + * @param dtype: Data type of cos and sin table generated. + */ +KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue, + void *cos_sin_table, + float *rotary_emb_alpha_cached, + int *seq_lens, + int max_position_embeddings, + int batch, + int batch_stride, + int rotary_seq_len, + int rotary_dim, + int rotary_stride, + float rotary_base, + float rotary_scaling, + int rotary_scaling_type, + bool interleaved, + cnnlDataType_t dtype); + +} // namespace tmo + +#endif // CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mlu new file mode 100644 index 0000000..f7592f4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mlu @@ -0,0 +1,812 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include +#include +#include +#include "dequant_from_linear_cache.mluh" +#include "quant_utils.h" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { + +#pragma bang walign(16) +#define REM_FOR_STACK (32 * 1024) +#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024) +#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK) +#define DEQUANT_LINEAR_PERHEAD kernels::MLUDequantFromLinearCacheKernelPerHead +#define DEQUANT_LINEAR_PERCHANNEL kernels::MLUDequantFromLinearCacheKernelPerChannel +#define DEQUANT_FUNC_LEN (24) +#define DEQUANT_BATCH_NUM (1024) + +__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE]; +__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE]; +__nram__ uint8_t pre_table_nram[TRANS_TABLE_SIZE]; +// Uses 8K = 1K * (4 + 4) to process offsets +__nram__ int32_t n_lens[DEQUANT_BATCH_NUM]; +__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM]; + +__mlu_func__ void calcu_offsets_per_channel(int32_t &cache_id, + size_t &context_offset, + size_t &cache_offset, + size_t &scale_offset, + const int32_t *cache_bs_id, + const int32_t *cache_seq_offsets, + const int32_t cache_mem_len, + const int32_t seq_len, + const int32_t seq_begin, + const int32_t seq_offset, + const int32_t batch_idx, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_seq_stride, + const size_t scale_bs_stride) { + cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx); + int32_t cache_seq_offset = + cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx); + if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) { + context_offset = context_seq_stride * (seq_offset + seq_begin); + cache_offset = cache_bs_stride * cache_id + cache_seq_stride * (cache_seq_offset + seq_begin); + scale_offset = scale_bs_stride * cache_id; + } else { + cache_id = -1; + } +} + +__mlu_func__ void calcu_offsets_per_head(int32_t &cache_id, + size_t &context_offset, + size_t &key_cache_offset, + size_t &value_cache_offset, + size_t &scale_offset, + const int32_t *cache_bs_id, + const int32_t *cache_seq_offsets, + const int32_t cache_mem_len, + const int32_t seq_len, + const int32_t seq_begin, + const int32_t seq_offset, + const int32_t batch_idx, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride) { + cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx); + int32_t cache_seq_offset = + cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx); + if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) { + context_offset = context_seq_stride * (seq_offset + seq_begin); + key_cache_offset = + cache_bs_stride * cache_id + key_cache_seq_stride * (cache_seq_offset + seq_begin); + value_cache_offset = + cache_bs_stride * cache_id + value_cache_seq_stride * (cache_seq_offset + seq_begin); + scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id; + } else { + cache_id = -1; + } +} + +template +__mlu_func__ void dequantize_per_channel(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + T *data, + Tc *cache, + Ts *scale, + const int32_t scale_num, + const int32_t seq_num, + const int32_t head_num, + const int32_t head_size, + const size_t context_offset, + const size_t cache_offset, + const size_t scale_offset, + const size_t context_seq_stride, + const size_t context_head_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + if (scale_bs_stride != 0) { + __memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM, + head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1); + } + + if (std::is_same::value) { + __memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM, + head_size >> 1, head_num - 1, scale_num >> 1, seq_num - 1, cache_head_stride, + head_num - 1, cache_seq_stride, seq_num - 1); + } else { + __memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM, + head_size * sizeof_(Tc), head_num - 1, scale_num * sizeof_(Tc), seq_num - 1, + cache_head_stride * sizeof_(Tc), head_num - 1, cache_seq_stride * sizeof_(Tc), + seq_num - 1); + } + + dequantize((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf, + seq_num * scale_num, scale_num); + __memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM, + context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T), + seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1); +} + +template +__mlu_func__ void dequantize_value_per_channel(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + int8_t *temp_nram, + T *data, + Tc *cache, + Ts *scale, + const int32_t scale_num, + const int32_t seq_num, + const int32_t head_num, + const int32_t head_size, + const size_t context_offset, + const size_t cache_offset, + const size_t scale_offset, + const size_t context_seq_stride, + const size_t context_head_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride, + const bool pad_front) { + /* Step 1. load scale [head_num, head_size]*/ + if (scale_bs_stride != 0) { + __memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM, + head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1); + } + + /* Step 2. load cache [load_seq_num, head_num, head_size] */ + int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2); + int32_t deal_seq_num = load_seq_num << 1; + __memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size, + head_num - 1, scale_num, load_seq_num - 1, cache_head_stride, head_num - 1, + cache_seq_stride, load_seq_num - 1); + /* Step 3. convert into int8 [load_seq_num, head_num, head_size, 2] */ + convert((int8_t *)output_nram, (int4x2_t *)input_nram, deal_seq_num * scale_num); + /* Step 4. transpose to [deal_seq_num (load_seq_num, 2), head_num, head_size] */ + trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram, + load_seq_num, head_num, head_size, 2); + /* Step 5. dequantize [save_seq_num, head_num, head_size] */ + int save_seq_num = pad_front ? seq_num - 1 : seq_num; + dequantize((T *)output_nram, (int8_t *)temp_nram + (pad_front ? scale_num : 0), + (Ts *)scale_nram, (Ts *)nbuf, save_seq_num * scale_num, scale_num); + /* Step 6. store [save_seq_num, head_num, head_size]*/ + __memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM, + context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T), + save_seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), + save_seq_num - 1); +} + +template +__mlu_func__ void dequantize_per_head(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + T *temp_nram, + T *data, + Tc *cache, + Ts *scale, + const int32_t scale_num, + const int32_t seq_num, + const int32_t head_num, + const int32_t head_size, + const size_t context_offset, + const size_t cache_offset, + const size_t scale_offset, + const size_t context_seq_stride, + const size_t context_head_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t scale_head_stride) { + __memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, seq_num * sizeof_(Ts), GDRAM2NRAM, + seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1); + if (std::is_same::value) { + __memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM, + head_size >> 1, seq_num - 1, seq_num * (head_size >> 1), head_num - 1, + cache_seq_stride, seq_num - 1, cache_head_stride, head_num - 1); + } else { + __memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM, + head_size * sizeof_(Tc), seq_num - 1, seq_num * head_size * sizeof_(Tc), head_num - 1, + cache_seq_stride * sizeof_(Tc), seq_num - 1, cache_head_stride * sizeof_(Tc), + head_num - 1); + } + + convert((float *)output_nram, (Tc *)input_nram, head_num * seq_num * head_size); + if (std::is_same::value) { + conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram, + head_num * seq_num, head_size, 1); + } else { + conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram, + head_num * seq_num, head_size, 1); + output_nram = (T *)temp_nram; + } + + __memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM, + context_seq_stride * sizeof_(T), seq_num - 1, context_head_stride * sizeof_(T), + head_num - 1, head_size * sizeof_(T), seq_num - 1, head_size * seq_num * sizeof_(T), + head_num - 1); +} + +template +__mlu_func__ void dequantize_value_per_head(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + T *temp_nram, + T *data, + Tc *cache, + Ts *scale, + const int32_t scale_num, + const int32_t seq_num, + const int32_t head_num, + const int32_t head_size, + const size_t context_offset, + const size_t cache_offset, + const size_t scale_offset, + const size_t context_seq_stride, + const size_t context_head_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride, + const bool pad_front) { + int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2); + int32_t deal_seq_num = load_seq_num << 1; + /* Step1. load scale first, [head_num, deal_seq_num] */ + __memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, deal_seq_num * sizeof_(Ts), GDRAM2NRAM, + deal_seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1); + /* Step2. load cache input, [head_num, load_seq_num, head_size, 2] for int4 */ + __memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size, + load_seq_num - 1, load_seq_num * head_size, head_num - 1, cache_seq_stride, + load_seq_num - 1, cache_head_stride, head_num - 1); + convert((int8_t *)output_nram, (Tc *)input_nram, head_num * head_size * deal_seq_num); + /* Step3. trans to [head_num, load_seq_num, 2, head_size]*/ + trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram, + head_num * load_seq_num, head_size, 1, 2); + /* Step4. dequant to T [head_num, deal_seq_num, head_size] */ + convert((float *)output_nram, (int8_t *)temp_nram, head_num * deal_seq_num * head_size); + if (std::is_same::value) { + conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram, + head_num * deal_seq_num, head_size, 1); + } else { + conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram, + head_num * deal_seq_num, head_size, 1); + output_nram = (T *)temp_nram; + } + + /* Step5. save [head_num, save_seq_num, head_size]*/ + int32_t save_seq_num = pad_front ? seq_num - 1 : seq_num; + __memcpy((T *)data + context_offset, (T *)output_nram + (pad_front ? head_size : 0), + head_size * sizeof_(T), NRAM2GDRAM, context_seq_stride * sizeof_(T), save_seq_num - 1, + context_head_stride * sizeof_(T), head_num - 1, head_size * sizeof_(T), save_seq_num - 1, + head_size * deal_seq_num * sizeof_(T), head_num - 1); +} + +template +__mlu_global__ void MLUDequantFromLinearCacheKernelPerChannel(void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const int32_t *context_lens, + const int32_t *context_seq_offsets, + const int32_t *cache_bs_id, + const int32_t *cache_seq_offsets, + const int32_t max_context_len, + const int32_t batch_size, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t cache_mem_len, + const int32_t head_size, + const int32_t seq_block, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + bool has_key = (key && key_cache && key_scale); + bool has_value = (value && value_cache && value_scale); + if (!(has_key || has_value)) { + return; + } + + /* *********************************nram space ************************************** + * NRAM |scale[head_num, head_size]|output/input[seq_block, head_num, head_size]| + */ + int32_t scale_num = head_num * head_size; + Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts)); + float *output_nram = (float *)nbuf + scale_num; + Tc *input_nram = (Tc *)output_nram + + (std::is_same::value + ? (7 * seq_block * (scale_num >> 1)) + : (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc))); + // temp_nram for nram to store int8 input + int8_t *temp_nram = (int8_t *)output_nram + seq_block * scale_num * 3; + int32_t seq_offset; + int32_t seq_len; + int32_t seq_begin; + int32_t deal_seq_num; + int32_t cache_id; + int32_t cache_seq_offset; + size_t context_offset; + size_t cache_offset; + size_t scale_offset; + process_offsets((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens, + (int32_t *)context_seq_offsets, batch_size); + if (has_key) { + load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride, + scale_head_stride); + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + calcu_offsets_per_channel(cache_id, context_offset, cache_offset, scale_offset, + (int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets, cache_mem_len, + seq_len, seq_begin, seq_offset, batch_idx, context_seq_stride, + cache_bs_stride, key_cache_seq_stride, scale_bs_stride); + if (cache_id < 0) continue; + dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key, + (Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num, head_num, + head_size, context_offset, cache_offset, scale_offset, + context_seq_stride, context_head_stride, cache_head_stride, + key_cache_seq_stride, scale_bs_stride, scale_head_stride); + } + } + + if (has_value) { + if (std::is_same::value) { + __reshape_nhwc2nchw_smallc_init(pre_table_nram, 2); + } + load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride, + scale_head_stride); + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + if (std::is_same::value) { + seq_begin = taskIdZ * seq_block; + cache_id = + cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx); + cache_seq_offset = cache_seq_offsets == nullptr + ? 0 + : __load_gdram((int32_t *)cache_seq_offsets + batch_idx); + // move seq_begin left by 1 when cache_seq_offset is odd + seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + if (cache_id >= 0 && cache_seq_offset >= 0 && + (cache_seq_offset + seq_len) <= cache_mem_len) { + context_offset = + context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0)); + // value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t + cache_offset = cache_bs_stride * cache_id + + value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2); + scale_offset = scale_bs_stride * cache_id; + } else { + cache_id = -1; + } + if (cache_id < 0) continue; + dequantize_value_per_channel( + (T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (int8_t *)temp_nram, (T *)value, + (Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num, head_num, head_size, + context_offset, cache_offset, scale_offset, context_seq_stride, context_head_stride, + cache_head_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride, + seq_begin == -1); + } else { + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + calcu_offsets_per_channel( + cache_id, context_offset, cache_offset, scale_offset, (int32_t *)cache_bs_id, + (int32_t *)cache_seq_offsets, cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx, + context_seq_stride, cache_bs_stride, value_cache_seq_stride, scale_bs_stride); + if (cache_id < 0) continue; + dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value, + (Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num, + head_num, head_size, context_offset, cache_offset, scale_offset, + context_seq_stride, context_head_stride, cache_head_stride, + value_cache_seq_stride, scale_bs_stride, scale_head_stride); + } + } + } +} + +template +__mlu_global__ void MLUDequantFromLinearCacheKernelPerHead(void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const int32_t *context_lens, + const int32_t *context_seq_offsets, + const int32_t *cache_bs_id, + const int32_t *cache_seq_offsets, + const int32_t max_context_len, + const int32_t batch_size, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t cache_mem_len, + const int32_t head_size, + const int32_t seq_block, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + bool has_key = (key && key_cache && key_scale); + bool has_value = (value && value_cache && value_scale); + if (!(has_key || has_value)) { + return; + } + + /* *********************************nram space ************************************** + * NRAM |scale[seq_block, head_num]|output/input[head_size, seq_block, head_num]|temp| + */ + int32_t scale_num = seq_block * head_num; + Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts)); + float *output_nram = (float *)nbuf + scale_num; + Tc *input_nram = (Tc *)output_nram + + (std::is_same::value + ? (7 * head_size * (scale_num >> 1)) + : head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)); + // temp_nram for nram to store converted output + float *temp_nram = (float *)output_nram + head_size * scale_num; + int32_t seq_offset; + int32_t seq_len; + int32_t seq_begin; + int32_t deal_seq_num; + int32_t cache_id; + size_t context_offset; + size_t key_cache_offset; + size_t value_cache_offset; + size_t scale_offset; + __bang_write_value((float *)nbuf, head_size * 16, 1.0f); + mvNram2WramLT16((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16); + process_offsets((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens, + (int32_t *)context_seq_offsets, batch_size); + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + calcu_offsets_per_head(cache_id, context_offset, key_cache_offset, value_cache_offset, + scale_offset, (int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets, + cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx, + context_seq_stride, cache_bs_stride, key_cache_seq_stride, + value_cache_seq_stride, scale_bs_stride); + if (cache_id < 0) continue; + if (has_key) { + dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram, + (T *)key, (Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num, + head_num, head_size, context_offset, key_cache_offset, scale_offset, + context_seq_stride, context_head_stride, cache_head_stride, + key_cache_seq_stride, scale_head_stride); + } + if (has_value && std::is_same::value) { + dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram, + (T *)value, (Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num, + head_num, head_size, context_offset, value_cache_offset, scale_offset, + context_seq_stride, context_head_stride, cache_head_stride, + value_cache_seq_stride, scale_head_stride); + } + } + + // process value int4 differently + if (has_value && std::is_same::value) { + int32_t cache_seq_offset; + __reshape_nhwc2nchw_smallc_init(pre_table_nram, 2); + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + seq_begin = taskIdZ * seq_block; + cache_id = + cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx); + cache_seq_offset = + cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx); + // move seq_begin left by 1 when cache_seq_offset is odd + seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) { + context_offset = + context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0)); + // value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t + value_cache_offset = cache_bs_stride * cache_id + + value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2); + scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id; + } else { + cache_id = -1; + } + if (cache_id < 0) continue; + dequantize_value_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, + (T *)temp_nram, (T *)value, (Tc *)value_cache, (Ts *)value_scale, + scale_num, deal_seq_num, head_num, head_size, context_offset, + value_cache_offset, scale_offset, context_seq_stride, + context_head_stride, cache_head_stride, value_cache_seq_stride, + scale_bs_stride, scale_head_stride, seq_begin == -1); + } + } +} +} // namespace kernels + +#define DEQUANT_LINEAR_INIT(T, Tc, Ts, C, Name) \ + template __mlu_global__ void kernels::MLUDequantFromLinearCacheKernel##Name( \ + void *key, void *value, const void *key_cache, const void *value_cache, \ + const void *key_scale, const void *value_scale, const int32_t *context_lens, \ + const int32_t *context_seq_offsets, const int32_t *cache_bs_id, \ + const int32_t *cache_seq_offsets, const int32_t max_context_len, const int32_t batch_size, \ + const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \ + const int32_t cache_mem_len, const int32_t head_size, const int32_t seq_block, \ + const size_t context_head_stride, const size_t context_seq_stride, \ + const size_t cache_bs_stride, const size_t cache_head_stride, \ + const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \ + const size_t scale_bs_stride, const size_t scale_head_stride); + +DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerChannel) +DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerChannel) +DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerChannel) +DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerChannel) +DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerChannel) +DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerChannel) +DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerHead) +DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerHead) +DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerHead) +DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerHead) +DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerHead) +DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerHead) +DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerChannel) +DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerChannel) +DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerChannel) +DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerChannel) +DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerChannel) +DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerChannel) +DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerHead) +DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerHead) +DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerHead) +DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerHead) +DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerHead) +DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerHead) + +typedef void (*DequantFromLinearCachePointer)(void *, // key + void *, // value + const void *, // key_cache + const void *, // value_cache + const void *, // key_scale + const void *, // value_scale + const int32_t *, // context_lens + const int32_t *, // context_seq_offsets + const int32_t *, // cache_bs_id + const int32_t *, // cache_seq_offsets + const int32_t, // max_context_len + const int32_t, // batch_size + const int32_t, // head_num + const int32_t, // key_group_num + const int32_t, // value_group_num + const int32_t, // cache_mem_len + const int32_t, // head_size + const int32_t, // seq_block + const size_t, // context_head_stride + const size_t, // context_seq_stride + const size_t, // cache_bs_stride + const size_t, // cache_head_stride + const size_t, // key_cache_seq_stride + const size_t, // value_cache_seq_stride + const size_t, // scale_bs_stride + const size_t); // scale_head_stride + +static DequantFromLinearCachePointer DequantFromLinearCacheFuncArr[DEQUANT_FUNC_LEN] = { + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERCHANNEL, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD, + DEQUANT_LINEAR_PERHEAD}; + +uint32_t getDequantLinearIdx(cnnlDataType_t dtype, + int32_t quant_mode, + int32_t quant_bit, + const void *context_seq_offset) { + uint32_t idx = 0; + idx += (quant_mode != 0) ? 6 : 0; + idx += (quant_bit != 8) ? 3 : 0; + idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0; + idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0; + idx += (context_seq_offset == nullptr) ? 12 : 0; + + return idx; +} + +void getBlockAndDimForLinear(int32_t &seq_block, + cnrtDim3_t &task_dim, + cnrtFunctionType_t &task_type, + const int32_t max_context_len, + const int32_t head_num, + const int32_t batch_size, + const int32_t head_size, + const int32_t quant_mode, + const int32_t quant_bit, + const cnnlDataType_t dtype) { + int32_t core_dim; + int32_t cluster_dim; + int32_t nram_size = 480 * 1024; + int32_t wram_size = 512 * 1024; + int32_t sram_size = 2016 * 1024; + getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK); + if (quant_mode == 0) { + seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1); + if (quant_bit == 4) { + if (seq_block <= 1) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * head_size * sizeof_(float) should be less than " + << (nram_size >> 1) << " when quant_mode is 0." << std::endl; + } + } else { + if (seq_block <= 0) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * head_size * sizeof_(float) should be less than " << nram_size + << " when quant_mode is 0." << std::endl; + } + } + } else { + int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half); + seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) + + (int64_t)head_num * head_size * dtype_size)); + if (quant_bit == 4) { + if (seq_block <= 1) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + " + "context_dtype_size) " + << "should be less than " << (nram_size >> 1) << " when quant_mode is 1." + << std::endl; + } + } else { + if (seq_block <= 0) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + " + "context_dtype_size) " + << "should be less than " << nram_size << " when quant_mode is 1." << std::endl; + } + } + + /* head_size * 64B put in the wram. */ + if (head_size * ONE_LINE >= wram_size) { + std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than " + << wram_size << " when quant_mode is 1." << std::endl; + } + } + seq_block = std::min(seq_block, max_context_len); + if (seq_block > 16 && seq_block < max_context_len) { + seq_block = PAD_DOWN(seq_block, 16); + } + if (quant_bit == 4) { + seq_block = PAD_DOWN(seq_block, 2); + } + int seq_seg = DIV_UP(max_context_len, seq_block); + // need an extra seg block to dealwith int4 value_cache [...,seq_len/2, head_size] + if (quant_bit == 4) { + seq_seg += 1; + } + uint32_t core_num = cluster_dim * core_dim; + if (batch_size * seq_seg <= (core_num / 2)) { + int times = core_num / batch_size / seq_seg; + seq_block = std::max(seq_block / times, 2); + if (quant_bit == 4) { + seq_block = PAD_DOWN(seq_block, 2); + } + seq_seg = DIV_UP(max_context_len, seq_block); + // same as above to dealwise int4 value_cache with an extra seg block + if (quant_bit == 4) { + seq_seg += 1; + } + } + + task_dim.x = 1; + task_dim.y = uint32_t(std::min(batch_size, cluster_dim * core_dim)); + task_dim.z = uint32_t(seq_seg); + task_type = cnrtFuncTypeBlock; +} + +KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue, + void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const void *context_lens, + const void *context_seq_offsets, + const void *cache_bs_id, + const void *cache_seq_offsets, + const int32_t max_context_len, + const int32_t batch_size, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t cache_mem_len, + const int32_t head_size, + const int32_t quant_mode, + const int32_t quant_bit, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride, + const cnnlDataType_t dtype) { + if (dtype == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) { + std::cerr << "[invokeDequantFromPagedCache]: " + "MLU300 devices do not support bfloat16." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + int32_t index; + int32_t seq_block; + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + getBlockAndDimForLinear(seq_block, k_dim, k_type, max_context_len, head_num, batch_size, + head_size, quant_mode, quant_bit, dtype); + index = getDequantLinearIdx(dtype, quant_mode, quant_bit, context_seq_offsets); + auto dequant_linear_func = DequantFromLinearCacheFuncArr[index]; + dequant_linear_func<<>>( + (void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache, + (const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens, + (const int32_t *)context_seq_offsets, (const int32_t *)cache_bs_id, + (const int32_t *)cache_seq_offsets, max_context_len, batch_size, head_num, key_group_num, + value_group_num, cache_mem_len, head_size, seq_block, context_head_stride, context_seq_stride, + cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride, + scale_bs_stride, scale_head_stride); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mluh new file mode 100644 index 0000000..9b66dfb --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mluh @@ -0,0 +1,108 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_ +#define CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" + +namespace tmo { +/** + * @brief De-quantizes the key and value tensors from the provided linear cache and scale. + * @param queue: The queue for mlu. + * @param key: Pointer to the MLU memory that stores the key tensor, + * with shape [total_seqlen, head_num, head_size]. Data type can be float32, half, + * or bfloat16. This parameter can be nullptr. + * @param value: Pointer to the MLU memory that stores the value tensor, + * with shape [total_seqlen, head_num, head_size]. Data type can be float32, + * half, or bfloat16.This parameter can be nullptr. + * @param key_cache: Pointer to the MLU memory that stores the key cache tensor, + * with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization + * or [max_bs, head_num, cache_mem_len, head_size//2] for 4-bit quantization. + * Data type must be int8. This parameter can be nullptr. + * @param value_cache: Pointer to the MLU memory that stores the value cache tensor, + * with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization + * or [max_bs, head_num, cache_mem_len//2, head_size] for 4-bit quantization. + * Data type must be int8. This parameter can be nullptr. + * @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale. + * Shape depends on quantization mode: + * - For per-channel quantization (quant_mode = 0): [head_num, head_size]. + * - For per-token quantization (quant_mode = 1): [max_batch_size, head_num, cache_mem_len]. + * Data type must be float32. This parameter can be nullptr. + * @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale, + * with the same shape as key_scale. Data type must be float32. This parameter can be + nullptr. + * @param context_lens: Pointer to the MLU memory that stores the sequence lengths. + * The shape must be [batch]. + * @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the + context. + * The shape must be [batch]. If nullptr, the default value is the cumulative sum of + context_lengths. + * @param cache_bs_id: Pointer to the MLU memory that stores the batch index in the cache. + * The shape must be [batch]. If nullptr, the default value is {0, 1, 2, ..., batch - 1}. + * @param cache_seq_offset: Pointer to the MLU memory that stores the sequence offset in the cache. + * The shape must be [batch]. If nullptr, the default value is 0 for every batch. + * @param max_contxt_len: The maximum sequence length of context. + * @param batch: Batch size. + * @param head_num: Head number. + * @param key_group_num: group number of key group-wise quantization. + * @param value_group_num: group number of value group-wise quantization. + * @param cache_mem_len: The maximum sequence length of cache. + * @param head_size: Head size. + * @param quant_mode: An integer value indicating the quantization mode: + * 0 for per-channel quantization and 1 for per-token quantization. + * @param quant_bit: An integer value indicating the quantization bit width: + * 8 for 8-bit quantization and 4 for 4-bit quantization. + * @param contxt_head_stride: The stride of head_num in context. + * @param contxt_seq_stride: The stride of max_contxt_len in context. + * @param cache_bs_stride: The stride of batch in cache. + * @param cache_head_stride: The stride of head_num in cache. + * @param key_cache_seq_stride: The stride of cache_mem_len in key cache. + * @param value_cache_seq_stride: The stride of cache_mem_len in value cache. + * @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant. + * @param cache_scale_head_stride: The stride of head in cache scale. + * @param dtype: The data type of the key and value tensors. + * @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key. + * If any of value/value_cache/value_scale is nullptr, no operation is performed on the value. + */ +KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue, + void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const void *context_lens, + const void *context_seq_offsets, + const void *cache_bs_ids, + const void *cache_seq_offsets, + const int32_t max_context_len, + const int32_t batch, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t cache_mem_len, + const int32_t head_size, + const int32_t quant_mode, + const int32_t quant_bit, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t cache_scale_bs_stride, + const size_t cache_scale_head_stride, + const cnnlDataType_t dtype); +} // namespace tmo + +#endif // CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mlu new file mode 100644 index 0000000..9e535d4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mlu @@ -0,0 +1,616 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include +#include +#include +#include "dequant_from_paged_cache.mluh" +#include "quant_utils.h" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { + +#pragma bang walign(16) +#define REM_FOR_STACK (32 * 1024) +#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024) +#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK) +#define DEQUANT_PAGED_PERHEAD kernels::MLUDequantFromPagedCacheKernelPerHead +#define DEQUANT_PAGED_PERCHANNEL kernels::MLUDequantFromPagedCacheKernelPerChannel +#define DEQUANT_FUNC_LEN (24) +#define DEQUANT_BATCH_NUM (1024) + +__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE]; +__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE]; +// Uses 8K = 1K * (4 + 4) to process offsets +__nram__ int32_t n_lens[DEQUANT_BATCH_NUM]; +__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM]; + +template +__mlu_func__ void dequantize_per_channel(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + T *data, + const int32_t scale_num, + const int32_t seq_num, + const int32_t head_num, + const int32_t head_size, + const size_t context_offset, + const size_t context_seq_stride, + const size_t context_head_stride) { + dequantize((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf, + seq_num * scale_num, scale_num); + __memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM, + context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T), + seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1); +} + +template +__mlu_func__ void dequantize_per_head(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + T *temp_nram, + T *data, + const int32_t seq_num, + const int32_t head_num, + const int32_t block_size, + const int32_t head_size, + const size_t context_offset, + const size_t context_seq_stride, + const size_t context_head_stride) { + int block_count = DIV_UP(seq_num, block_size); + int rem_token = seq_num % block_size; + convert((float *)output_nram, (Tc *)input_nram, block_count * head_num * block_size * head_size); + T *res_nram = std::is_same::value ? (T *)output_nram : (T *)temp_nram; + // dequantize [block_count, head_num, block_size, head_size] + conv_fuse_mul_cvt((T *)res_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram, + block_count * head_num * block_size, head_size, 1); + // copy to [seq_num, head_num, head_size] + int whole_block_count = block_count - int(rem_token > 0); + if (whole_block_count) { + for (int i = 0; i < head_num; ++i) { + // copy from [whole_block_count, i, block_size, head_size] + // to [whole_block_count * block_size, 1, head_size] + __memcpy((T *)data + context_offset + i * context_head_stride, + (T *)res_nram + i * block_size * head_size, head_size * sizeof_(T), NRAM2GDRAM, + context_seq_stride * sizeof_(T), block_size - 1, + block_size * context_seq_stride * sizeof_(T), whole_block_count - 1, + head_size * sizeof_(T), block_size - 1, + head_num * block_size * head_size * sizeof_(T), whole_block_count - 1); + } + } + + if (rem_token) { + // copy from [last, head_num, block_size(rem_token), head_size] + // to [rem_token, head_num, head_size] + __memcpy((T *)data + context_offset + whole_block_count * block_size * context_seq_stride, + (T *)res_nram + whole_block_count * head_num * block_size * head_size, + head_size * sizeof_(T), NRAM2GDRAM, context_head_stride * sizeof_(T), head_num - 1, + context_seq_stride * sizeof_(T), rem_token - 1, block_size * head_size * sizeof_(T), + head_num - 1, head_size * sizeof_(T), rem_token - 1); + } +} + +template +__mlu_func__ void load_input_per_channel(Tc *input_nram, + Tc *cache, + Tc *temp_nram, + int32_t *block_offsets, + uint32_t *cache_offsets, + const int32_t *block_tables, + const int32_t batch_idx, + const int32_t scale_num, + const int32_t max_block_num, + const int32_t head_num, + const int32_t block_size, + const int32_t head_size, + const int32_t seq_begin, + const int32_t deal_seq_num, + const size_t cache_bs_stride, + const size_t cache_head_stride) { + int32_t block_start = batch_idx * max_block_num + seq_begin / block_size; + int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size; + int32_t block_count = block_end - block_start + 1; + // make sure elements in block_tables >= 0 + __memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start, + block_count * sizeof_(int32_t), GDRAM2NRAM); + __bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets, + (uint32_t)cache_bs_stride * sizeof(Tc), block_count); +#if __BANG_ARCH__ >= 500 + // gather [block_count, head_num, block_size, head_size] + __gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets, + (uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM, + (uint32_t)cache_bs_stride * sizeof(Tc), block_count); + if (head_num != 1 && block_size != 1) { + // mv to [head_num, whole_block_count, block_size, head_size] + __memcpy((Tc *)temp_nram, (Tc *)input_nram, block_size * head_size * sizeof(Tc), NRAM2NRAM, + block_size * head_size * sizeof(Tc), block_count - 1, + block_count * block_size * head_size * sizeof(Tc), head_num - 1, + head_num * block_size * head_size * sizeof(Tc), block_count - 1, + block_size * head_size * sizeof(Tc), head_num - 1); + // mv to [whole_block_count, block_size, head_num, head_size] + __memcpy((Tc *)input_nram, (Tc *)temp_nram, head_size * sizeof(Tc), NRAM2NRAM, + head_size * sizeof(Tc), head_num - 1, head_num * head_size * sizeof(Tc), + block_count * block_size - 1, block_count * block_size * head_size * sizeof(Tc), + head_num - 1, head_size * sizeof(Tc), block_count * block_size - 1); + } +#endif +} + +template +__mlu_func__ void load_input_per_head(Tc *input_nram, + Ts *scale_nram, + Tc *cache, + Ts *scale, + Tc *temp_nram, + int32_t *block_offsets, + uint32_t *cache_offsets, + uint32_t *scale_offsets, + const int32_t *block_tables, + const int32_t batch_idx, + const int32_t max_block_num, + const int32_t head_num, + const int32_t block_size, + const int32_t head_size, + const int32_t seq_begin, + const int32_t deal_seq_num, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + int32_t block_start = batch_idx * max_block_num + seq_begin / block_size; + int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size; + int32_t block_count = block_end - block_start + 1; + // make sure elements in block_tables >= 0 + __memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start, + block_count * sizeof_(int32_t), GDRAM2NRAM); + __bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets, + (uint32_t)cache_bs_stride * sizeof(Tc), block_count); + __bang_mul_scalar((uint32_t *)scale_offsets, (uint32_t *)block_offsets, + (uint32_t)scale_bs_stride * sizeof(Ts), block_count); +#if __BANG_ARCH__ >= 500 + // gather [block_count, head_num, block_size, head_size] + __gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets, + (uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM, + (uint32_t)cache_bs_stride * sizeof(Tc), block_count); + // gather [block_count, head_num, block_size] + __gather((Ts *)scale_nram, (Ts *)scale, (uint32_t *)scale_offsets, + (uint32_t)scale_bs_stride * sizeof(Ts), GDRAM2NRAM, + (uint32_t)scale_bs_stride * sizeof(Ts), block_count); +#endif +} + +template +__mlu_global__ void MLUDequantFromPagedCacheKernelPerChannel(void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const int32_t *context_lens, + const int32_t *context_seq_offsets, + const int32_t *block_tables, + const int32_t max_context_len, + const int32_t max_block_num, + const int32_t batch_size, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t block_size, + const int32_t head_size, + const int32_t seq_block, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + bool has_key = (key && key_cache && key_scale); + bool has_value = (value && value_cache && value_scale); + if (!(has_key || has_value)) { + return; + } + + /* *********************************nram space ************************************** + * NRAM |scale[head_num, head_size] fp32|output/input[seq_block, head_num, head_size] fp32| + */ + int32_t scale_num = head_num * head_size; + Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts)); + float *output_nram = (float *)nbuf + scale_num; + Tc *input_nram = (Tc *)output_nram + + (std::is_same::value + ? (7 * seq_block * (scale_num >> 1)) + : (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc))); + int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + seq_block * scale_num); + uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size); + int32_t seq_offset; + int32_t seq_len; + int32_t seq_begin; + int32_t deal_seq_num; + size_t context_offset; + process_offsets((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens, + (int32_t *)context_seq_offsets, batch_size); + if (has_key) { + load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride, + scale_head_stride); + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + // seq_begin % block_size != 0 only when seq_block < block_size + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + context_offset = context_seq_stride * (seq_offset + seq_begin); + load_input_per_channel((Tc *)input_nram, (Tc *)key_cache, (Tc *)output_nram, + (int32_t *)block_offsets, (uint32_t *)cache_offsets, + (int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num, + block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride, + cache_head_stride); + dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key, + scale_num, deal_seq_num, head_num, head_size, context_offset, + context_seq_stride, context_head_stride); + } + } + + if (has_value) { + load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride, + scale_head_stride); + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + // seq_begin % block_size != 0 only when seq_block < block_size + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + context_offset = context_seq_stride * (seq_offset + seq_begin); + load_input_per_channel((Tc *)input_nram, (Tc *)value_cache, (Tc *)output_nram, + (int32_t *)block_offsets, (uint32_t *)cache_offsets, + (int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num, + block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride, + cache_head_stride); + dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value, + scale_num, deal_seq_num, head_num, head_size, context_offset, + context_seq_stride, context_head_stride); + } + } +} + +template +__mlu_global__ void MLUDequantFromPagedCacheKernelPerHead(void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const int32_t *context_lens, + const int32_t *context_seq_offsets, + const int32_t *block_tables, + const int32_t max_context_len, + const int32_t max_block_num, + const int32_t batch_size, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t block_size, + const int32_t head_size, + const int32_t seq_block, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + bool has_key = (key && key_cache && key_scale); + bool has_value = (value && value_cache && value_scale); + if (!(has_key || has_value)) { + return; + } + + /* *********************************nram space ************************************** + * NRAM |scale[seq_block, head_num] * fp32|output/input[seq_block, head_num, head_size] * fp32| + * |temp[seq_block, head_num, head_size] * output_dtype| + * WRAM |head_size * 64B| + */ + int32_t scale_num = seq_block * head_num; + Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts)); + float *output_nram = (float *)nbuf + scale_num; + Tc *input_nram = + (Tc *)output_nram + (head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)); + // temp_nram for nram to store temp output + float *temp_nram = (float *)output_nram + head_size * scale_num; + int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + head_size * scale_num); + uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size); + uint32_t *scale_offsets = (uint32_t *)cache_offsets + DIV_UP(seq_block, block_size); + int32_t seq_len; + int32_t seq_begin; + int32_t seq_offset; + int32_t deal_seq_num; + size_t context_offset; + __bang_write_value((float *)nbuf, head_size * 16, 1.0f); + mvNram2WramLT16((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16); + process_offsets((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens, + (int32_t *)context_seq_offsets, batch_size); + if (has_key) { + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + context_offset = context_seq_stride * (seq_offset + seq_begin); + load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)key_cache, (Ts *)key_scale, + (Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets, + (uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx, + max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num, + cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride); + dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram, + (T *)key, deal_seq_num, head_num, block_size, head_size, context_offset, + context_seq_stride, context_head_stride); + } + } + + if (has_value) { + for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) { + load_len_offset(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets, + (int32_t *)context_lens, (int32_t *)context_seq_offsets, + batch_idx); + seq_begin = taskIdZ * seq_block; + deal_seq_num = std::min(seq_len - seq_begin, seq_block); + if (deal_seq_num <= 0 || seq_offset < 0) continue; + context_offset = context_seq_stride * (seq_offset + seq_begin); + load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)value_cache, (Ts *)value_scale, + (Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets, + (uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx, + max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num, + cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride); + dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram, + (T *)value, deal_seq_num, head_num, block_size, head_size, context_offset, + context_seq_stride, context_head_stride); + } + } +} +} // namespace kernels + +#define DEQUANT_PAGED_INIT(T, Tc, Ts, C, Name) \ + template __mlu_global__ void kernels::MLUDequantFromPagedCacheKernel##Name( \ + void *key, void *value, const void *key_cache, const void *value_cache, \ + const void *key_scale, const void *value_scale, const int32_t *context_lens, \ + const int32_t *context_seq_offsets, const int32_t *block_tables, \ + const int32_t max_context_len, const int32_t max_block_num, const int32_t batch_size, \ + const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \ + const int32_t block_size, const int32_t head_size, const int32_t seq_block, \ + const size_t context_head_stride, const size_t context_seq_stride, \ + const size_t cache_bs_stride, const size_t cache_head_stride, \ + const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \ + const size_t scale_bs_stride, const size_t scale_head_stride); + +DEQUANT_PAGED_INIT(half, int8_t, float, false, PerChannel) +DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerChannel) +DEQUANT_PAGED_INIT(float, int8_t, float, false, PerChannel) +// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerChannel) +// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerChannel) +// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerChannel) +DEQUANT_PAGED_INIT(half, int8_t, float, false, PerHead) +DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerHead) +DEQUANT_PAGED_INIT(float, int8_t, float, false, PerHead) +// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerHead) +// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerHead) +// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerHead) +DEQUANT_PAGED_INIT(half, int8_t, float, true, PerChannel) +DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerChannel) +DEQUANT_PAGED_INIT(float, int8_t, float, true, PerChannel) +// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerChannel) +// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerChannel) +// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerChannel) +DEQUANT_PAGED_INIT(half, int8_t, float, true, PerHead) +DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerHead) +DEQUANT_PAGED_INIT(float, int8_t, float, true, PerHead) +// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerHead) +// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerHead) +// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerHead) + +typedef void (*DequantFromPagedCachePointer)(void *, // key + void *, // value + const void *, // key_cache + const void *, // value_cache + const void *, // key_scale + const void *, // value_scale + const int32_t *, // context_lens + const int32_t *, // context_seq_offsets + const int32_t *, // block_tables + const int32_t, // max_context_len + const int32_t, // max_block_num + const int32_t, // batch_size + const int32_t, // head_num + const int32_t, // key_group_num + const int32_t, // value_group_num + const int32_t, // block_size + const int32_t, // head_size + const int32_t, // seq_block + const size_t, // context_head_stride + const size_t, // context_seq_stride + const size_t, // cache_bs_stride + const size_t, // cache_head_stride + const size_t, // key_cache_seq_stride + const size_t, // value_cache_seq_stride + const size_t, // scale_bs_stride + const size_t); // scale_head_stride + +static DequantFromPagedCachePointer DequantFromPagedCacheFuncArr[DEQUANT_FUNC_LEN] = { + DEQUANT_PAGED_PERCHANNEL, + DEQUANT_PAGED_PERCHANNEL, + DEQUANT_PAGED_PERCHANNEL, nullptr, nullptr, nullptr, + // DEQUANT_PAGED_PERCHANNEL, + // DEQUANT_PAGED_PERCHANNEL, + // DEQUANT_PAGED_PERCHANNEL, + DEQUANT_PAGED_PERHEAD, + DEQUANT_PAGED_PERHEAD, + DEQUANT_PAGED_PERHEAD, nullptr, nullptr, nullptr, + // DEQUANT_PAGED_PERHEAD, + // DEQUANT_PAGED_PERHEAD, + // DEQUANT_PAGED_PERHEAD, + DEQUANT_PAGED_PERCHANNEL, + DEQUANT_PAGED_PERCHANNEL, + DEQUANT_PAGED_PERCHANNEL, nullptr, nullptr, nullptr, + // DEQUANT_PAGED_PERCHANNEL, + // DEQUANT_PAGED_PERCHANNEL, + // DEQUANT_PAGED_PERCHANNEL, + DEQUANT_PAGED_PERHEAD, + DEQUANT_PAGED_PERHEAD, + DEQUANT_PAGED_PERHEAD, nullptr, nullptr, nullptr}; +// DEQUANT_PAGED_PERHEAD, +// DEQUANT_PAGED_PERHEAD, +// DEQUANT_PAGED_PERHEAD}; + +uint32_t getDequantPagedIdx(cnnlDataType_t dtype, + int32_t quant_mode, + int32_t quant_bit, + const void *context_seq_offset) { + uint32_t idx = 0; + idx += (quant_mode != 0) ? 6 : 0; + idx += (quant_bit != 8) ? 3 : 0; + idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0; + idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0; + idx += (context_seq_offset == nullptr) ? 12 : 0; + + return idx; +} + +void getBlockAndDimForPaged(int32_t &seq_block, + cnrtDim3_t &task_dim, + cnrtFunctionType_t &task_type, + const int32_t max_context_len, + const int32_t head_num, + const int32_t batch_size, + const int32_t head_size, + const int32_t block_size, + const int32_t quant_mode, + const int32_t quant_bit, + const cnnlDataType_t dtype) { + int32_t core_dim; + int32_t cluster_dim; + int32_t nram_size = 480 * 1024; + int32_t wram_size = 512 * 1024; + int32_t sram_size = 2016 * 1024; + getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK); + if (quant_mode == 0) { + seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1); + if (seq_block < block_size) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * head_size * sizeof_(float) should be less than " + << nram_size / block_size << " when quant_mode is 0." << std::endl; + } + } else { + int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half); + seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) + + (int64_t)head_num * head_size * dtype_size)); + if (seq_block < block_size) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + " + "context_dtype_size) " + << "should be less than " << nram_size / block_size << " when quant_mode is 1." + << std::endl; + } + + /* head_size * 64B put in the wram. */ + if (head_size * ONE_LINE >= wram_size) { + std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than " + << wram_size << " when quant_mode is 1." << std::endl; + } + } + // seq_block should be a multiply of block_size + seq_block = PAD_DOWN(seq_block, block_size); + int seq_seg = DIV_UP(max_context_len, seq_block); + int32_t core_num = cluster_dim * core_dim; + if (batch_size * seq_seg <= (core_num / 2)) { + int times = core_num / batch_size / seq_seg; + seq_block = std::max(seq_block / times, 2); + if (seq_block > block_size) { + seq_block = PAD_DOWN(seq_block, block_size); + } else { + seq_block = block_size; + } + seq_seg = DIV_UP(max_context_len, seq_block); + } + + task_dim.x = 1; + task_dim.y = uint32_t(std::min(batch_size, core_num)); + task_dim.z = uint32_t(seq_seg); + task_type = cnrtFuncTypeBlock; +} + +KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue, + void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const void *context_lens, + const void *context_seq_offsets, + const void *block_tables, + const int32_t max_context_len, + const int32_t max_block_num, + const int32_t batch_size, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t block_size, + const int32_t head_size, + const int32_t quant_mode, + const int32_t quant_bit, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t scale_bs_stride, + const size_t scale_head_stride, + const cnnlDataType_t dtype) { + if (is_arch300()) { + std::cerr << "[invokeDequantFromPagedCache]: kernel does not support MLU300 devices." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + int32_t index; + int32_t seq_block; + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + getBlockAndDimForPaged(seq_block, k_dim, k_type, max_context_len, head_num, batch_size, head_size, + block_size, quant_mode, quant_bit, dtype); + index = getDequantPagedIdx(dtype, quant_mode, quant_bit, context_seq_offsets); + auto dequant_paged_func = DequantFromPagedCacheFuncArr[index]; + dequant_paged_func<<>>( + (void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache, + (const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens, + (const int32_t *)context_seq_offsets, (const int32_t *)block_tables, max_context_len, + max_block_num, batch_size, head_num, key_group_num, value_group_num, block_size, head_size, + seq_block, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride, + key_cache_seq_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mluh new file mode 100644 index 0000000..149ed8f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mluh @@ -0,0 +1,106 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_ +#define CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" + +namespace tmo { +/** + * @brief De-quantizes the key and value tensors from the provided paged cache and scale. + * @param queue: The queue for mlu. + * @param key: Pointer to the MLU memory that stores the key tensor, + * with shape [total_seqlen, head_num, head_size]. Data type can be half + * or bfloat16. This parameter can be nullptr. + * @param value: Pointer to the MLU memory that stores the value tensor, + * with shape [total_seqlen, head_num, head_size]. Data type can be + * half or bfloat16.This parameter can be nullptr. + * @param key_cache: Pointer to the MLU memory that stores the key cache tensor, + * with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization. + * Data type must be int8. This parameter can be nullptr. + * @param value_cache: Pointer to the MLU memory that stores the value cache tensor, + * with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization. + * Data type must be int8. This parameter can be nullptr. + * @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale. + * Shape depends on quantization mode: + * - For per-channel quantization (quant_mode = 0): [head_num, head_size]. + * - For per-token quantization (quant_mode = 1): [total_blocks, head_num, block_size]. + * Data type must be float32. This parameter can be nullptr. + * @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale, + * with the same shape as key_scale. Data type must be float32. This parameter can be + nullptr. + * @param context_lens: Pointer to the MLU memory that stores the sequence lengths. + * The shape must be [batch]. + * @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the + context. + * The shape must be [batch]. If nullptr, the default value is the cumulative sum of + context_lengths. + * @param block_tables: Pointer to the MLU memory that stores the block tables for indexing. + * The shape must be [batch, max_block_num]. + * @param max_contxt_len: The maximum sequence length of context. + * @param max_block_num: The maximum block number of each batch. + * @param batch: Batch size. + * @param head_num: Head number. + * @param key_group_num: group number of key group-wise quantization. + * @param value_group_num: group number of value group-wise quantization. + * @param block_size: The block size of the cache. + * @param head_size: Head size. + * @param quant_mode: An integer value indicating the quantization mode: + * 0 for per-channel quantization and 1 for per-token quantization. + * @param quant_bit: An integer value indicating the quantization bit width: + * 8 for 8-bit quantization. + * @param contxt_head_stride: The stride of head_num in context. + * @param contxt_seq_stride: The stride of max_contxt_len in context. + * @param cache_bs_stride: The stride of batch in cache. + * @param cache_head_stride: The stride of head_num in cache. + * @param key_cache_seq_stride: The stride of cache_mem_len in key cache. + * @param value_cache_seq_stride: The stride of cache_mem_len in value cache. + * @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant. + * @param cache_scale_head_stride: The stride of head in cache scale. + * @param dtype: The data type of the key and value tensors. + + * @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key. + * If any of value/value_cache/value_scale is nullptr, no operation is performed on the value. + */ +KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue, + void *key, + void *value, + const void *key_cache, + const void *value_cache, + const void *key_scale, + const void *value_scale, + const void *context_lens, + const void *context_seq_offsets, + const void *block_tables, + const int32_t max_context_len, + const int32_t max_block_num, + const int32_t batch, + const int32_t head_num, + const int32_t key_group_num, + const int32_t value_group_num, + const int32_t block_size, + const int32_t head_size, + const int32_t quant_mode, + const int32_t quant_bit, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t cache_scale_bs_stride, + const size_t cache_scale_head_stride, + const cnnlDataType_t dtype); +} // namespace tmo + +#endif // CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu new file mode 100644 index 0000000..5773d2e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu @@ -0,0 +1,254 @@ +#include +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "dequantify.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { +template +struct PackValueNum { + const static int value = 1; +}; + +template <> +struct PackValueNum { + const static int value = 2; +}; + +__nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)]; + +__nram__ int8_t nram_buf_scale[8192]; + +__mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) { + __bang_int82float(dst, src, count, 0); + __bang_mul_scalar(dst, dst, scale, count); +} + +__mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) { + __bang_int42float_rn(dst, src, count, 0); + __bang_mul_scalar(dst, dst, scale, count); +} + +__mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) { + __bang_int82half(dst, src, count, 0); + __bang_mul_scalar(dst, dst, (half)scale, count); +} + +__mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) { + __bang_int42half_rn(dst, src, count, 0); + __bang_mul_scalar(dst, dst, (half)scale, count); +} + +template +__mlu_func__ void swap(T *&ping, T *&pong) { + T *tmp = ping; + ping = pong; + pong = tmp; +} + +template +__mlu_global__ void dequantifyPerTensor(void *all_dst, + const void *all_src, + size_t all_src_count, + float scale) { + scale = 1.0f / scale; + size_t src_per_core = all_src_count / taskDim; + size_t src_remain = all_src_count % taskDim; + size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain); + const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0); + TDst *dst = reinterpret_cast(all_dst) + start * PackValueNum::value; + const TSrc *src = reinterpret_cast(all_src) + start; + + constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong + (sizeof(TSrc) + sizeof(TDst) * PackValueNum::value) / 128 * + 128; // align to 128 + constexpr int src_num_unit = size_unit / sizeof(TSrc); + constexpr int dst_num_unit = src_num_unit * PackValueNum::value; + int8_t *nram_buf_ping = nram_buf; + int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2; + + TSrc *nram_src_ping = reinterpret_cast(nram_buf_ping); + TDst *nram_dst_ping = + reinterpret_cast(nram_buf_ping + static_cast(sizeof(TSrc)) * size_unit); + TSrc *nram_src_pong = reinterpret_cast(nram_buf_pong); + TDst *nram_dst_pong = + reinterpret_cast(nram_buf_pong + static_cast(sizeof(TSrc)) * size_unit); + + int loop_count = src_count / src_num_unit; + int remain_count = src_count % src_num_unit; + + // L + __memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); + swap(nram_src_ping, nram_src_pong); + swap(nram_dst_ping, nram_dst_pong); + __sync_io_move_compute(); + + // L C + __memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); + convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale); + swap(nram_src_ping, nram_src_pong); + swap(nram_dst_ping, nram_dst_pong); + __sync_io_move_compute(); + + // L C S + for (int i = 0; i < loop_count - 2; ++i) { + __memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit, + GDRAM2NRAM); + __memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM); + convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale); + swap(nram_src_ping, nram_src_pong); + swap(nram_dst_ping, nram_dst_pong); + __sync_io_move_compute(); + } + + // C S + __memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, + NRAM2GDRAM); + convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale); + swap(nram_src_ping, nram_src_pong); + swap(nram_dst_ping, nram_dst_pong); + __sync_io_move_compute(); + + // S + __memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, + NRAM2GDRAM); + + __sync_io_move_compute(); + + if (remain_count > 0) { + __memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count, + GDRAM2NRAM); + convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum::value, scale); + __memcpy(dst + loop_count * dst_num_unit, nram_dst_ping, + sizeof(TDst) * remain_count * PackValueNum::value, NRAM2GDRAM); + } +} + +// does not use a pipeline because per channel is more complicated but it's a one-time operation, so +// performance doesn't matter. +template +__mlu_global__ void dequantifyPerChannel(void *all_dst, + const void *all_src, + int src_ci, + int all_co, + const void *scale) { + const int co_per_core = all_co / taskDim; + const int co_remain = all_co % taskDim; + const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain); + const int co_count = co_per_core + (taskId < co_remain ? 1 : 0); + assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst)); + + constexpr int size_unit = sizeof(nram_buf) / + (sizeof(TSrc) + sizeof(TDst) * PackValueNum::value) / 128 * + 128; // align to 128 + // yes, we only deal with 1 channel at a time + // no, there's no need to optimize a one-time operation + const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci); + const int dst_num_unit = src_num_unit * PackValueNum::value; + TSrc *const nram_src = reinterpret_cast(nram_buf); + TDst *const nram_dst = + reinterpret_cast(nram_buf + static_cast(sizeof(TSrc)) * size_unit); + + const TDst *nram_scale = reinterpret_cast(nram_buf_scale); + + const int loop_one_channel = src_ci / src_num_unit; + const int remain_one_channel = src_ci % src_num_unit; + + for (int o = start_co; o < start_co + co_count; ++o) { + const TSrc *src = reinterpret_cast(all_src) + o * src_ci; + TDst *dst = reinterpret_cast(all_dst) + o * src_ci; + const TDst scale_value = 1. / nram_scale[o]; + for (int i = 0; i < loop_one_channel; ++i) { + __memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM); + convert(nram_dst, nram_src, dst_num_unit, scale_value); + __memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM); + } + if (remain_one_channel > 0) { + __memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel, + GDRAM2NRAM); + convert(nram_dst, nram_src, remain_one_channel * PackValueNum::value, scale_value); + __memcpy(dst + loop_one_channel * dst_num_unit, nram_dst, + sizeof(TDst) * remain_one_channel * PackValueNum::value, NRAM2GDRAM); + } + } +} + +} // namespace kernels +static const std::map, + decltype(&kernels::dequantifyPerTensor)> + per_tensor_func_map = { + {{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor}, + {{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor}, + {{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor}, + {{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor}, +}; + +KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle, + const void *src, + int src_bitwidth, + void *dst, + cnnlDataType_t dst_dtype, + size_t src_count, + float scale) { + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + CNdev dev; + cnnlGetDevice(handle, &dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1}; + auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype)); + if (iter == per_tensor_func_map.end()) { + std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth + << " dst_dtype: " << dst_dtype; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + iter->second<<>>(dst, src, src_count, scale); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +static const std::map, + decltype(&kernels::dequantifyPerChannel)> + per_channel_func_map = { + {{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel}, + {{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel}, + {{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel}, + {{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel}, +}; + +KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle, + const void *src, + int src_bitwidth, + void *dst, + cnnlDataType_t dst_dtype, + int src_ci, + int co, + const void *scale) { + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + CNdev dev; + cnnlGetDevice(handle, &dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1}; + auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype)); + if (iter == per_channel_func_map.end()) { + std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth + << " dst_dtype: " << dst_dtype; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + iter->second<<>>(dst, src, src_ci, co, scale); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mluh new file mode 100644 index 0000000..b6355c6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mluh @@ -0,0 +1,57 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_DEQUANTIFY_MLUH_ +#define CSRC_KERNELS_DEQUANTIFY_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Dequantify per tensor. + * @param handle: The handle of cnnl. + * @param src: Input. Pointer to the MLU memory that stores the input. + * @param src_bitwidth: The bitwidth of input quantized data. + * @param dst: Output. Pointer to the MLU memory that stores the output. + * @param dst_dtype: The data type of output. + * @param src_count: The number of elements in input. + * @param scale: The scale for dequantify. + */ +KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle, + const void *src, + int src_bitwidth, + void *dst, + cnnlDataType_t dst_dtype, + size_t src_count, + float scale); + +/** + * @brief Dequantify per channel. + * @param handle: The handle of cnnl. + * @param src: Input. Pointer to the MLU memory that stores the input. + * @param src_bitwidth: The bitwidth of input quantized data. + * @param dst: Output. Pointer to the MLU memory that stores the output. + * @param dst_dtype: The data type of output. + * @param src_ci: The ci of input. + * @param co: The co of input. + * @param scale: Pointer to the MLU memory that stores the scale for dequantify. + */ +KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle, + const void *src, + int src_bitwidth, + void *dst, + cnnlDataType_t dst_dtype, + int src_ci, + int co, + const void *scale); +} // namespace tmo + +#endif // CSRC_KERNELS_DEQUANTIFY_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu new file mode 100644 index 0000000..611f0dc --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu @@ -0,0 +1,310 @@ +#include +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "embedding.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { +#define MAX_UINT32 (4294967295) +#define MAX_SINT32 (2147483647) +#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) +__nram__ int8_t nram_buffer[NRAM_SIZE]; + +__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) { + int base = total / num; + int tail = total - base * num; + every = base + (id < tail ? 1 : 0); + offset = base * id + (id < tail ? id : tail); +} + +#define PAD_DOWN(x, y) (((x) / (y)) * (y)) +#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y)) + +template +__mlu_func__ void embeddingImpl_500(T *filter, + int *input_ids, + T *output, + int vocab_offset, + int vocab_size, + int input_size, + int total_seq) { + if (__is_mpu()) { + return; + }; + + int bs_core = 0; + int bs_offset = 0; + split(total_seq, taskDim, taskId, bs_core, bs_offset); + // 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex must be divisible + // by 8 + int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) / + (input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t)); + + int vocab_start = vocab_offset; + int vocab_end = vocab_offset + vocab_size - 1; + + T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T) + T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T) + int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int) + int *idxs_nram = ones_nram + limit; // limit * sizeof(int) + int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int) + int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int) + uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t) + __bang_write_zero(zeros_nram, input_size); + __bang_write_zero(zeros_offset_nram, limit); + __bang_write_value(ones_nram, limit, 1); + + int repeat = bs_core / limit; + int remain = bs_core % limit; + + for (int i = 0; i < repeat + 1; i++) { + if ((i == repeat) && (remain == 0)) { + return; + } + int num = (i == repeat) ? remain : limit; + int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex + __memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM); + __sync(); + __bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num); + __bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num); + __bang_mul(mask_nram, mask_nram, temp_nram, num); + __bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram, + num_pad); // gather valid mask + __bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask + __bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index + __bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram, + (unsigned int)input_size * sizeof(T), num); // gather offset + __sync(); + __gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T), + GDRAM2NRAM, input_size * sizeof(T), num); + __gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T), + NRAM2NRAM, input_size * sizeof(T), num); + __sync(); + __memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram, + num * input_size * sizeof(T), NRAM2GDRAM); + __sync(); + } +} + +template +__mlu_func__ void write_zero(T *dst, unsigned int elem_count) { + __bang_write_zero(dst, elem_count); +} + +template <> +__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) { +#if __BANG_ARCH__ >= 500 + __bang_write_zero(dst, elem_count); +#endif +} + +template +__mlu_func__ void embeddingImpl_300(T *filter, + int *input_ids, + T *output, + int vocab_offset, + int vocab_size, + int input_size, + int total_seq) { + if (__is_mpu()) { + return; + }; + + int bs_core = 0; + int bs_offset = 0; + split(total_seq, taskDim, taskId, bs_core, bs_offset); + int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int)); + limit = PAD_DOWN(limit, 2); + int repeat = bs_core / limit; + int remain = bs_core % limit; + int vocab_start = vocab_offset; + int vocab_end = vocab_offset + vocab_size - 1; + + T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T) + int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int) + + for (int i = 0; i < repeat + 1; i++) { + if ((i == repeat) && (remain == 0)) { + return; + } + int num = (i == repeat) ? remain : limit; + __memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM); + __sync(); + + int idx1 = idxs_nram[0]; + int idx2 = idxs_nram[1]; + bool first = (idx1 >= vocab_start && idx1 <= vocab_end); + bool second = (idx2 >= vocab_start && idx2 <= vocab_end); + for (int n = 0; n < num / 2 * 2; n += 2) { + if (first && second) { + __memcpy_async(emb_nram + n * input_size, + filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), + GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T), + 1); + } else if (!first && !second) { + write_zero(emb_nram + n * input_size, 2 * input_size); + } else if (first && !second) { + write_zero(emb_nram + (n + 1) * input_size, input_size); + __memcpy_async(emb_nram + n * input_size, + filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), + GDRAM2NRAM); + } else { + write_zero(emb_nram + n * input_size, input_size); + __memcpy_async(emb_nram + (n + 1) * input_size, + filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), + GDRAM2NRAM); + } + idx1 = idxs_nram[n + 2]; + idx2 = idxs_nram[n + 3]; + first = (idx1 >= vocab_start && idx1 <= vocab_end); + second = (idx2 >= vocab_start && idx2 <= vocab_end); + } // copy loop + + // last idx copy + if (num % 2 == 1) { + if (first) { + __memcpy_async(emb_nram + (num - 1) * input_size, + filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), + GDRAM2NRAM); + } else { + write_zero(emb_nram + (num - 1) * input_size, input_size); + } + } + __sync(); + + __memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram, + num * input_size * sizeof(T), NRAM2GDRAM); + __sync(); + } +} + +template +__mlu_func__ void embeddingImpl_generic(T *filter, + int *input_ids, + T *output, + int vocab_offset, + int vocab_size, + int input_size, + int total_seq) { + if (__is_mpu()) { + return; + }; + + int bs_core = 0; + int bs_offset = 0; + split(total_seq, taskDim, taskId, bs_core, bs_offset); + int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int)); + limit = PAD_DOWN(limit, 2); + int repeat = bs_core / limit; + int remain = bs_core % limit; + int vocab_start = vocab_offset; + int vocab_end = vocab_offset + vocab_size - 1; + + T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T) + int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int) + + for (int i = 0; i < repeat + 1; i++) { + if ((i == repeat) && (remain == 0)) { + return; + } + int num = (i == repeat) ? remain : limit; + __memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM); + __sync(); + + int idx = idxs_nram[0]; + bool hit = (idx >= vocab_start && idx <= vocab_end); + for (int n = 0; n < num; n++) { + if (hit) { + __memcpy_async(emb_nram + n * input_size, + filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T), + GDRAM2NRAM); + } else { + write_zero(emb_nram + n * input_size, input_size); + } + idx = idxs_nram[n + 1]; + hit = (idx >= vocab_start && idx <= vocab_end); + } + __sync(); + + __memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram, + num * input_size * sizeof(T), NRAM2GDRAM); + __sync(); + } +} + +template +__mlu_global__ void MLUEmbeddingKernel(T *filter, + int *input_ids, + T *output, + int vocab_offset, + int vocab_size, + int total_vocab_size, + int input_size, + int total_seq) { +#if __BANG_ARCH__ > 372 + // __gather index maximum dtype is unsigned int + if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) { + embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq); + } else { + embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size, + total_seq); + } +#else + // __memcpy 2D src_stride dtype is int + if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) { + embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq); + } else { + embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size, + total_seq); + } +#endif +} +} // namespace kernels + +KernelStatus invokeEmbedding(cnrtQueue_t queue, + void *filter, + void *input_ids, + void *output, + const cnnlDataType_t dtype, + int vocab_offset, + int vocab_size, + int total_vocab_size, + int input_size, + int total_seq) { + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; + + if (dtype == CNNL_DTYPE_FLOAT) { + kernels::MLUEmbeddingKernel<<>>( + static_cast(filter), (int *)input_ids, static_cast(output), vocab_offset, + vocab_size, total_vocab_size, input_size, total_seq); + } else if (dtype == CNNL_DTYPE_HALF) { + kernels::MLUEmbeddingKernel<<>>( + static_cast(filter), (int *)input_ids, static_cast(output), vocab_offset, + vocab_size, total_vocab_size, input_size, total_seq); + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + if (!isBf16Supported()) { + std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + kernels::MLUEmbeddingKernel<<>>( + static_cast(filter), (int *)input_ids, static_cast(output), + vocab_offset, vocab_size, total_vocab_size, input_size, total_seq); + } + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh new file mode 100644 index 0000000..3d4de7f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh @@ -0,0 +1,63 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_EMBEDDING_MLUH_ +#define CSRC_KERNELS_EMBEDDING_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Look up table for ids which greater than vocab_offset and less than + * vocab_offset + vocab_size, and write the results back to the position + * corresponding to the ids. For ids that are not in the range, write 0 + * to the corresponding position. + * @example + * filter: + * [[1, 2, 3, 4], + * [5, 6, 7, 8], + * [4, 3, 2, 1]] + * input_ids: + * [[1, 5, 6, 7, 8, 9]] + * vocab_offset = 5 + * vocab_size = 3 + * input_size = 4 + * total_seq = 6 + * output: + * [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8], + * [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]] + * @param queue: The queue for mlu. + * @param filter: Input. Pointer to the MLU memory that stores the embedding table, + * the shape must be [vocab_size, input_size]. + * @param input_ids: Input. Pointer to the MLU memory that stores the token id, + * the shape must be [batch, seq]. + * @param output: Output. Pointer to the MLU memory that stores the output, + * the shape must be [batch, seq, input_size]. + * @param dtype: Data type. + * @param vocab_offset: embedding table offset. + * @param vocab_size: embedding table size. + * @param total_vocab_size: total embedding table size. + * @param input_size: embedding dim. + * @param total_seq: Total sequence length. + */ +KernelStatus invokeEmbedding(cnrtQueue_t queue, + void *filter, + void *input_ids, + void *output, + const cnnlDataType_t dtype, + int vocab_offset, + int vocab_size, + int total_vocab_size, + int input_size, + int total_seq); +} // namespace tmo + +#endif // CSRC_KERNELS_EMBEDDING_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mlu new file mode 100644 index 0000000..41f60cd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mlu @@ -0,0 +1,658 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "cnnl.h" +#include "cnrt.h" +#include "fused_rope.mluh" +// clang-format off +#include +// clang-format on + +using bf16 = bfloat16_t; + +namespace tmo { + +namespace kernels { + +#ifndef PAD_UP +#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y)) +#endif + +#if __BANG_ARCH__ > 500 +#include +template +using bang_cycle_fusor = bang::experimental::cycle_fusor; +#endif + +#define NRAM_BUFFER_SIZE (480 * 1024) +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; +__nram__ float nram_mask[256]; +__nram__ int nram_rope_offsets[256]; +__nram__ float nram_zeros[1024] = {0.f}; + +template +__mlu_func__ void toFloat(float *dst, T *src, int num) { + if (std::is_same::value) { + __bang_half2float(dst, (half *)src, num); + } else if (std::is_same::value) { +#if __BANG_ARCH__ > 500 + __bang_bfloat162float(dst, (bf16 *)src, num); +#endif + } +} + +template +__mlu_func__ void floatTo(T *dst, float *src, int num) { + if (std::is_same::value) { + __bang_float2half_rn((half *)dst, src, num); + } else if (std::is_same::value) { + __bang_float2bfloat16_rn((bf16 *)dst, src, num); + } +} + +__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin, + int *cache_seq_offsets_begin, + int *slot_mapping_begin, + int *nram_k_cache_offsets, + int *nram_v_cache_offsets, + int *nram_v_onchip_offsets, + int *nram_kv_scale_offsets, + float *nram_cache_mask, + float *nram_zeros, + float *nram_temp, + int task_deal_batch, + int task_begin_batch, + int head_num_k, + int head_size, + int max_decode_len, + int block_size, + int kv_out_size, + int group_num, + bool discrete_batch, + bool paged_cache, + bool mixed_cache) { + // 目前先用标量化计算offset,便于理解(性能无影响) + int bh = task_deal_batch * head_num_k; + if (paged_cache) { + int cache_seq_stride = head_size; + int cache_head_stride = block_size * head_size; + int cache_scale_head_stride = block_size * group_num; + int cache_block_stride = head_num_k * cache_head_stride; + int cache_scale_block_stride = head_num_k * cache_scale_head_stride; + int *nram_slot_mapping = (int *)nram_mask; + __memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM); + for (int i = 0; i < task_deal_batch; i++) { + int mapping_idx = __load_nram(nram_slot_mapping + i); + if (mapping_idx < 0) { + __bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1); + continue; + } + int block_idx = mapping_idx / block_size; + int seq_idx = mapping_idx % block_size; + int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride; + int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride; + int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num; + int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size; + + for (int j = 0; j < head_num_k; j++) { + __store_nram( + nram_k_cache_offsets + i * head_num_k + j, + (int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1))); + if (mixed_cache) { + __store_nram(nram_v_cache_offsets + i * head_num_k + j, + (int)(v_seq_offset + j * cache_head_stride / 2)); + __store_nram(nram_kv_scale_offsets + i * head_num_k + j, + (int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float))); + __store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size); + } + } + } + } else { + int *nram_seq_offsets = (int *)nram_mask; + int *nram_bs_id = nram_seq_offsets + 32; + int cache_seq_stride = head_size; + int cache_head_stride = max_decode_len * head_size; + int cache_scale_head_stride = max_decode_len * group_num; + int cache_bs_stride = head_num_k * cache_head_stride; + int cache_scale_bs_stride = head_num_k * cache_scale_head_stride; + __memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM); + if (discrete_batch) { + __memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM); + } + for (int i = 0; i < task_deal_batch; i++) { + int bs_idx = __load_nram(nram_bs_id + i); + int seq_idx = __load_nram(nram_seq_offsets + i); + int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i; + int temp_seq_idx = seq_idx; + bool masked = temp_bs_idx < 0 || temp_seq_idx < 0; + if (masked) { + __bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1); + continue; + } + int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride; + int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num; + int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride; + int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size; + + for (int j = 0; j < head_num_k; j++) { + __store_nram( + nram_k_cache_offsets + i * head_num_k + j, + (int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1))); + if (mixed_cache) { + __store_nram(nram_v_cache_offsets + i * head_num_k + j, + (int)(v_seq_offset + j * cache_head_stride / 2)); + __store_nram(nram_kv_scale_offsets + i * head_num_k + j, + (int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float))); + __store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size); + } + } + } + } + + // 此处是为了做上scatter指令的 mask,如果bs offset或seq offset小于0则需要mask掉 + __bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0); + __bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8)); +} + +__mlu_func__ void layernormImpl(float *nram_k, + float *norm_params, + int task_deal_batch, + int k_hidden, + float eps) { +#if __BANG_ARCH__ > 500 + float *buffer = nram_k + task_deal_batch * k_hidden; + for (int i = 0; i < task_deal_batch; i++) { + float *k_ = nram_k + i * k_hidden; + __bang_mul(buffer, k_, k_, k_hidden); + float mean = __bang_sum(k_, k_hidden); + mean = mean / k_hidden; + float rstd = __bang_sum(buffer, k_hidden); + rstd = rstd / k_hidden - mean * mean; + rstd = rstd < 0 ? eps : rstd + eps; + rstd = 1.f / std::sqrt(rstd); + __bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden); + } + __bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden, + task_deal_batch * k_hidden, k_hidden); +#endif +} + +__mlu_func__ void foldRotaryImpl(float *nram_qk, + float *nram_qk_rot, + float *nram_table, + int task_deal_batch, + int head_num_qk, + int head_size) { + int rotary_low_dim = task_deal_batch * head_size; + __bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim); + __bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size); + __bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size, + head_num_qk * rotary_low_dim, rotary_low_dim); + __bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim); +} + +template +__mlu_func__ void quantify(T *input, + float *float_input, + void *output_hp, + void *output_lp, + float *nram_trans, + float *scale_hp, + float *scale_lp, + float *scale_lp_temp, + int batch, + int head_num, + int head_size, + int group_num, + int group_size, + bool quant_kv_hp, + bool mixed_cache) { + if (quant_kv_hp) { + int hidden = head_num * head_size; + int bh = batch * head_num; + toFloat(float_input, input, batch * hidden); + __bang_recip(scale_hp, scale_hp, hidden); +#if __BANG_ARCH__ > 500 + __asm__ __volatile__( + "fuse.nram.crn.s8.f32 " + "[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])" + ";\n\t" ::[dst] "r"(output_hp), + [num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input), + [src1] "r"(scale_hp), [pos] "i"(0)); +#endif + if (mixed_cache) { + __bang_transpose(nram_trans, float_input, bh * group_num, group_size); + __bang_abs(float_input, nram_trans, bh * head_size); + __bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1, + 1); + __bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num); + __bang_recip(scale_lp_temp, scale_lp, bh * group_num); + __bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size, + bh * group_num); + __bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0); + __bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num); + } + } +} + +template +__mlu_func__ void fuseRopeImpl(T *input, + void *key_cache_hp, + void *value_cache_hp, + void *key_cache_lp, + void *value_cache_lp, + T *sin_table, + T *cos_table, + int *rope_offsets, + T *gamma, + T *beta, + float *key_scale_hp, + float *value_scale_hp, + float *key_scale_lp, + float *value_scale_lp, + int *cache_bs_id_hp, + int *cache_seq_offsets_hp, + int *cache_bs_id_lp, + int *cache_seq_offsets_lp, + int *slot_mapping_hp, + int *slot_mapping_lp, + int rotary_stride, + int task_deal_batch, + int task_begin_batch, + int head_num_q, + int head_num_k, + int head_size, + int max_decode_len_hp, + int max_decode_len_lp, + int block_size_hp, + int block_size_lp, + int group_size, + int batch_cap, + float eps) { +#if __BANG_ARCH__ > 500 + /* + 由于需要支持mixed cache,kernel支持cache的功能组合比较多,现规定只存在以下几种: + 1.只存在hp_cache的情况(通过lp tensor不为0判断),cache支持bf16,fp16,量化下支持离线perchannel int8 + 支持linear和paged,key和value cache形状一致,key/value_scale_hp形状为[head_num, head_size] + 2.mixed cache的情况hp支持离线perchannel int8量化,支持linear和paged,key和value cache形状一致, + key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化, + key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2] + paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2, + head_size],paged cache形状为 [num_blocks, head_num_k, block_size / 2, + head_size],key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp, + group_num],paged_cache 为 [num_blocks, head_num_k, block_size, group_num] + */ + bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr; + bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr; + bool discrete_batch_hp = cache_bs_id_hp != nullptr; + bool discrete_batch_lp = cache_bs_id_lp != nullptr; + bool paged_cache_hp = slot_mapping_hp != nullptr; + bool paged_cache_lp = slot_mapping_lp != nullptr; + int head_num_qk = head_num_q + head_num_k; + int head_num_qkv = head_num_q + head_num_k * 2; + int qkv_hidden = head_num_qkv * head_size; + int qk_hidden = head_num_qk * head_size; + int q_hidden = head_num_q * head_size; + int k_hidden = head_num_k * head_size; + int float_size = sizeof(float); + int dtype_size = sizeof(T); + int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size; + int group_num = mixed_cache ? head_size / group_size : 1; + + // task ddr offset + T *input_begin = input + task_begin_batch * qkv_hidden; + int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch; + int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch; + int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch; + + // nram_buffer + float *nram_qk = (float *)nram_buffer; + float *nram_qk_rot = nram_qk + batch_cap * qk_hidden; + float *nram_v = nram_qk_rot + batch_cap * qk_hidden; + float *nram_kv_trans = nram_v + batch_cap * k_hidden; + float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden; + float *norm_params = nram_table + 2 * batch_cap * head_size; + float *nram_k_scale_hp = norm_params + 2 * head_size; + float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden; + float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden; + float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num; + int8_t *nram_kv_hp = + (int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num); + int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden; + int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden; + int *nram_kv_cache_offsets_hp = + (int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2); + int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k; + int *nram_v_cache_offsets_lp = + nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k; + int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k; + int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k; + float *cache_mask_hp = + (float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k); + float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8); + + // 这里将qk和qk_rot放在一起是为升位宽可以一起做,减少指令,同样还有sincostable和norm的gamma和beta + T *qk_in = (T *)nram_qk_rot; + T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden); + T *v_in = + (T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden); + T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size); + T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size); + int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t); + + // 生成 kv cache的offset和mask,供scatter kv到kvcache使用 + genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp, + nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp, + nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k, + head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1, + discrete_batch_hp, paged_cache_hp, false); + if (mixed_cache) { + int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch; + int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch; + int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch; + genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp, + nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets, + nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch, + task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp, + 1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache); + } + + /* + ----------------------- + load v | + ----------------------- + load qk | quant v + ----------------------- + store v | rope qk + ----------------------- + store_q | layernorm k + | quant k + ----------------------- + store k | + */ + // prepare v v_scale cache_v rope_offset + __memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM, + k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1); + if (quant_kv_hp) { + __memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM); + __memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM); + } + __memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int), + GDRAM2NRAM); + + __sync_io(); + if (mixed_cache) { + __gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp, + head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t), + task_deal_batch * head_num_k); + } + __bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size, + task_deal_batch); + __sync_compute(); + /*==============================================================================================*/ + + // load_qk,rope_table | quant v + __memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, + task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1, + qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size, + head_num_qk - 1); + __gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size, + GDRAM2NRAM, head_size * dtype_size, task_deal_batch); + __gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets, + head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch); + __memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM); + __memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM); + + int8_t *nram_temp = (int8_t *)nram_qk; + if (mixed_cache) { + __bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0, + 0); + __bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2); + } + quantify(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp, + nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num, + group_size, quant_kv_hp, mixed_cache); + if (mixed_cache) { + __scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp, + head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t), + task_deal_batch * head_num_k); + __bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden); + __bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0, + 0); + } + __sync_io_move_compute(); + /*==============================================================================================*/ + + // rope | store v + // 将qk的左右部分交换,用于生成qk_rot + __memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM, + head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1); + __memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM, + head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1); + toFloat(nram_qk, qk_in, 2 * batch_cap * qk_hidden); + toFloat(nram_table, table_in, 2 * task_deal_batch * head_size); + toFloat(norm_params, norm_params_in, 2 * head_size); + __bang_write_value(nram_mask, head_size / 2, (float)-1); + __bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1); + foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size); + floatTo((T *)nram_qk, nram_qk, task_deal_batch * q_hidden); + + int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v; + __scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp, + cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp, + head_num_k * task_deal_batch); + if (mixed_cache) { + __scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp, + cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM, + head_size * sizeof(int8_t), head_num_k * task_deal_batch); + __scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets, + cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float), + head_num_k * task_deal_batch); + } + __sync_io_move_compute(); + /*==============================================================================================*/ + // layernrom k quant k | store q + // 从qk的nram buffer中提取出k做layernorm和量化 + float *nram_k = nram_qk_rot; + __memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM, + head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1, + task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size, + task_deal_batch - 1); + layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps); + quantify(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp, + nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, + head_size, group_num, group_size, quant_kv_hp, mixed_cache); + if (mixed_cache) { + __bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0); + } + if (!quant_kv_hp) { + floatTo((T *)nram_k, nram_k, task_deal_batch * k_hidden); + } + + // store q + __memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size, + task_deal_batch - 1, head_size * dtype_size, head_num_q - 1, + head_size * dtype_size, task_deal_batch - 1, + task_deal_batch * head_size * dtype_size, head_num_q - 1); + // =============================================================================================== + + int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k; + __scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp, + head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp, + head_num_k * task_deal_batch); + if (mixed_cache) { + __scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp, + head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t), + head_num_k * task_deal_batch); + __scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp, + group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float), + head_num_k * task_deal_batch); + } + __sync_io_move_compute(); +#endif +} + +template +__mlu_global__ void MLUFuseRope(T *input, + void *key_cache_hp, + void *value_cache_hp, + void *key_cache_lp, + void *value_cache_lp, + T *sin_table, + T *cos_table, + int *rope_offsets, + T *gamma, + T *beta, + float *key_scale_hp, + float *value_scale_hp, + float *key_scale_lp, + float *value_scale_lp, + int *cache_bs_id_hp, + int *cache_seq_offsets_hp, + int *cache_bs_id_lp, + int *cache_seq_offsets_lp, + int *slot_mapping_hp, + int *slot_mapping_lp, + int rotary_stride, + int batch, + int head_num_q, + int head_num_k, + int head_size, + int max_decode_len_hp, + int max_decode_len_lp, + int block_size_hp, + int block_size_lp, + int group_size, + int batch_cap, + int task_avg_batch, + float eps) { + int task_begin_batch = taskId * task_avg_batch; + int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch); + if (task_deal_batch <= 0 || __is_mpu()) { + return; + } + + int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap; + int once_batch = (task_deal_batch + task_loop - 1) / task_loop; + + for (int i = 0; i < task_loop; i++) { + int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch); + int batch_offset = task_begin_batch + once_batch * i; + + fuseRopeImpl(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table, + cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp, + key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp, + cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp, + rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size, + max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, + once_batch, eps); + } +} +} // namespace kernels + +KernelStatus invokeFusedRope(cnrtQueue_t queue, + void *input, + void *key_cache_hp, + void *value_cache_hp, + void *key_cache_lp, + void *value_cache_lp, + const void *sin_table, + const void *cos_table, + const void *rope_offsets, + const void *gamma, + const void *beta, + const void *key_scale_hp, + const void *value_scale_hp, + void *key_scale_lp, + void *value_scale_lp, + const void *cache_bs_id_hp, + const void *cache_seq_offsets_hp, + const void *cache_bs_id_lp, + const void *cache_seq_offsets_lp, + const void *slot_mapping_hp, + const void *slot_mapping_lp, + int rotary_stride, + int batch_size, + int head_num_q, + int head_num_kv, + int head_size, + int max_decode_len_hp, + int max_decode_len_lp, + int block_size_hp, + int block_size_lp, + int group_size, + cnnlDataType_t dtype, + float eps) { + if (is_arch300()) { + std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + uint32_t taskdimx = cluster_num * core_num; + + int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx; + int float_size = sizeof(float); + int group_num = head_size / group_size; + bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr; + bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr; + + int nram_avalible_bytes = 480 * 1024; + int task_max_batch = 32; + int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1); + int nram_params_bytes = 2 * head_size * float_size; + int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size; + int nram_remain_bytes = + nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes; + int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2; + int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1); + int nram_table_bytes = 2 * head_size * float_size; + int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size; + int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size; + int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size; + int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2; + int nram_cache_offsets_hp = head_num_kv * sizeof(int); + int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int); + int batch_cap = + nram_remain_bytes / + (nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes + + nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp); + batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch; + + cnrtDim3_t dim{taskdimx, 1, 1}; + if (dtype == CNNL_DTYPE_HALF) { + kernels::MLUFuseRope<<>>( + (half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, + (half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta, + (float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp, + (float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp, + (int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp, + (int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size, + max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap, + task_avg_batch, eps); + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + kernels::MLUFuseRope<<>>( + (bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, + (bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta, + (float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp, + (float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp, + (int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp, + (int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size, + max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap, + task_avg_batch, eps); + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mluh new file mode 100644 index 0000000..9d19ada --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mluh @@ -0,0 +1,119 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_ +#define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Apply query and kery rotary embedding, key layernorm and + * quantize key and value to kv cache. + * @param queue: The queue for mlu. + * @param input: Input/Output. Pointer to the MLU memory that stores the input, + * the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size]. + * @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key + * cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, + * head_num_kv, block_size, head_size]. + * @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision + * value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, + * head_num_kv, block_size, head_size]. + * @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key + * cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, + * head_num_kv, block_size, head_size]. + * @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision + * value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks, + * head_num_kv, block_size, head_size]. + * @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be + * continous. The shape must be [rotary_seq_len, rotary_dim]. + * @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be + * continous. The shape must be [rotary_seq_len, rotary_dim]. + * @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each + * batch. The shape must be [batch]. + * @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm. + * @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm. + * @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision + * key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr, + * that means key do not need to be quantized. + * @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision + * value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr, + * that means value do not need to be quantized. + * @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low + * precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or + * [num_blocks, head_num_kv, block_size, group_num]. + * @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low + * precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or + * [num_blocks, head_num_kv, block_size, group_num]. + * @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch + * offset of high precision cache, the shape must be [batch], if it's nullptr, the + * default value is {0, 1, 2 ... batch - 1}. + * @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence + * offset of high precision cache, the shape must be [batch]. + * @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch + * offset of low precision cache, the shape must be [batch], if it's nullptr, the + * default value is {0, 1, 2 ... batch - 1}. + * @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence + * offset of low precision cache, the shape must be [batch]. + * @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor + * which has shape [batch]. Data type of slot mapping must be int32_t. + * @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor + * which has shape [batch]. Data type of slot mapping must be int32_t. + * @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table. + * @param batch_size: Batch size. + * @param head_num_q: Head number of query. + * @param head_num_kv: Head number of key and value. + * @param head_size: Head size. For simplify, the rotary dim must be the same as head_size. + * @param max_decode_len_hp: The maximum sequence length of high precision cache. + * @param max_decode_len_lp: The maximum sequence length of low precision cache. + * @param block_size_hp: Number of tokens per block of high precision cache. + * @param block_size_lp: Number of tokens per block of low precision cache. + * @param data_type: Data type of all inputs and outputs. + * @param eps: float number use for layernorm. + * @note: Head_num_q and head_num_kv must be in range [1, 32]. + * Head_size must be in range [1, 128], and must be divided by 2. + */ +KernelStatus invokeFusedRope(cnrtQueue_t queue, + void *input, + void *key_cache_hp, + void *value_cache_hp, + void *key_cache_lp, + void *value_cache_lp, + const void *sin_table, + const void *cos_table, + const void *rope_offsets, + const void *gamma, + const void *beta, + const void *key_scale_hp, + const void *value_scale_hp, + void *key_scale_lp, + void *value_scale_lp, + const void *cache_bs_id_hp, + const void *cache_seq_offsets_hp, + const void *cache_bs_id_lp, + const void *cache_seq_offsets_lp, + const void *slot_mapping_hp, + const void *slot_mapping_lp, + int rotary_stride, + int batch_size, + int head_num_q, + int head_num_kv, + int head_size, + int max_decode_len_hp, + int max_decode_len_lp, + int block_size_hp, + int block_size_lp, + int group_size, + cnnlDataType_t dtype, + float eps); +} // namespace tmo + +#endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mlu new file mode 100644 index 0000000..78bb2c1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mlu @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "generate_alibi_slope.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { + +#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) +__nram__ int8_t nram_buffer[NRAM_SIZE]; + +__nram__ float range_1[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}; + +__nram__ float range_2[64] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, + 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, + 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, + 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, + 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; + +__mlu_func__ void genRange(float *range_nram, + float *range_base_nram, + int fill_num, + int base_num, + int offset = 0) { + int loop = (fill_num + base_num - 1) / base_num; + for (int i = 0; i < loop; i++) { + int num = std::min((fill_num - i * base_num), base_num); + float *fill_nram = range_nram + i * base_num; + __bang_move(fill_nram, range_base_nram, num * sizeof(float)); + __bang_add_scalar(fill_nram, fill_nram, i * base_num + offset, num); + } +} + +__mlu_global__ void MLUAlibiSlopeKernel(float *alibi_slopes, + int *true_seq_lens, + int batch_num, + int head_start, + int head_num, + int head_num_total, + int max_sequence_length, + bool use_dynamic, + int closest_power_of_2, + int farthest_power_of_2, + float base, + float extra_base) { + float *range_nram = (float *)nram_buffer; + float *base_nram = range_nram + head_num; + float *slope_nram = base_nram + head_num; + + float scale = 1.0; + float dynamic_base = base; + if (use_dynamic) { + float a0 = 1.0; + float a = a0 * true_seq_lens[taskIdX] / max_sequence_length; + a = std::max(a, 1.0f); + scale = powf(a, (1.0 / (head_num_total - 1))); + dynamic_base = base / scale; + } + + int close_head_num = 0; + if (head_start >= closest_power_of_2) { + close_head_num = 0; + } else if (head_start + head_num <= closest_power_of_2) { + close_head_num = head_num; + } else { + close_head_num = closest_power_of_2 - head_start; + } + int far_head_num = head_num - close_head_num; + + // fill range: 1, 2..., n1, 1, 3, (n - n1) * 2 - 1 + if (close_head_num) { + genRange(range_nram, range_1, close_head_num, 64, head_start); + __bang_write_value(base_nram, close_head_num, dynamic_base); + } + if (far_head_num) { + genRange(range_nram + close_head_num, range_2, far_head_num, 64, + (head_start + close_head_num - closest_power_of_2) * 2); + __bang_write_value(base_nram + close_head_num, far_head_num, extra_base); + } + + // base_nram ** range_nram + __bang_log(base_nram, base_nram, head_num); + __bang_mul(slope_nram, base_nram, range_nram, head_num); + __bang_pow2(slope_nram, slope_nram, head_num); + + if (use_dynamic) { + __bang_mul_scalar(slope_nram, slope_nram, scale, close_head_num); + } + + __memcpy(alibi_slopes + taskIdX * head_num, slope_nram, head_num * sizeof(float), NRAM2GDRAM); +} + +} // namespace kernels + +KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue, + void *alibi_slopes, + void *true_seq_lens, + int batch_num, + int head_start, + int head_num, + int head_num_total, + int max_sequence_length, + bool use_dynamic) { + cnrtDim3_t dim{.x = (uint32_t)batch_num, .y = 1, .z = 1}; + + int closest_power_of_2 = pow(2, floor(log2(head_num_total))); + int farthest_power_of_2 = closest_power_of_2 * 2; + float base = pow(2, (-pow(2, -(log2(closest_power_of_2) - 3)))); + float extra_base = pow(2, (-pow(2, -(log2(2 * closest_power_of_2) - 3)))); + + kernels::MLUAlibiSlopeKernel<<>>( + (float *)alibi_slopes, (int *)true_seq_lens, batch_num, head_start, head_num, head_num_total, + max_sequence_length, use_dynamic, closest_power_of_2, farthest_power_of_2, base, extra_base); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mluh new file mode 100644 index 0000000..032975a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mluh @@ -0,0 +1,43 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_ +#define CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Generate causal mask for context satge. + * @param queue: The queue for mlu. + * @param alibi_slopes: Output. Pointer to the MLU memory that stores the output, the shape must be + * [batch_num, head_num]. + * @param true_seq_lens: Input. Pointer to the MLU memory that stores the actual sequence length of + * each batch, the shape must be [batch_num]. + * @param batch_num: Batch number. + * @param head_start: The index of first head. + * @param head_num: Head number in this card. + * @param head_num_total: Total head number in all cards. + * @param max_sequence_length: The maximum sequence length used during training. + * @param use_dynamic: A boolean value indicates whether to use dynamic NTK. + */ +KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue, + void *alibi_slopes, + void *true_seq_lens, + int batch_num, + int head_start, + int head_num, + int head_num_total, + int max_sequence_length, + bool use_dynamic); +} // namespace tmo + +#endif // CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mlu new file mode 100644 index 0000000..543bad2 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mlu @@ -0,0 +1,214 @@ +#include +#include +#include "cn_api.h" +#include "cnnl.h" +#include "cnrt.h" +#include "generate_mask.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { +template +__mlu_func__ void write_value(void *dst, unsigned int elem_count, T value) { + __bang_write_value(dst, elem_count, value); +} + +template <> +__mlu_func__ void write_value(void *dst, unsigned int elem_count, bfloat16_t value) { +#if __BANG_ARCH__ >= 500 + __bang_write_value(dst, elem_count, value); +#endif +} + +// [once_len, once_len] +__nram__ int8_t nram_small[(__MLU_NRAM_SIZE__ * 1 / 4 * 1024)]; +// [1 + once_len, 2 * once_len] +__nram__ int8_t nram_large[(__MLU_NRAM_SIZE__ * 2 / 4 * 1024 + 1024)]; +// [once_len * 2 + 1] +__nram__ int8_t nram_tiny[2048]; +template +class GenerateMask { + constexpr static int once_len = sizeof(T) == 4 ? 160 : 256; + // [once_len, once_len] + T *nram_upper = (T *)(nram_small); + // [1 + once_len, 2 * once_len] + T *nram_buf = (T *)(nram_large); + // [once_len, once_len], reuse upper part of nram_buf + T *nram_filled = nram_buf; + // [once_len, once_len], reuse lower part of nram_buf + T *nram_zeros = nram_buf + once_len * once_len; + // [once_len] + T *nram_ones_zeros = (T *)nram_tiny; + + __mlu_func__ void initBuffers(T fill_value = -10000) { + /* nram_buf: + |---once_len---||---once_len---| + 0, 1, 1, 1, ..., 1, 0, 0, 0, ... + 0, 0, 1, 1, ..., 1, 1, 0, 0, ... + 0, 0, 0, 1, ..., 1, 1, 1, 0, ... + ... */ + nram_buf[0] = 0; + constexpr static int copy_size = (once_len * 2 + 1) * sizeof(T); + __memcpy(nram_buf + 1, nram_ones_zeros, copy_size, NRAM2NRAM, copy_size, 0, once_len - 1); + __memcpy(nram_upper, nram_buf, once_len * sizeof(T), NRAM2NRAM, once_len * sizeof(T), + once_len * 2 * sizeof(T), once_len - 1); + // nram_buf is nolonger needed + write_value(nram_filled, once_len * once_len, (T)fill_value); + write_value(nram_zeros, once_len * once_len, (T)0); + } + + __mlu_func__ void dealOneBatch(T *output, // [max_seq_len, max_seq_len] + int max_seq_len, + int seq_len) { + /* + | once_len | + +----------+-----------------------------------+ + | | | | + | upper | fill_value | | + | | | | + +----------+----------+ | | + | | | | | + | | upper | | fill | + | | | | value | + | +----------+----------+ | | + | | | | | + | | upper | | | + | 0 | | | | + | +----------+---+ | + | | u | | + |--------------------------------+---+ | + | | + | fill_value | + | | + +----------------------------------------------+ + */ + int tile_count = seq_len / once_len; + int tile_remain = seq_len % once_len; + int boarder_len = max_seq_len - seq_len; + int row = 0; + for (; row < tile_count * once_len; row += once_len) { + // fill left with zeros + // assume that max_seq_len <= once_len^2 + if (row > 0) { + __memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM, + max_seq_len * sizeof(T), 0, once_len - 1); + } + // fill middle with upper + __memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, once_len * sizeof(T), + NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), once_len - 1); + // fill right with fill_value + if (row + once_len < max_seq_len) { + __memcpy_async(output + (size_t)row * max_seq_len + row + once_len, nram_filled, + (max_seq_len - row - once_len) * sizeof(T), NRAM2GDRAM, + max_seq_len * sizeof(T), 0, once_len - 1); + } + } + + if (tile_remain) { + // fill left with zeros + if (row > 0) { + __memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM, + max_seq_len * sizeof(T), 0, tile_remain - 1); + } + // fill middle with upper + __memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, tile_remain * sizeof(T), + NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), tile_remain - 1); + // fill right with fill_value + if (row + tile_remain < max_seq_len) { + __memcpy_async(output + (size_t)row * max_seq_len + row + tile_remain, nram_filled, + (max_seq_len - row - tile_remain) * sizeof(T), NRAM2GDRAM, + max_seq_len * sizeof(T), 0, tile_remain - 1); + } + } + + if (boarder_len) { + // fill right boarder with fill_value + __memcpy_async(output + seq_len, nram_filled, boarder_len * sizeof(T), NRAM2GDRAM, + max_seq_len * sizeof(T), 0, (max_seq_len - boarder_len) - 1); + // fill bottom boarder with fill_value + __memcpy_async(output + (size_t)seq_len * max_seq_len, nram_filled, max_seq_len * sizeof(T), + NRAM2GDRAM, max_seq_len * sizeof(T), 0, boarder_len - 1); + } + __sync_io(); + } + + public: + __mlu_func__ void execute(T *output_ddr, // [total_batch, max_seq_len, max_seq_len] + int *batch_seq_len, + int total_batch, + int max_seq_len, + T fill_value = -10000) { + int batch_each = total_batch / taskDimY; + int batch_remain = total_batch % taskDimY; + int batch_start = taskIdY * batch_each + (taskIdY < batch_remain ? taskIdY : batch_remain); + int batch_count = batch_each + (taskIdY < batch_remain ? 1 : 0); + write_value(nram_ones_zeros, once_len, (T)fill_value); + write_value(nram_ones_zeros + once_len, once_len + 1, (T)0); + initBuffers(); + + for (int n = batch_start; n < batch_start + batch_count; n++) { + T *output = output_ddr + (size_t)n * max_seq_len * max_seq_len; + int seq_len = batch_seq_len[n]; + dealOneBatch(output, max_seq_len, seq_len); + } + } +}; + +template +__mlu_global__ void MLUUnion1GenerateMask(T *output_ddr, // [total_batch, max_seq_len, max_seq_len] + int *batch_seq_len, + int total_batch, + int max_seq_len, + T fill_value = -10000) { + if (coreId != 0) { + return; // we only use 1 core in a cluster + } + GenerateMask().execute(output_ddr, batch_seq_len, total_batch, max_seq_len, fill_value); +} +} // namespace kernels + +KernelStatus invokeGenerateMask(cnnlHandle_t handle, + void *output_ddr, + int *batch_seq_len, + int total_batch, + int max_seq_len, + cnnlDataType_t data_type, + float fill_value) { + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + CNdev dev; + cnnlGetDevice(handle, &dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + cnrtDim3_t dim; + dim.x = 4; + dim.y = cluster_num; + dim.z = 1; + if (data_type == CNNL_DTYPE_FLOAT) { + kernels::MLUUnion1GenerateMask<<>>( + static_cast(output_ddr), batch_seq_len, total_batch, max_seq_len, + static_cast(fill_value)); + } else if (data_type == CNNL_DTYPE_HALF) { + kernels::MLUUnion1GenerateMask<<>>( + static_cast(output_ddr), batch_seq_len, total_batch, max_seq_len, + static_cast(fill_value)); + } else if (data_type == CNNL_DTYPE_BFLOAT16) { + if (!isBf16Supported()) { + std::cerr << "[invokeGenerateMask]: MLU300 devices do not support bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + kernels::MLUUnion1GenerateMask<<>>( + static_cast(output_ddr), batch_seq_len, total_batch, max_seq_len, + static_cast(fill_value)); + } else { + std::cerr << "[invokeGenerateMask]: invokeGenerateMask: data_type is not supported" + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mluh new file mode 100644 index 0000000..7767cd8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mluh @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_GENERATE_MASK_MLUH_ +#define CSRC_KERNELS_GENERATE_MASK_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Generate causal mask for context stage. + * @param handle: The handle of cnnl. + * @param output_ddr: Output. Pointer to the MLU memory that stores the output. + * @param batch_seq_len: Input. Pointer to the MLU memory that stores the sequence length. + * @param total_batch: Batch size. + * @param max_seq_len: The maximum sequence length of context. + * @param data_type: Data type. + * @param fill_value: The fill value of the pad part. + */ +KernelStatus invokeGenerateMask(cnnlHandle_t handle, + void *output_ddr, + int *batch_seq_len, + int total_batch, + int max_seq_len, + cnnlDataType_t data_type, + float fill_value); +} // namespace tmo + +#endif // CSRC_KERNELS_GENERATE_MASK_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mlu new file mode 100644 index 0000000..7f91000 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mlu @@ -0,0 +1,60 @@ +#include +#include "cnrt.h" +#include "get_glm_position_id.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { +__nram__ int nram_buffer[__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)]; + +__mlu_global__ void MLUBlockSliceIndividualPosId(int *context_pos_id, + int *generate_pos_id, + int batch, + int context_seq_len, + int pos_id_dimension /* 1 for 1D, 2 for 2D */) { + if (taskId != 0) return; + __memcpy(nram_buffer, context_pos_id + context_seq_len - 1, sizeof(int), GDRAM2NRAM, sizeof(int), + context_seq_len * sizeof(int), pos_id_dimension * batch - 1); + if (pos_id_dimension == 2) { + for (int i = 1; i < 2 * batch; i += 2) { + nram_buffer[i] += 1; + } + } + __memcpy(generate_pos_id, nram_buffer, pos_id_dimension * batch * sizeof(int), NRAM2GDRAM); +} + +__mlu_global__ void MLUBlockIncrement2DPosId(int *generate_pos_id, int batch) { + if (taskId != 0) return; + __memcpy(nram_buffer, generate_pos_id, 2 * batch * sizeof(int), GDRAM2NRAM); + for (int i = 1; i < 2 * batch; i += 2) { + nram_buffer[i] += 1; + } + __memcpy(generate_pos_id, nram_buffer, 2 * batch * sizeof(int), NRAM2GDRAM); +} +} // namespace kernels +KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue, + int *context_pos_id, + int *generate_pos_id, + int batch, + int context_seq_len, + int pos_id_dimension /* 1 for 1D, 2 for 2D */) { + if (pos_id_dimension != 1 && pos_id_dimension != 2) { + std::cerr << "[invokeSliceIndividualPosId]: pos_id_dimension must be 1 or 2, but got " + << pos_id_dimension << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + cnrtDim3_t dim{4, 1, 1}; + kernels::MLUBlockSliceIndividualPosId<<>>( + context_pos_id, generate_pos_id, batch, context_seq_len, pos_id_dimension); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch) { + cnrtDim3_t dim{4, 1, 1}; + kernels::MLUBlockIncrement2DPosId<<>>(generate_pos_id, batch); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mluh new file mode 100644 index 0000000..9fa5a5d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mluh @@ -0,0 +1,59 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_ +#define CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Get generate position id from context position id, when position id is 2D, + * increase block position id by one. + * @example + * in GLM network, context_pos_id shape is [batch, 2, context_seq_len], data is + * [[[0, 1, 2, 2, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1]], + * [[0, 1, 2, 3, 4, 5, 5], [0, 0, 0, 0, 0, 0, 1]]] + * after invoke this kernel, the data is + * [[[2], [2]], + * [[5], [2]]] + * @param queue: The queue for mlu. + * @param context_pos_id: Input. Pointer to the MLU memory that stores the position id of + * context. + * @param generate_pos_id: Output. Pointer to the MLU memory that stores the position id of + * generate. + * @param batch: Batch size. + * @param context_seq_len: The sequence length of context. + * @param pos_id_dimension: The dimension of position id, 1 for 1D, 2 for 2D. + */ +KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue, + int *context_pos_id, + int *generate_pos_id, + int batch, + int context_seq_len, + int pos_id_dimension); + +/** + * @brief Increase block position id by one in generate stage. + * @example + * in GLM network, generate_pos_id shape is [batch, 2, 1], data is + * [[[2], [1]], [[5], [1]]] + * after invoke this kernel, the data is + * [[[2], [2]], [[5], [2]]] + * @param queue: The queue for mlu. + * @param generate_pos_id: Output/Input. Pointer to the MLU memory that stores the position id of + * generate. + * @param batch: Batch size. + */ +KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch); +} // namespace tmo + +#endif // CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/kernel_utils.h b/torch_mlu_ops-v1.3.2/csrc/kernels/kernel_utils.h new file mode 100644 index 0000000..e1ccc12 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/kernel_utils.h @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_KERNEL_UTILS_H_ +#define CSRC_KERNELS_KERNEL_UTILS_H_ + +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" + +namespace tmo { +const std::string arch_370 = "MLU370"; +enum class KernelStatus { KERNEL_STATUS_SUCCESS = 0, KERNEL_STATUS_FAILED }; + +#ifndef PAD_DOWN +#define PAD_DOWN(x, y) (((x) / (y)) * (y)) +#endif + +#ifndef PAD_UP +#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y)) +#endif + +inline bool isMlu300(const std::string &dev_name) { + if (dev_name.find("MLU3") != std::string::npos) { + return true; + } else { + return false; + } +} + +inline bool is_arch300() { + int card_id = -1; + cnrtDeviceProp_t dev_info; + CNRT_CHECK(cnrtGetDevice(&card_id)); + CNRT_CHECK(cnrtGetDeviceProperties(&dev_info, card_id)); + std::string dev_name = dev_info.name; + return isMlu300(dev_name); +} + +inline bool isBf16Supported() { + return !is_arch300(); +} +} // namespace tmo +#endif // CSRC_KERNELS_KERNEL_UTILS_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu new file mode 100644 index 0000000..c5452c3 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu @@ -0,0 +1,521 @@ +#include +#include +#include +#include +#include +#include "add_bias_activation.mluh" +#include "cnnl.h" +#include "cnrt.h" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { +#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y)) +#define USE_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 20 * 1024) +#define USE_SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 20 * 1024) + +__nram__ int8_t nram_buffer[USE_NRAM_SIZE]; +__mlu_shared__ int8_t sram_buffer[USE_SRAM_SIZE]; + +__mlu_func__ void get_expert_info(int *nram_count, + int *count_sram, + uint32_t *gather_offset, + int real_inner, + int tokens_start, + int tokens_end, + int tokens_load, + int expert_deal_start, + int expert_deal_end, + uint32_t &expert_start, + uint32_t &expert_end, + uint32_t &tokens_deal_first, + int dtype_size) { + bool record_start = false; + // loop expert to find first and last deal expert in current core + for (int expert_id = expert_deal_start; expert_id <= expert_deal_end; expert_id++) { + if (__load_nram(nram_count + expert_id + 1) > tokens_start && !record_start) { + expert_start = expert_id; + tokens_deal_first = + std::min(__load_nram(nram_count + expert_id + 1) - 1, tokens_end) - tokens_start + 1; + record_start = true; + } + if (__load_nram(nram_count + expert_id + 1) > tokens_end) { + expert_end = expert_id; + break; + } + } + + // record expert offset to gather bias + __bang_write_zero(gather_offset, tokens_load); + int tokens_load_total = 0; + for (int expert_id = expert_start; expert_id <= expert_end; expert_id++) { + int tokens_expand = __load_sram((int *)count_sram + expert_id); + if (expert_id == expert_start) { + tokens_expand = tokens_deal_first; + } else if (expert_id == expert_end) { + tokens_expand = tokens_load - tokens_load_total; + } + if (tokens_expand == 0) { + continue; + } + __bang_write_value(gather_offset + tokens_load_total, tokens_expand, + (int)((expert_id - expert_deal_start) * real_inner * dtype_size)); + tokens_load_total += tokens_expand; + } +} + +/*************** functions for compute basic operation ***************/ + +template +__mlu_func__ void add_bias(T *dst_src, T *bias, int number) { + // cycle add bias + __bang_add((T *)dst_src, (T *)dst_src, (T *)bias, number); +} + +template <> +__mlu_func__ void add_bias(bfloat16_t *dst_src, bfloat16_t *bias, int number) { +#if __BANG_ARCH__ > 500 + __bang_add((bfloat16_t *)dst_src, (bfloat16_t *)dst_src, (bfloat16_t *)bias, number); +#endif +} + +template +__mlu_func__ void mul_left_right(T *left, T *right, int number) { + __bang_mul((T *)left, (T *)left, (T *)right, number); +} + +template <> +__mlu_func__ void mul_left_right(bfloat16_t *left, bfloat16_t *right, int number) { +#if __BANG_ARCH__ > 500 + __bang_mul((bfloat16_t *)left, (bfloat16_t *)left, (bfloat16_t *)right, number); +#endif +} + +__mlu_func__ void do_activation(float *input_left, + float *act_space, + int number, + float active_coef, + cnnlActivationMode_t act_type) { + if (act_type == CNNL_ACTIVATION_GELU) { + __bang_active_gelu((float *)input_left, (float *)input_left, number); + } else if (act_type == CNNL_ACTIVATION_SWISH) { + float *tmp = input_left; + if (active_coef != 1.0f) { + __bang_mul_scalar(act_space, input_left, active_coef, number); + tmp = act_space; + } + __bang_active_sigmoid((float *)act_space, (float *)tmp, number); + __bang_mul((float *)input_left, (float *)input_left, (float *)act_space, number); + } +} + +/*************** functions for steps of each loop ***************/ + +template +__mlu_func__ void gather_bias(T *bias_sram, + T *bias_nram, + uint32_t *gather_offset, + int expert_start, + int expert_end, + int expert_deal_start, + int tokens_deal_first, + int tokens_deal, + int inner_size, + bool is_gated) { +#if __BANG_ARCH__ > 500 + if (is_gated) { + __gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM, + inner_size * sizeof(T), tokens_deal); + __gather_async((T *)bias_nram + tokens_deal * inner_size, (T *)bias_sram + inner_size, + gather_offset, inner_size * sizeof(T), SRAM2NRAM, inner_size * sizeof(T), + tokens_deal); + } else { + __gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM, + inner_size * sizeof(T), tokens_deal); + } +#else + for (int i = 0; i < tokens_deal; i++) { + if (is_gated) { + __memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i], + inner_size * sizeof(T), SRAM2NRAM); + __memcpy_async((T *)bias_nram + (tokens_deal * inner_size + i * inner_size), + (int8_t *)bias_sram + inner_size * sizeof(T) + gather_offset[i], + inner_size * sizeof(T), SRAM2NRAM); + } else { + __memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i], + inner_size * sizeof(T), SRAM2NRAM); + } + } +#endif +} + +template +__mlu_func__ void loadBiasInput(T *input, + T *left, + T *right, + T *bias_nram, + T *bias_sram, + uint32_t *gather_offset, + size_t input_offset, + int tokens_deal, + int inner_size, + uint32_t expert_start, + uint32_t expert_end, + int expert_deal_start, + uint32_t tokens_deal_first, + bool is_gated, + bool has_bias) { + if (is_gated) { + // if gated, stride io load input, left/right inner to input_left/right + __memcpy_async((T *)left, (T *)input + input_offset, inner_size * sizeof(T), GDRAM2NRAM, + inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1); + __memcpy_async((T *)right, (T *)input + input_offset + inner_size, inner_size * sizeof(T), + GDRAM2NRAM, inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1); + } else { + // if not gated, load input to input_left total + __memcpy_async((T *)left, (T *)input + input_offset, tokens_deal * inner_size * sizeof(T), + GDRAM2NRAM); + } + if (has_bias) { + __sync_compute(); + gather_bias((T *)bias_sram, (T *)bias_nram, gather_offset, expert_start, expert_end, + expert_deal_start, tokens_deal_first, tokens_deal, inner_size, is_gated); + } +} + +template +__mlu_func__ void computeAddActivation(T *bias_nram, + T *left_dst, + T *input_right, + float *input_left, + float *act_space, + int tokens_deal, + int inner_size, + bool is_gated, + bool has_bias, + float active_coef, + cnnlActivationMode_t act_type) { + int number = tokens_deal * inner_size; + if (has_bias) { + add_bias((T *)left_dst, (T *)bias_nram, number); + if (is_gated) { + add_bias((T *)input_right, (T *)bias_nram + tokens_deal * inner_size, number); + } + } + + // cast half/bfloat16 to float to acvication, if float, left_dst is same as input_left + if (std::is_same::value) { + __bang_half2float((float *)input_left, (half *)left_dst, number); + } +#if __BANG_ARCH__ > 500 + if (std::is_same::value) { + __bang_bfloat162float((float *)input_left, (bfloat16_t *)left_dst, number); + } +#endif + + // activation + do_activation(input_left, act_space, number, active_coef, act_type); + + // if half/bfloat16, cast float to T to mul + if (std::is_same::value) { + __bang_float2half((half *)input_left, (float *)input_left, number); + } +#if __BANG_ARCH__ > 500 + if (std::is_same::value) { + __bang_float2bfloat16((bfloat16_t *)input_left, (float *)input_left, number); + } +#endif + + if (is_gated) { + mul_left_right((T *)input_left, (T *)input_right, tokens_deal * inner_size); + } +} + +template +__mlu_func__ void storeOutput(T *output, + T *output_nram, + size_t output_offset, + int output_stride, + int tokens_deal, + int inner_size) { + __memcpy_async((T *)output + output_offset, (T *)output_nram, inner_size * sizeof(T), NRAM2GDRAM, + output_stride * sizeof(T), inner_size * sizeof(T), tokens_deal - 1); +} + +template +__mlu_global__ void MLUAddBiasActivationKernel(T *output, + const T *input, + const T *bias, + const int *cusum_token_count, + int num_expert, + int total_tokens, + int inner_size, + int output_stride, + bool is_gated, + cnnlActivationMode_t act_type, + int start_expert_id, + int expert_size, + float active_coef) { + // if bias and token_count is nullptr, not add bias, only activation and gated mul. + bool has_bias = (bias != nullptr); + + // 1. distrubute nram space + /* if not gated + ---------------------------- ---------------------------------- + | nram_token_count | bias_ping/pong | + | num_expert * sizeof(int) | 2 * x * inner_size * sizeof(T) | + ---------------------------- ---------------------------------- + -------------------------------------- + | input_ping/pong | + | 2 * x * inner_size * sizeof(float) | + -------------------------------------- + ------------------------------------------------- ------------------- + | act_space | gather_offset | + | gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) | + ------------------------------------------------- ------------------- + */ + + /* if gated + ---------------------------- ---------------------------------- + | nram_token_count | bias_ping/pong | + | num_expert * sizeof(int) | 2 * x * real_inner * sizeof(T) | + ---------------------------- ---------------------------------- + -------------------------------------- ---------------------------------- + | input_left_ping/pong | input_right_ping/pong | + | 2 * x * inner_size * sizeof(float) | 2 * x * inner_size * sizeof(T) | + -------------------------------------- ---------------------------------- + ------------------------------------------------- ------------------- + | act_space | gather_offset | + | gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) | + ------------------------------------------------- ------------------- + */ + + // distribute sram + int8_t *count_sram = (int8_t *)sram_buffer; + int8_t *bias_sram = (int8_t *)count_sram + num_expert * sizeof(int); + + // distrubute nram + int real_inner = (is_gated) ? inner_size * 2 : inner_size; + int bias_nram_size = (has_bias) ? real_inner * sizeof(T) : 0; + int act_space_size = (act_type == CNNL_ACTIVATION_GELU) ? 0 : inner_size * sizeof(float); + int gated_ext_size = is_gated ? sizeof(T) : 0; + int max_token_deal = (USE_NRAM_SIZE - (num_expert + 1) * sizeof(int)) / + (2 * inner_size * (sizeof(float) + gated_ext_size) + act_space_size + + 2 * bias_nram_size + sizeof(int)); + + int8_t *nram_count = (int8_t *)nram_buffer; + int8_t *bias_nram = (int8_t *)nram_count + (num_expert + 1) * sizeof(int); + int8_t *input_left = (int8_t *)bias_nram + 2 * max_token_deal * bias_nram_size; + int8_t *input_right = + (int8_t *)input_left + 2 * ((is_gated) ? max_token_deal * inner_size * sizeof(float) : 0); + int8_t *act_space = (int8_t *)input_right + + 2 * max_token_deal * inner_size * (is_gated ? sizeof(T) : sizeof(float)); + int8_t *gather_offset = (int8_t *)act_space + max_token_deal * act_space_size; + + // 2. cusum_token_count load to nram, because need to reuse in load bias. + if (has_bias) { + __memcpy((int *)nram_count, (int *)cusum_token_count, (num_expert + 1) * sizeof(int), + GDRAM2NRAM); + if (taskIdX == 0) { + __bang_sub((int *)bias_nram, (int *)nram_count + 1, (int *)nram_count, num_expert); + __sync(); + __memcpy((int *)count_sram, (int *)bias_nram, num_expert * sizeof(int), NRAM2SRAM); + } + __sync_cluster(); + } + + // 3. sram loop to compute + // compute once load bias to sram due to sram_limit + int max_expert_deal = (USE_SRAM_SIZE - num_expert * sizeof(int)) / (real_inner * sizeof(T)); + int real_expert = cusum_token_count == nullptr ? num_expert : expert_size; + int sram_loop_rem = real_expert % max_expert_deal; + int sram_loop = real_expert / max_expert_deal + (int)(sram_loop_rem != 0); + if (!has_bias) { + max_expert_deal = real_expert; + sram_loop = 1; + sram_loop_rem = 0; + } + for (int deal_loop = 0; deal_loop < sram_loop; deal_loop++) { + // load current bias, compute each core deal number + int expert_deal = + (deal_loop == (sram_loop - 1) && sram_loop_rem != 0) ? sram_loop_rem : max_expert_deal; + int expert_deal_start = deal_loop * max_expert_deal + start_expert_id; + int expert_deal_end = expert_deal_start + expert_deal - 1; + __sync_all(); + if (has_bias && __is_mpu()) { + __memcpy((T *)bias_sram, (T *)bias + deal_loop * max_expert_deal * real_inner, + expert_deal * real_inner * sizeof(T), GDRAM2SRAM); + } + __sync_all(); + + // get tokens info of each core + int tokens_total_cur = total_tokens; + if (has_bias) { + tokens_total_cur = + __load_nram((int *)nram_count + expert_deal_end + 1) - + (start_expert_id == 0 ? 0 : __load_nram((int *)nram_count + expert_deal_start)); + } else if (cusum_token_count != nullptr) { + tokens_total_cur = __load_gdram(cusum_token_count + expert_deal_end + 1) - + __load_gdram(cusum_token_count + expert_deal_start); + } + + if (sram_loop != 1) { + tokens_total_cur = ((int *)nram_count)[expert_deal_end + 1] - + ((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]); + if (deal_loop == 0 && start_expert_id != 0) { + tokens_total_cur -= __load_nram((int *)nram_count + expert_deal_start); + } + } + + int tokens_core_rem = tokens_total_cur % taskDim; + int tokens_cur_core = tokens_total_cur / taskDim + (taskId < tokens_core_rem); + if (tokens_cur_core == 0) { + continue; + } + + // if ep, input start in current token, have a real start in total network + int real_start = + cusum_token_count != nullptr ? __load_gdram((int *)cusum_token_count + start_expert_id) : 0; + int tokens_core_start = tokens_cur_core * taskId + + (taskId < tokens_core_rem ? 0 : tokens_core_rem) + + ((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]); + if (deal_loop != 0) { + tokens_core_start -= real_start; + } + + uint32_t expert_start = 0; + uint32_t expert_end = 0; + uint32_t tokens_deal_first = 0; + + // 4. nram loop compute + int nram_loop_rem = tokens_cur_core % max_token_deal; + int nram_loop = tokens_cur_core / max_token_deal + (int)(nram_loop_rem != 0); + int tokens_load = max_token_deal; + int tokens_compute = max_token_deal; + int tokens_store = max_token_deal; + + for (int loop = 0; loop < nram_loop + 2; loop++) { + int inner_io_offset = (loop % 2) * max_token_deal * inner_size; + int inner_com_offset = ((loop + 1) % 2) * max_token_deal * inner_size; + int real_io_offset = (loop % 2) * max_token_deal * real_inner; + int real_com_offset = ((loop + 1) % 2) * max_token_deal * real_inner; + if (nram_loop_rem != 0) { + if (loop > 1 && (loop - 2) == (nram_loop - 1)) { + tokens_store = nram_loop_rem; + } + if (loop > 0 && (loop - 1) == (nram_loop - 1)) { + tokens_compute = nram_loop_rem; + } + if (loop == (nram_loop - 1)) { + tokens_load = nram_loop_rem; + } + } + int tokens_cur_start = tokens_core_start + loop * max_token_deal; + int tokens_cur_end = tokens_cur_start + tokens_load - 1; + // get current load info + if (loop < nram_loop && has_bias) { + get_expert_info((int *)nram_count, (int *)count_sram, (uint32_t *)gather_offset, real_inner, + tokens_cur_start + real_start, tokens_cur_end + real_start, tokens_load, + expert_deal_start, expert_deal_end, expert_start, expert_end, + tokens_deal_first, sizeof(T)); + } + + // store + if (loop > 1) { + size_t output_offset = (tokens_core_start + (loop - 2) * max_token_deal) * output_stride; + storeOutput((T *)output, (T *)((float *)input_left + inner_io_offset), output_offset, + output_stride, tokens_store, inner_size); + } + + // compute + if (loop > 0 && loop <= nram_loop) { + T *left_dst = (T *)((float *)input_left + inner_com_offset) + + ((std::is_same::value) ? 0 : tokens_compute * inner_size); + computeAddActivation((T *)bias_nram + real_com_offset, (T *)left_dst, + (T *)input_right + inner_com_offset, + (float *)input_left + inner_com_offset, (float *)act_space, + tokens_compute, inner_size, is_gated, has_bias, active_coef, act_type); + } + + // load + if (loop < nram_loop) { + T *left_dst = (T *)((float *)input_left + inner_io_offset) + + ((std::is_same::value) ? 0 : tokens_load * inner_size); + size_t input_offset = tokens_cur_start * real_inner; + loadBiasInput((T *)input, (T *)left_dst, (T *)input_right + inner_io_offset, + (T *)bias_nram + real_io_offset, (T *)bias_sram, (uint32_t *)gather_offset, + input_offset, tokens_load, inner_size, expert_start, expert_end, + expert_deal_start, tokens_deal_first, is_gated, has_bias); + } + __sync(); + } + } +} + +} // namespace kernels + +KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue, + void *output, + const void *input, + const void *bias, + const int *cusum_token_count, + int num_expert, + int total_tokens, + int inner_size, + int output_stride, + cnnlDataType_t dtype, + bool is_gated, + cnnlActivationMode_t act_type, + int start_expert_id, + int expert_size, + float active_coef) { + if (bias != NULL && cusum_token_count == NULL) { + std::cerr << "[invokeGroupAddBiasActivationKernel]: " + << "when have bias, cusum_token_count can not be nullptr."; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (act_type != CNNL_ACTIVATION_GELU && act_type != CNNL_ACTIVATION_SWISH) { + std::cerr << "[invokeGroupAddBiasActivationKernel]: " + << "activation mode only supports gelu and swish now."; + return KernelStatus::KERNEL_STATUS_FAILED; + } + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; + + if (dtype == CNNL_DTYPE_FLOAT) { + kernels::MLUAddBiasActivationKernel<<>>( + (float *)output, (const float *)input, (const float *)bias, cusum_token_count, num_expert, + total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size, + active_coef); + } else if (dtype == CNNL_DTYPE_HALF) { + kernels::MLUAddBiasActivationKernel<<>>( + (half *)output, (const half *)(input), (const half *)bias, cusum_token_count, num_expert, + total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size, + active_coef); + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + if (!isBf16Supported()) { + std::cerr << "[invokeGroupAddBiasActivationKernel]: MLU300 devices do not support bfloat16." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + kernels::MLUAddBiasActivationKernel<<>>( + (bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)bias, + cusum_token_count, num_expert, total_tokens, inner_size, output_stride, is_gated, act_type, + start_expert_id, expert_size, active_coef); + } else { + std::cerr << "[invokeGroupAddBiasActivationKernel]: add_bias_activation data_type not support, " + << "only support float/half/bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mluh new file mode 100644 index 0000000..68f5d92 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mluh @@ -0,0 +1,84 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_ +#define CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_ +#include "../kernel_utils.h" +#include "cnnl.h" +namespace tmo { +/** + * @brief Add bias and activate to all tokens. Different expert with different bias. + * If in gated mode, add bias to all input. But only activate in the + * input of [:, :inner_size]. Then multiply it to the input of [:, inner_size:]. + * Else, add bias and activate to all input, has no multiply process. + * @example + * is_gated = true, num_expert = 4, total_tokens = 6, inner_size = 2 + * input: (6, 4) = [[2, 4, 5, 6], [1, 4, 5, 3], + * [3, 5, 7, 8], [6, 8, 5, 3], + * [2, 3, 4 ,5], [2, 9, 2, 3]] + * bias: (4, 4) = [[1, 0, 1, 0], [0, 1, 2, 2], [2, 3, 2, 3], [1, 2, 3, 4]] + * token_count = [2, 2, 1, 1] + * first step: add bias + * [[2+1, 4+0, 5+1, 6+0], [1+1, 4+0, 5+1, 3+0], + * [3+0, 5+1, 7+2, 8+2], [6+0, 8+1, 5+2, 3+2], + * [2+2, 3+3, 4+2, 5+3], [2+1, 9+2, 2+3, 3+4]] + * second step: act and mul + * output: (6, 2) = [[act(3)*6, act(4)*6], [act(2)*6, act(4)*3], + * [act(3)*9, act(6)*10], [act(6)*7, act(9)*5], + * [act(4)*6, act(6)*8], [act(3)*5, act(11)*7]] + * @param queue: The queue for mlu. + * @param output: Output. Pointer to the MLU memory that stores the result. + * When is_gated is true, The shape is [total_tokens, input_size / 2]. + * In this case, the input_size must be even. Otherwise the shape is [total_tokens, + * input_size]. The memory can be discontinuous in total_tokens dim. The stride is output_stride. + * @param input: Input. Pointer to the MLU memory that stores the input tokens. + * The shape is [total_tokens, input_size]. + * When is_gated is true, the shape is [total_tokens, 2 * inner_size]. + * Otherwise the shape is [total_tokens, inner_size]. + * @param bias: Input. Pointer to the MLU memory that stores the bias. The memory must be + * continuous. When is_gated is true, the shape is [num_expert, 2 * inner_size]. Otherwise the shape + * is [num_expert, inner_size]. Bias can be nullptr. If bias is nullptr, has no add bias process. + * @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of token + * counts. The shape is [num_expert + 1]. If cusum_token_count + * is not nullptr, cusum_token_count, start_expert_id and + * expert_size together determine which tokens to process. + * If cusum_token_count is nullptr, process all tokens, + * the number of which is total_tokens. When bias is not nullptr, + * cusum_token_count must also not be nullptr. + * @param num_expert: The number of expert. + * @param total_tokens: The total number of tokens. + * @param inner_size: The inner size of output. + * @param output_stride: The stride of output, must be greater than or equal to inner_size. + * @param dtype: Data type. + * @param is_gated: Gated or not. + * @param act_type: The type of activation. Support gelu and swish. + * @param start_expert_id: The index of the start expert. + * @param expert_size: The number of experts to process. + * @param active_coef: The coefficient used in the swish activation. + */ +KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue, + void *output, + const void *input, + const void *bias, + const int *cusum_token_count, + int num_expert, + int total_tokens, + int inner_size, + int output_stride, + cnnlDataType_t dtype, + bool is_gated, + cnnlActivationMode_t act_type, + int start_expert_id, + int expert_size, + float active_coef); +} // namespace tmo + +#endif // CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu new file mode 100644 index 0000000..123fa9f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu @@ -0,0 +1,646 @@ +#include +#include +#include +#include +#include +#include "cast_gating.mluh" +#include "cnnl.h" +#include "cnrt.h" +// clang-format off +#include +// clang-format on +namespace tmo { +#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y))) + +#define NRAM_BUFFER_SIZE (496 * 1024) +#define WRAM_BUFFER_SIZE (512 * 1024) +#define SRAM_BUFFER_SIZE (2032 * 1024) + +#ifndef ONE_LINE +#define ONE_LINE 64 +#endif + +#ifndef LT_NUM +#define LT_NUM 64 +#endif + +struct castGatingTileInfo { + int32_t block = 64; + int32_t split_k_num = 8; + int32_t block_k = 256; +}; + +namespace kernels { +#pragma bang walign(16) +#ifndef ROW_PER_LT +#define ROW_PER_LT 4 +#endif + +#ifndef LT_SIZE +#define LT_SIZE 16 +#endif + +#ifndef WRAM_LT_MAP16_STRIDE +#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16) +#endif + +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; +__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE]; +__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE]; + +#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \ + do { \ + uint32_t align_num = 64 / src_dsize; \ + uint32_t n = PAD_DOWN(size / src_dsize, align_num); \ + uint32_t rem = size % 64; \ + if (n) { \ + __asm__ __volatile__( \ + "move.tiling.async.nram.sram.b16" \ + " [%[dst_addr]], [%[src_addr]], " \ + "%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \ + "%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \ + "%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \ + "%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \ + "%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \ + [src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \ + [src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \ + [src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \ + [src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \ + [dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \ + [dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \ + } \ + \ + if (rem) { \ + __asm__ __volatile__( \ + "move.tiling.async.nram.sram.b16" \ + " [%[dst_addr]], [%[src_addr]], " \ + "%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \ + "%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \ + "%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \ + "%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \ + "%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \ + [src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \ + [src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \ + [src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \ + [dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \ + [dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \ + } \ + } while (false) + +__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) { +#if __BANG_ARCH__ >= 500 + SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()"); +#endif +} + +__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) { +#if __BANG_ARCH__ >= 500 + SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()"); +#endif +} + +__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) { + __memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM); +} + +#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \ + convert_type) \ + int align_n = PAD_DOWN(n, LT_NUM); \ + int sn0 = ONE_LINE; \ + int size_sn0 = sn0 / src_dsize; \ + int sn1 = ONE_LINE / src_dsize; \ + int ss1 = total_k * src_dsize; \ + int sn3 = k / size_sn0; \ + int sn4 = align_n / sn1; \ + int ss4 = sn1 * ss1; \ + int dn0 = sn0; \ + int dn1 = ROW_PER_LT; \ + int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \ + int ds1 = dst_k * dst_dsize; \ + int dn2 = sn1 / ROW_PER_LT; \ + int ds2 = WRAM_LT_MAP16_STRIDE; \ + int ds3 = sn0 * dst_dsize / src_dsize; \ + int dn4 = LT_SIZE / dn2; \ + int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \ + int dn5 = align_n / LT_NUM; \ + int ds5 = ROW_PER_LT * dst_k * dst_dsize; \ + int rem_k = k % size_sn0; \ + int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \ + int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \ + if (align_n > 0 && sn3 > 0) { \ + __asm__ __volatile__( \ + "move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \ + "%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \ + "%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \ + "%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \ + "%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \ + ";\n\t" ::[dst_addr] "r"(wram_dst), \ + [src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \ + [src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \ + [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \ + [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \ + [dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \ + sram_src += align_n * total_k; \ + wram_dst += align_n / LT_SIZE * dst_k; \ + } \ + align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \ + if (align_n > 0 && sn3 > 0) { \ + sn1 = align_n; \ + dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \ + __asm__ __volatile__( \ + "move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \ + "%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \ + "%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \ + "1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \ + [src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \ + [src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \ + [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \ + [dst_s3] "r"(ds3)); \ + sram_src += align_n * total_k; \ + wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \ + } \ + if (rem_k > 0) { \ + align_n = PAD_UP(n, LT_NUM); \ + sn0 = rem_k * src_dsize; \ + dn0 = sn0; \ + __asm__ __volatile__( \ + "move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \ + "%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \ + "%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \ + "%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \ + "%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \ + ";\n\t" ::[dst_addr] "r"(wram_dst2), \ + [src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \ + [src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \ + [src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \ + [dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \ + [dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \ + [dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \ + [dst_s5] "r"(0)); \ + } + +__mlu_func__ void warp_prompt_weight(float *wram_dst, + half *sram_src, + int32_t warp_n, + int32_t len_k, + int32_t total_k) { +#if __BANG_ARCH__ >= 500 + SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half), + ".cvt.rn.f32.f16()"); +#endif +} + +__mlu_func__ void warp_prompt_weight(float *wram_dst, + bfloat16_t *sram_src, + int32_t warp_n, + int32_t len_k, + int32_t total_k) { +#if __BANG_ARCH__ >= 500 + SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), + sizeof(bfloat16_t), ".cvt.rn.f32.bf16()"); +#endif +} + +template +__mlu_func__ void warp_prompt_weight(T *wram_dst, + T *sram_src, + int32_t n, + int32_t len_k, + int32_t total_k) { + int32_t type_size = sizeof(T); + int32_t data_size = len_k * type_size; + int32_t ds0 = PAD_UP(data_size, ONE_LINE); + int32_t ss0 = total_k * type_size; + int32_t count = n / LT_NUM; + for (int32_t i = 0; i < count; ++i) { + __memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1, + WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0); + wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0); + sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0); + } + count = n % LT_NUM / ROW_PER_LT; + if (count > 0) { + __memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1, + WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0); + wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE); + sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0); + } + count = n % ROW_PER_LT; + if (count > 0) { + __memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1); + } +} + +__mlu_func__ void assignTaskEvenly(const int32_t num_total_task, + const int32_t &taskid, + const int32_t &taskdim, + int32_t &task_offset, + int32_t &num_cur_task) { + int32_t num_per_task = num_total_task / taskdim; + int32_t rem_idx = num_total_task % taskdim; + if (taskid < rem_idx) { + task_offset = taskid * (num_per_task + 1); + num_cur_task = num_per_task + 1; + } else { + task_offset = taskid * num_per_task + rem_idx; + num_cur_task = num_per_task; + } +} + +__mlu_func__ void bidirectionBarrierOp() { + int32_t bcnt = coreDim + 1; + if (__is_ipu()) { + __asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt)); + __asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt)); + } else { + __asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt)); + __asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt)); + } +} + +__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) { + __bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n); +} + +__mlu_func__ void warp_store(void *ddr_dst, + void *nram_src, + const int32_t data_num, + const int32_t dst_stride, + const int32_t src_stride, + const int32_t count, + const int32_t dt_size) { + if (src_stride == data_num && dst_stride == data_num) { + __memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM); + } else { + __memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size, + src_stride * dt_size, count - 1); + } +} + +template +__mlu_func__ void splitKReduce(Tcc *workspace, + Tc *output, + int32_t M, + int32_t N, + int32_t split_k_num, + int32_t ldc) { + int32_t offset_m, cta_m; + assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m); + if (cta_m <= 0) return; + int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc); + int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0); + int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m; + Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N; + Tc *output_ddr = (Tc *)output + offset_m * ldc; + for (int32_t i = 0; i < repeat; i++) { + int32_t current_m = i == repeat - 1 ? rem_m : block_m; + int32_t data_size = N * sizeof(Tc); + int32_t data_num = current_m - 1; + if (ldc == N) { + data_size = current_m * N * sizeof(Tc); + data_num = 0; + } + __memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM, + current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1); + __bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1, + split_k_num, 1, 1, 1); + __memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc), + N * sizeof(Tc), data_num); + workspace_ddr = workspace_ddr + block_m * N; + output_ddr = output_ddr + block_m * ldc; + } +} + +template +__mlu_global__ void MLUCastGating(Ta *A, + Tb *B, + Tc *C, + Tcc *workspace, + int32_t M, + int32_t N, + int32_t K, + int32_t lda, + int32_t ldb, + int32_t ldc, + castGatingTileInfo split_info) { +#if __BANG_ARCH__ >= 500 + int32_t block_k = split_info.block_k; + int32_t grid_dimx = split_info.split_k_num; + int32_t block = split_info.block; + int32_t grid_idx = clusterId % grid_dimx; + int32_t grid_idy = clusterId / grid_dimx; + int32_t offset_k = 0, problem_k = 0; + assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k); + int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k; + int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0); + int32_t cta_k = k_loop == 1 ? rem_k : block_k; + int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0; + int32_t warp_m = cta_m, warp_offset_m = 0; + int32_t warp_n = cta_n, warp_offset_n = 0; + int32_t outer_loop = 0, outer_rem = 0; + if (EXCHANGE_AB) { + assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n); + assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n); + if (cta_n > block && cta_n % block != 0) { + int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM); + if (block_tmp < block) block = block_tmp; + } + outer_loop = (cta_n + block - 1) / block; + outer_rem = cta_n % block == 0 ? block : cta_n % block; + } else { + assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m); + assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m); + if (cta_m > block && cta_m % block != 0) { + int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim); + if (block_tmp < block) block = block_tmp; + } + outer_loop = (cta_m + block - 1) / block; + outer_rem = cta_m % block == 0 ? block : cta_m % block; + } + + int32_t size_nram_buf = + NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB)); + int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac); + Tac *nbuf_a = (Tac *)nram_buffer; + Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf); + Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c; + + int32_t size_sram_buf = SRAM_BUFFER_SIZE; + int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta); + int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb); + Ta *sbuf_a = (Ta *)sram_buffer; + Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k)); + + int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc); + Tbc *wbuf_b = (Tbc *)wram_buffer; + + int32_t a_dsize = sizeof(Ta); + int32_t b_dsize = sizeof(Tb); + int32_t k_loop_count = 0; + for (int32_t j = 0; j < outer_loop; j++) { + Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB); + Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB); + int32_t current_block = j == outer_loop - 1 ? outer_rem : block; + if (EXCHANGE_AB) { + assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n); + } else { + assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m); + } + int32_t compute_total = warp_m * warp_n; + if (compute_total > 0 && __is_ipu()) { + if (!EXCHANGE_AB) { + __sync_io_move_compute(true, false, false, false, false, true); + } + __bang_write_zero((Tcc *)nbuf_c, compute_total); + } + int32_t i = 0; + for (; i < k_loop; i++) { + Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a; + Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b; + cta_k = i == k_loop - 1 ? rem_k : block_k; + if (__is_mpu()) { + if (EXCHANGE_AB) { + __memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize, + current_block - 1); + __asm__ volatile( + "ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], " + "%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a), + [src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize), + [src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1)); + } else { + __memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize, + current_block - 1); + __asm__ volatile( + "ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], " + "%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b), + [src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize), + [src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1)); + } + a_ddr = (Ta *)a_ddr + block_k; + b_ddr = (Tb *)b_ddr + block_k; + } + bidirectionBarrierOp(); + if (__is_ipu() && compute_total > 0) { + __sync_io_move_compute(false, true, false, false, false, true); + __sync_io_move_compute(false, false, true, false, true, false); + if (i >= 1) { + __wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram, + (Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k); + } + warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram, + sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta)); + // mvdma bound for EXCHANGE_AB when n==32 + warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram, + (Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k); + } + k_loop_count += 1; + } + if (compute_total > 0) { + __sync_io_move_compute(false, true, false, false, false, true); + __wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram, + (Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k); + if (EXCHANGE_AB) { + __sync_io_move_compute(true, false, false, false, false, true); + __bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n); + } + int32_t total_offset = + grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M + : (offset_m + warp_offset_m + block * j) * N); + Tcc *wks = (Tcc *)workspace + total_offset; + int32_t store_c_size = sizeof(Tcc); + int8_t *store_ddr = (int8_t *)wks; + int32_t dst_str = EXCHANGE_AB ? M : N; + if (grid_dimx == 1) { + // convert Tcc to Tc + dst_str = ldc; + store_ddr = + (int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc + : (offset_m + warp_offset_m + block * j) * ldc)); + } + __asm__ volatile("sync.psimd.cio;\n\t"); + if (EXCHANGE_AB) { + warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size); + } else { + warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size); + } + } + } + if (grid_dimx != 1) { + __sync_all(); + splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N, + split_info.split_k_num, ldc); + } +#endif // __BANG_ARCH__ >= 500 +} +} // namespace kernels + +int32_t getBlock(int32_t m, + int32_t n, + int32_t core_num, + int32_t block_k, + int32_t a_dtype_size, + int32_t b_dtype_size, + int32_t compute_dtype_size, + bool EXCHANGE_AB) { + int32_t block = 0; + if (EXCHANGE_AB) { + int32_t block_m = n; + int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) / + (2 * block_m * compute_dtype_size) * core_num; + int32_t wram_block_n = + WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num; + int32_t sram_block_n = + (SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2); + int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n); + int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM); + return block_n > 0 ? block_n : block_n_tmp; + } else { + int32_t block_n = n; + int32_t nram_block_m = + NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2); + int32_t sram_block_m = + (SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2); + block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num)); + return block; + } +} + +void gatingTiling(int32_t m, + int32_t n, + int32_t k, + size_t a_dtype_size, + size_t b_dtype_size, + size_t compute_dtype_size, + size_t workspace_size, + int32_t union_number, + int32_t core_num, + int32_t &block, + int32_t &split_k_num, + int32_t &block_k, + bool &EXCHANGE_AB) { + block_k = std::min(k, int32_t(512 / a_dtype_size)); + split_k_num = 1; + // swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian + if (m >= core_num * LT_NUM && + float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) { + EXCHANGE_AB = true; + } + int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size, + compute_dtype_size, EXCHANGE_AB); + int32_t total_blocks = DIV_UP((size_t)m, tmp_block); + block = tmp_block; + if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) { + for (int32_t i = total_blocks; i <= union_number; i++) { + if (union_number % i == 0) { + int32_t tmp_split_k = union_number / i; + size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size; + if (workspace_size >= workspace_size_need) { + split_k_num = tmp_split_k; + block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block); + if (EXCHANGE_AB && block > LT_NUM * core_num) { + block = PAD_DOWN(block, LT_NUM * core_num); + } + break; + } + } + } + } +} + +void getContxtInfo(int32_t *union_number, int32_t *core_num) { + CNdev dev; + cnCtxGetDevice(&dev); + CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev)); +} + +KernelStatus invokeCastGating(cnrtQueue_t queue, + void *input, + void *filter, + void *output, + int input_row, + int expert_num, + int hidden_size, + cnnlDataType_t a_dtype, + void *workspace, + size_t workspace_size_bytes) { + if (is_arch300()) { + std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + if (expert_num > 128) { + std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) { + std::cerr + << "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + if (workspace_size_bytes > 0 && workspace == NULL) { + std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is " + "greater than 0." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + int32_t union_number, core_num; + getContxtInfo(&union_number, &core_num); + cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num); + cnrtDim3_t dim; + dim.x = (int32_t)func_type; + dim.y = 1; + dim.z = 1; + + cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT; + cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT; + size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0; + cnnlGetSizeOfDataType(a_dtype, &a_dtype_size); + cnnlGetSizeOfDataType(b_dtype, &b_dtype_size); + cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size); + castGatingTileInfo split_info; + bool EXCHANGE_AB = false; + gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size, + workspace_size_bytes, union_number, core_num, split_info.block, + split_info.split_k_num, split_info.block_k, EXCHANGE_AB); + if (a_dtype == CNNL_DTYPE_BFLOAT16) { + if (EXCHANGE_AB) { + kernels::MLUCastGating + <<>>((float *)filter, (bfloat16_t *)input, (float *)output, + (float *)workspace, expert_num, input_row, hidden_size, + hidden_size, hidden_size, expert_num, split_info); + } else { + kernels::MLUCastGating + <<>>((bfloat16_t *)input, (float *)filter, (float *)output, + (float *)workspace, input_row, expert_num, hidden_size, + hidden_size, hidden_size, expert_num, split_info); + } + } else if (a_dtype == CNNL_DTYPE_HALF) { + if (EXCHANGE_AB) { + kernels::MLUCastGating + <<>>((float *)filter, (half *)input, (float *)output, + (float *)workspace, expert_num, input_row, hidden_size, + hidden_size, hidden_size, expert_num, split_info); + } else { + kernels::MLUCastGating + <<>>((half *)input, (float *)filter, (float *)output, + (float *)workspace, input_row, expert_num, hidden_size, + hidden_size, hidden_size, expert_num, split_info); + } + } else { + std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh new file mode 100644 index 0000000..014ed31 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh @@ -0,0 +1,50 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_CAST_GATING_MLUH_ +#define CSRC_KERNELS_CAST_GATING_MLUH_ + +#include "../kernel_utils.h" +#include "cnnl.h" +namespace tmo { +/** + * @brief Convert input to float32 and do gating operation. + * @param queue: The queue for mlu. + * @param input: Input. Pointer to the MLU memory that stores the input, + * the shape must be [input_row, hidden_size]. + * @param filter: Input. Pointer to the MLU memory that stores the weight, + * the shape must be [expert_num, hidden_size]. + * @param output: Output. Pointer to the MLU memory that stores the output, + * the shape must be [input_row, expert_num]. + * @param input_row: Input. + * @param expert_num: Input. + * @param hidden_size: Input. + * @param a_dtype: Input. The data-type of input. + * @param workspace: Input. Pointer to the MLU workspace. + * @param workspace_size_bytes: Input. The size of workspace in bytes. + * @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF. + * expert_num must be in range [1, 128]. + * If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024. + * The data-type of filter and output must be float. + * cast_gating only supports MLU500 device or higher. + */ +KernelStatus invokeCastGating(cnrtQueue_t queue, + void *input, + void *filter, + void *output, + int input_row, + int expert_num, + int hidden_size, + cnnlDataType_t a_dtype, + void *workspace, + size_t workspace_size_bytes); +} // namespace tmo +#endif // CSRC_KERNELS_CAST_GATING_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu new file mode 100644 index 0000000..6264ef3 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu @@ -0,0 +1,760 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "cnrt.h" +#include "combine_result.mluh" +// clang-format off +#include +#include +// clang-format on + +#if __BANG_ARCH__ >= 592 +#include +template +using bang_fusor = bang::experimental::fusor; +#endif + +namespace tmo { +namespace kernels { +#define NRAM_REMAIN_SIZE (32 * 1024) +#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) + +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; + +template +__mlu_func__ void swap(T *&ping, T *&pong) { + T *temp = ping; + ping = pong; + pong = temp; +} + +#define GATHER_ASYNC_IO0(offset_type) \ + __asm__ __volatile__( \ + "gather.vector.async.nram.gdram.nram." #offset_type \ + ".io0 [%[dst]], [%[src]], [%[offset]], " \ + "%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \ + [src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \ + [transfer_num] "r"(token_count), [stride] "r"(transfer_size)) + +#define FUSE_MUL_CVT(dst_dtype) \ + __asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \ + ".f32 [%[dst]], [%[src0]], %[src1]," \ + " %[size];\n\t" ::[dst] "r"(dst), \ + [src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size)); + +#define FUSE_MULADD_CVT(dst_dtype) \ + __asm__ __volatile__("muladd.nram.crn." #dst_dtype \ + ".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \ + " %[size], %[size];\n\t" ::[dst] "r"(dst), \ + [src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size)); + +template +__mlu_func__ void toFloat(float *dst, T *src, int count) { + if (std::is_same::value) { + __bang_half2float(dst, (half *)src, count); + } else if (std::is_same::value) { + __bang_bfloat162float(dst, (bfloat16_t *)src, count); + } else if (std::is_same::value) { + __bang_add_scalar((float *)dst, (float *)src, (float)0, count); + } +} + +template +__mlu_func__ void floatTo(T *dst, float *src, int count) { + if (std::is_same::value) { + __bang_float2half_rn((half *)dst, src, count); + } else if (std::is_same::value) { + __bang_float2bfloat16_rn((bfloat16_t *)dst, src, count); + } else if (std::is_same::value) { + __bang_add_scalar((float *)dst, (float *)src, (float)0, count); + } +} + +__mlu_func__ void loadAsync2d(void *dst, + void *src, + int size, + int dststride, + int srcstride, + int seg_num) { +#if __BANG_ARCH__ > 500 + __asm__ __volatile__( + "ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]]," + " %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst), + [src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride), + [segnum] "r"(seg_num)); +#else + __memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num); +#endif +} + +__mlu_func__ void storeAsync2d(void *dst, + void *src, + int size, + int dststride, + int srcstride, + int seg_num) { +#if __BANG_ARCH__ > 500 + __asm__ __volatile__( + "st.async.stride.gdram.nram.io1 [%[dst]], [%[src]]," + " %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst), + [src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride), + [segnum] "r"(seg_num)); +#else + __memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num); +#endif +} + +template +__mlu_func__ void gatherTokensAsync(void *dst, + void *src_gdram, + T_IDX *nram_offset, + int transfer_size, + int token_count) { + if (token_count <= 0 || src_gdram == nullptr) return; +#if __BANG_ARCH__ > 500 + if (std::is_same::value) { + GATHER_ASYNC_IO0(u32); + } else { + GATHER_ASYNC_IO0(u64); + } +#else + for (int k = 0; k < token_count; k++) { + __memcpy_async((int8_t *)dst + k * transfer_size, + (int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM); + } +#endif +} + +__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx, + int *nram_mask, + uint8_t *nram_mask_char, + int *nram_mask_buffer, + int begin_expert_acc_tokens, + int end_expert_acc_tokens, + int token_count, + bool expert_parallelism) { + if (!expert_parallelism) { + return token_count; + } + + __bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count); +#if __BANG_ARCH__ >= 592 + bang_fusor(nram_mask, nram_token_idx, token_count) + .ge(begin_expert_acc_tokens) + .land(nram_mask_buffer) + .cvt(0); +#else + __bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count); + __bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count); + __bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0); +#endif + __bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count); + int active_token_count = __bang_count((float *)nram_mask, token_count); + return active_token_count; +} + +__mlu_func__ void computeOffset0(uint64_t *nram_offset, + int *nram_idx, + uint64_t mul_scalar, + int64_t add_scalar, + uint32_t token_count) { +#if __BANG_ARCH__ > 592 + __bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0); +#else + __bang_int322int64((int64_t *)nram_offset, nram_idx, token_count); +#endif + __bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count); + __bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count); +} + +__mlu_func__ void computeOffset0(uint32_t *nram_offset, + int *nram_idx, + uint32_t mul_scalar, + int64_t add_scalar, + uint32_t token_count) { + __bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar, + token_count); +} + +template +__mlu_func__ void computeOffset(T_IDX *nram_token_offset, + T_IDX *nram_bias_offset, + int *nram_token_idx, + int *nram_expert_tables, + int expert_num, + int token_count, + int active_token_count, + int hidden_size, + int local_hidden_begin, + int dtype_size, + int start_expert_id, + int expert_size, + int begin_expert_acc_tokens, + bool has_bias) { + // for large tensor, convert int322int64 then do multiply and add seperately. + if (active_token_count <= 0) return; + if (has_bias) { + int *nram_bias_offset_temp = (int *)nram_token_offset; + __bang_write_zero(nram_bias_offset, active_token_count); + for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) { + __bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i], + active_token_count); + __bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp, + active_token_count); + } + __bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count); + computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size, + (T_IDX)local_hidden_begin * dtype_size, active_token_count); + } + + int64_t offset = + ((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size; + computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset, + active_token_count); +} + +template +__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) { +#if __BANG_ARCH__ > 500 + if (std::is_same::value) { + FUSE_MUL_CVT(bf16); + } else if (std::is_same::value) { + FUSE_MUL_CVT(f16); + } else if (std::is_same::value) { + __bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size); + } +#else + __bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size); + floatTo((T *)dst, (float *)dst, size); +#endif +} + +template +__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) { +#if __BANG_ARCH__ > 500 + if (std::is_same::value) { + FUSE_MULADD_CVT(bf16); + } else if (std::is_same::value) { + FUSE_MULADD_CVT(f16); + } else if (std::is_same::value) { + __bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size, + size); + } +#else + __bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size, + size); + floatTo((T *)dst, (float *)dst, size); +#endif +} + +// weightedReduceSum with EP split +// input [token_count, k, hidden_size], weight [token_count, k] +// 1. input * weight +// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add +template ::type = nullptr> +__mlu_func__ void weightedReduceSum(T *output, + T *input, + float *weight, + T *input_buffer, + int8_t *og_mask, + int topk, + int hidden_size, + int token_count, + bool &is_ping) { + float *nram_input_buffer = + (float *)((half *)input_buffer + + ((std::is_same::value || !is_ping) ? 0 : hidden_size)); + T *output_base = output - ((std::is_same::value || is_ping) ? 0 : hidden_size); + int32_t index[32]; + float reg_weight[128]; + int8_t *index_ = (int8_t *)index; + int topk_divide_4 = PAD_UP(topk, 4) / 4; + int token_use_count = 0; + for (int t_i = 0; t_i < token_count; t_i++) { + float *output_begin = (float *)(output_base + t_i * hidden_size); + for (int i = 0; i < topk_divide_4; i++) { + index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i); + float *weight_begin = weight + t_i * topk + i * 4; + reg_weight[i * 4] = __load_nram(weight_begin); + if (i * 4 + 1 < topk) { + reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1); + } + if (i * 4 + 2 < topk) { + reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2); + } + if (i * 4 + 3 < topk) { + reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3); + } + } + + int first_in_expert = 0; + float expert_coeff; + for (; first_in_expert < topk - 1; first_in_expert++) { + bool in_expert_range = index_[first_in_expert]; + if (!in_expert_range) continue; + + expert_coeff = reg_weight[first_in_expert]; + toFloat(output_begin, input + token_use_count * hidden_size, hidden_size); + __bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size); + token_use_count++; + break; + } + if (first_in_expert == topk - 1) { + if (index_[topk - 1]) { + expert_coeff = reg_weight[topk - 1]; + toFloat(nram_input_buffer, input + token_use_count * hidden_size, hidden_size); + token_use_count++; + mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size); + } else { + __bang_write_zero((T *)output_begin, hidden_size); + } + } else { + for (int j = first_in_expert + 1; j < topk - 1; j++) { + bool in_expert_range = index_[j]; + if (!in_expert_range) continue; + + expert_coeff = reg_weight[j]; + toFloat(nram_input_buffer, input + token_use_count * hidden_size, hidden_size); + token_use_count++; + __bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin, + hidden_size, hidden_size); + } + if (index_[topk - 1]) { + expert_coeff = reg_weight[topk - 1]; + toFloat(nram_input_buffer, input + token_use_count * hidden_size, hidden_size); + token_use_count++; + mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size); + } else { + floatTo((T *)output_begin, (float *)output_begin, hidden_size); + } + } + } + if (!is_ping && sizeof(T) < sizeof(float)) { + __memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size, + hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1, + token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1, + token_count * hidden_size * sizeof(T), 0); + } + is_ping = !is_ping; +} + +// weightedReduceSum without EP split +// input [token_count, k, hidden_size], weight [token_count, k] +// 1. input * weight +// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add +template ::type = nullptr> +__mlu_func__ void weightedReduceSum(T *output, + T *input, + float *weight, + T *input_buffer, + int8_t *og_mask, + int topk, + int hidden_size, + int token_count, + bool &is_ping) { + float *nram_input_buffer = + (float *)((half *)input_buffer + + ((std::is_same::value || !is_ping) ? 0 : hidden_size)); + T *output_base = output - ((std::is_same::value || is_ping) ? 0 : hidden_size); + if (topk == 1) { + for (int i = 0; i < token_count; i++) { + float expert_coeff = __load_nram(weight + i); + toFloat(nram_input_buffer, input + i * hidden_size, hidden_size); + mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size); + } + return; + } + for (int t_i = 0; t_i < token_count; t_i++) { + float *output_begin = (float *)(output_base + t_i * hidden_size); + float expert_coeff = __load_nram(weight + t_i * topk); + toFloat(output_begin, input + t_i * topk * hidden_size, hidden_size); + toFloat(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size); + __bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size); + expert_coeff = __load_nram(weight + t_i * topk + 1); + for (int k_i = 2; k_i < topk; k_i++) { + __bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin, + hidden_size, hidden_size); + expert_coeff = __load_nram(weight + t_i * topk + k_i); + toFloat(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size); + } + mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size); + } + if (!is_ping && sizeof(T) < sizeof(float)) { + __memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size, + hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1, + token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1, + token_count * hidden_size * sizeof(T), 0); + } + is_ping = !is_ping; +} + +template +__mlu_global__ void MLUCombineMoeResultKernel(T *output, + T *input, + T *bias, + T *residual, + float *reduce_weight, + int *cusum_token_count, + int *gather_idx, + int num_token, + int topk, + int num_expert, + int hidden_size, + int start_expert_id, + int expert_size, + int HIDDEN_BLOCK, + int TOKEN_BLOCK) { + if (__is_mpu()) { + return; + } + int local_hidden_begin = taskIdX * HIDDEN_BLOCK; + int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin); + int task_avg_tokens = num_token / taskDimY; + int task_remain_tokens = num_token % taskDimY; + int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens); + int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens); + if (local_hidden_size <= 0) return; + if (task_tokens <= 0) return; + + constexpr int int32_dtype_size = (int)sizeof(int); + constexpr int fp32_dtype_size = (int)sizeof(float); + int pad_num_expert = PAD_UP(num_expert + 1, 32); + bool has_bias = bias != nullptr; + bool has_residual = residual != nullptr; + bool using_acc_sum = cusum_token_count != nullptr; + bool expert_parallelism = expert_size < num_expert; + + int block_size = TOKEN_BLOCK * topk; + int pad_block_size = PAD_UP(block_size, 64); + + int *nram_expert_tables = (int *)nram_buffer; + int *nram_token_idx = nram_expert_tables + pad_num_expert; + T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size); + T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size); + int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size); + T *nram_input_ping = (T *)(nram_mask + pad_block_size); + T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK; + T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK; + T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK; + T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK; + T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK; + float *nram_weight_ping = + (float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK); + float *nram_weight_pong = nram_weight_ping + pad_block_size; + int buffer_block_num = sizeof(T) > 2 ? 2 : 3; + T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size); + T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK; + T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) + + buffer_block_num * HIDDEN_BLOCK * sizeof(half)); + int *nram_mask_buffer = (int *)nram_token_offset; + uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK); + + int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk; + int begin_expert_acc_tokens = 0; + int end_expert_acc_tokens = num_token * topk; + if (using_acc_sum) { + __memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size, + GDRAM2NRAM); + } + __memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk, + init_token_count * sizeof(int), GDRAM2NRAM); + __sync_io(); + + if (expert_parallelism) { + begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id); + end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size); + } + + int active_token_count = getMaskAndActiveTokenCount( + nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens, + end_expert_acc_tokens, init_token_count, expert_parallelism); + + computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert, + init_token_count, active_token_count, hidden_size, local_hidden_begin, + (int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias); + __sync_io_move_compute(true, false, false, false, false, true); + __sync_io_move_compute(false, false, true, true, false, false); + + int next_active_token_count = active_token_count; + int previous_global_token_begin = 0; + int previous_token_count = 0; + bool is_ping = false; + for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) { + int next_token_begin = (task_begin + 1) * TOKEN_BLOCK; + int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK; + bool is_last_loop = next_token_begin >= task_tokens; + bool is_last_2_loop = next_next_token_begin >= task_tokens; + int current_token_begin = task_begin * TOKEN_BLOCK; + int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin); + int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin); + int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin); + int current_global_token_begin = task_token_begin + current_token_begin; + int next_global_token_begin = task_token_begin + next_token_begin; + int next_next_global_token_begin = task_token_begin + next_next_token_begin; + + if (!is_last_loop) { + if (!is_last_2_loop) { + loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk, + next_next_token_count * topk * sizeof(int), 0, 0, 0); + } + loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk, + next_token_count * topk * fp32_dtype_size, 0, 0, 0); + if (has_residual) { + loadAsync2d(nram_residual_ping, + residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin, + local_hidden_size * sizeof(T), local_hidden_size * sizeof(T), + hidden_size * sizeof(T), next_token_count - 1); + } + gatherTokensAsync(nram_input_ping, input, nram_token_offset, + local_hidden_size * sizeof(T), next_active_token_count); + gatherTokensAsync(nram_bias_ping, bias, nram_bias_offset, + local_hidden_size * sizeof(T), next_active_token_count); + } + + if (task_begin >= 1) { + storeAsync2d( + output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin, + nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T), + local_hidden_size * sizeof(T), previous_token_count - 1); + } + + if (task_begin >= 0) { + if (has_bias && active_token_count) { + __bang_add(nram_input_pong, nram_input_pong, nram_bias_pong, + active_token_count * local_hidden_size); + } + if (expert_parallelism) { + weightedReduceSum(nram_output_ping, nram_input_pong, nram_weight_pong, + nram_input_buffer, (int8_t *)nram_mask_char, topk, + local_hidden_size, current_token_count, is_ping); + } else { + weightedReduceSum(nram_output_ping, nram_input_pong, nram_weight_pong, + nram_input_buffer, (int8_t *)nram_mask_char, topk, + local_hidden_size, current_token_count, is_ping); + } + if (has_residual) { + __bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong, + current_token_count * local_hidden_size); + } + } + + __sync_io_move_compute(); + active_token_count = next_active_token_count; + if (expert_parallelism && !is_last_loop) { + __bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk); + } + if (!is_last_2_loop) { + next_active_token_count = getMaskAndActiveTokenCount( + nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens, + end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism); + computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, + num_expert, next_next_token_count * topk, next_active_token_count, hidden_size, + local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size, + begin_expert_acc_tokens, has_bias); + } + + swap(nram_input_ping, nram_input_pong); + swap(nram_bias_ping, nram_bias_pong); + swap(nram_residual_ping, nram_residual_pong); + swap(nram_weight_ping, nram_weight_pong); + swap(nram_output_ping, nram_output_pong); + previous_global_token_begin = current_global_token_begin; + previous_token_count = current_token_count; + } + storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin, + nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T), + local_hidden_size * sizeof(T), previous_token_count - 1); +} + +#if __BANG_ARCH__ < 500 +template <> +__mlu_global__ void MLUCombineMoeResultKernel(bfloat16_t *output, + bfloat16_t *input, + bfloat16_t *bias, + bfloat16_t *residual, + float *reduce_weight, + int *cusum_token_count, + int *gather_ids, + int num_token, + int topk, + int num_expert, + int hidden_size, + int start_expert_id, + int expert_size, + int HIDDEN_BLOCK, + int TOKEN_BLOCK) {} + +template <> +__mlu_global__ void MLUCombineMoeResultKernel(bfloat16_t *output, + bfloat16_t *input, + bfloat16_t *bias, + bfloat16_t *residual, + float *reduce_weight, + int *cusum_token_count, + int *gather_ids, + int num_token, + int topk, + int num_expert, + int hidden_size, + int start_expert_id, + int expert_size, + int HIDDEN_BLOCK, + int TOKEN_BLOCK) {} +#endif + +} // namespace kernels +KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue, + void *output, + const void *input, + const void *bias, + const void *residual, + const float *reduce_weight, + const int *cusum_token_count, + const int *gather_idx, + int num_token, + int topk, + int num_expert, + int hidden_size, + int start_expert_id, + int expert_size, + cnnlDataType_t dtype) { + if (topk > 128 || num_expert > 1024 || hidden_size < 256) { + std::cerr << "[invokeMoeCombineResultKernel]: " + << "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256."; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + if (bias != nullptr) { + std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias."; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) { + std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, " + << "cusum_token_count can not be nullptr."; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + size_t data_bytes = 0; + cnnlGetSizeOfDataType(dtype, &data_bytes); + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + + // 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer. + // TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset. + int convert_buffer = data_bytes == 2 + ? 3 * hidden_size * data_bytes + : 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32 + int max_input_size = (432 * 1024 - convert_buffer) / + (2 * topk * data_bytes + /*input size, double buffer*/ + (bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/ + (residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/ + 2 * data_bytes); /*output size, one buffer*/ + + int TOKEN_BLOCK = 1; + int HIDDEN_BLOCK = 1; + int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64; + if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) { + HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK; + TOKEN_BLOCK = 1; + } else { + HIDDEN_BLOCK = hidden_size; + } + // for latency case, hidden_size is large but token is small. + if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) { + HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num; + } + + HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024); + uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK; + task_dim_x = + (task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num); + uint32_t pad_dim_x = task_dim_x; + while (pad_dim_x <= cluster_num * core_num) { + if ((cluster_num * core_num % pad_dim_x == 0)) { + task_dim_x = pad_dim_x; + break; + } + pad_dim_x += core_num; + } + HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x; + HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64; + if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) { + TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK; + } + TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk); + + float max_cluster_num = core_num * cluster_num / task_dim_x; + uint32_t task_dim_y = std::min(max_cluster_num, num_token); + task_dim_y = task_dim_y < 1 ? 1 : task_dim_y; + cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1}; + + bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX; + if (dtype == CNNL_DTYPE_FLOAT) { + if (!is_large_tensor) { + kernels::MLUCombineMoeResultKernel<<>>( + (float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight, + (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, + start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); + } else { + kernels::MLUCombineMoeResultKernel<<>>( + (float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight, + (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, + start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); + } + } else if (dtype == CNNL_DTYPE_HALF) { + if (!is_large_tensor) { + kernels::MLUCombineMoeResultKernel<<>>( + (half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight, + (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, + start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); + } else { + kernels::MLUCombineMoeResultKernel<<>>( + (half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight, + (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, + start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); + } + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + if (!isBf16Supported()) { + std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (!is_large_tensor) { + kernels::MLUCombineMoeResultKernel<<>>( + (bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual, + (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, + num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); + } else { + kernels::MLUCombineMoeResultKernel<<>>( + (bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual, + (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, + num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); + } + } else { + std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is " + << "among float/half/bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh new file mode 100644 index 0000000..b8b97a2 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh @@ -0,0 +1,85 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_ +#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_ + +#include "../kernel_utils.h" +#include "cnnl.h" +namespace tmo { +/** + * @brief Sort tokens grouped by different experts based on index. Each token + * selects the topk hidden vectors, multiplies them by corresponding weights, + * and finally reduces the topk vectors for each token. This process involves + * bias and residual, calculated as (x + bias) * weight + residual. + * @example + * input: + * [[[1, 2, 1, 1], + * [1, 1, 1, 2]], + * [[2, 1, 1, 1], + * [1, 1, 1, 1]]] + * num_token = 2, topk = 2 + * cusum_token_count = [0, 2, 4] + * index: + * [0, 1, 2, 3] + * weight: + * [0, 0, 1, 1] + * bias: + * [[0, 0, 0, 0], + * [1, 1, 1, 1]] + * residual: + * [[1, 1, 1, 1], + * [0, 0, 0, 0]] + * output: + * [[1, 1, 1, 1], + * [5, 4, 4, 4]] + * @param queue: The queue for mlu. + * @param output: Output. Pointer to the MLU memory that stores the result. + * The shape is [num_token, hidden_size]. + * @param input: Input. Pointer to the MLU memory that stores input tokens. + * The shape is [num_token * topk, hidden_size]. + * @param bias: Input. Pointer to the MLU memory that stores bias. + * The shape is [num_expert, hidden_size]. + * @param residual: Input. Pointer to the MLU memory that stores residual. + * The shape is [num_token, hidden_size]. + * @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight. + * The shape is [num_token * topk]. + * @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the + * token number of each expert. The shape is [num_expert + 1]. + * @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx. + * The shape is [num_token * topk]. + * @param num_token: The total number of tokens. + * @param topk: The number of expert. + * @param num_expert: The number of expert. + * @param hidden_size: The size of lowest dimension. + * @param start_expert_id: The id of the first processed expert. + * @param expert_size: The number of processed experts. + * @param dtype: Data type. + * @note Currently does not support bias. + */ +KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue, + void *output, + const void *input, + const void *bias, + const void *residual, + const float *reduce_weight, + const int *cusum_token_count, + const int *gather_idx, + int num_token, + int topk, + int num_expert, + int hidden_size, + int start_expert_id, + int expert_size, + cnnlDataType_t dtype); +} // namespace tmo + +#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu new file mode 100644 index 0000000..c349801 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu @@ -0,0 +1,219 @@ +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "expand_input.mluh" + +namespace tmo { + +namespace kernels { + +#define RESERVED_SIZE (32 * 1024) +#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE) +#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE) +#define MEMCPY_BURST_SIZE 128 + +__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE]; +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; + +// T_offset: uint32_t or uint64_t +template +__mlu_func__ void ExpandInputKernel(void *output, + void *input, + int *index, + int num_token, + int hidden_size, + int num_index) { + void *input_ptr = input; + uint64_t input_size = data_size * num_token * hidden_size; + // whether SRAM_BUFFER_SIZE can hold the input data + // sram_enable is true ==> is_nram_output is true + bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE); + if (sram_enable) { + input_ptr = (void *)sram_buffer; + __memcpy(input_ptr, input, input_size, GDRAM2SRAM); + __sync_cluster(); + } + if (__is_mpu()) { + return; + } + + // Each ipu core processes no less than 128B of data, and the remaining cores can idle + int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE; + uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1)); + uint32_t total_num = num_index; + uint32_t base = total_num / maxTaskDim; + uint32_t tail = total_num - base * maxTaskDim; + if (taskId >= maxTaskDim) { + return; + } + uint32_t batch_per_core = base + (taskId < tail ? 1 : 0); + uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail); + + // nram + /* + * first: compute offset: index[i] * data_size + * second: gather data + * ------------------------------------------------- + * addr || index/offset | output | + * type || int32_t/T_offset | T | + * num || n | n * hidden_size | + * ------------------------------------------------- + */ + + uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size; + // whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM + bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE; + uint32_t per_num = + is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset); + + int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size; + int *index_base = index + batch_step; + T_offset *nram_offset = (T_offset *)nram_buffer; + int32_t *nram_index; + if (std::is_same::value) { + nram_index = (int32_t *)nram_offset + per_num; + } else { + nram_index = (int32_t *)nram_offset; + } + int8_t *nram_output = (int8_t *)(nram_offset + per_num); + + uint32_t repeat = batch_per_core / per_num; + uint32_t remain = batch_per_core - repeat * per_num; + uint32_t deal_num = per_num; + uint32_t is_remain = remain != 0 ? 1 : 0; + + for (int32_t i = 0; i < repeat + is_remain; i++) { + if (i == repeat) { + deal_num = remain; + } + int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size; + int32_t *index_ptr = index_base + i * per_num; + + // index -> offset + __memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM); + if (std::is_same::value) { +#if __BANG_ARCH__ > 592 + __bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0); +#else + __bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num); +#endif + } + __bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num); + + // copy + if (is_nram_output) { + __bang_write_zero((int8_t *)nram_output, deal_num * hidden_size); + mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM; + // GDRAM or SRAM -> NRAM -> GDRAM +#if __BANG_ARCH__ >= 592 // gather requires + __gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir, + hidden_size * data_size, deal_num); +#else + for (int32_t j = 0; j < deal_num; j++) { + T_offset offset_value = *(nram_offset + j); + int8_t *input_offset = (int8_t *)input_ptr + offset_value; + __memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size, + dir); + } +#endif // __BANG_ARCH__ + __memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM); + } else { + // GDRAM -> GDRAM +#if __BANG_ARCH__ >= 592 // gather requires + __gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM, + hidden_size * data_size, deal_num); +#else + for (int32_t j = 0; j < deal_num; j++) { + T_offset offset_value = *(nram_offset + j); + int8_t *input_offset = (int8_t *)input + offset_value; + __memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset, + hidden_size * data_size, GDRAM2GDRAM); + } +#endif // __BANG_ARCH__ + } + } +} + +// T_offset: uint32_t or uint64_t +template +__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state, + void *hidden_state, + int *gather_idx, + int *cusum_token_count, + int num_token, + int hidden_size, + int topk, + int total_expert_num, + int start_expert_id, + int expert_count) { + int32_t num_index = num_token * topk; + int *gather_start_idx = (int *)gather_idx; + if (cusum_token_count != nullptr) { + num_index = *((int *)cusum_token_count + start_expert_id + expert_count) - + *((int *)cusum_token_count + start_expert_id); + gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id); + } + ExpandInputKernel(expand_hidden_state, hidden_state, gather_start_idx, + num_token, hidden_size, num_index); +} +// instantiate kernels +#define INSTANTIATE_ONE(data_size, T_offset) \ + template __mlu_global__ void MLUExpandInputKernel( \ + void *, void *, int *, int *, int, int, int, int, int, int); + +INSTANTIATE_ONE(1, uint32_t) +INSTANTIATE_ONE(2, uint32_t) +INSTANTIATE_ONE(4, uint32_t) +INSTANTIATE_ONE(8, uint32_t) +// large tensor +INSTANTIATE_ONE(1, uint64_t) +INSTANTIATE_ONE(2, uint64_t) +INSTANTIATE_ONE(4, uint64_t) +INSTANTIATE_ONE(8, uint64_t) +} // namespace kernels + +KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue, + void *expand_hidden_state, + const void *hidden_state, + const int *gather_idx, + const int *cusum_token_count, + int num_token, + int hidden_size, + int topk, + cnnlDataType_t data_type, + int total_expert_num, + int start_expert_id, + int expert_count) { + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + + size_t type_size = 0; + cnnlGetSizeOfDataType(data_type, &type_size); + + int max_cluster_num = + (uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE); + cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num); + + cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; + + void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = { + kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>, + kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>, + kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>, + kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>}; + + bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX; + int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0); + expand_input_kernels[kernel_index]<<>>( + (void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx, + (int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id, + expert_count); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh new file mode 100644 index 0000000..6426a78 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh @@ -0,0 +1,81 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_ +#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_ + +#include "../kernel_utils.h" +#include "cnnl.h" + +namespace tmo { + +/** + * @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count. + * @example + * hidden_state: + * [[1, 2, 3, 4], + * [5, 6, 7, 8], + * [9, 10, 11, 12]] + * gather_idx: + * [[1, 0, 2, 2, 1, 0]] + * cusum_token_count: NULL + * num_token = 3 + * hidden_size = 4 + * topk = 2 + * expand_hidden_state: + * [[5, 6, 7, 8], + * [1, 2, 3, 4], + * [9, 10, 11, 12], + * [9, 10, 11, 12], + * [5, 6, 7, 8], + * [1, 2, 3, 4]] + * @param queue: The queue for mlu. + * @param hidden_state: Input. Pointer to the MLU memory that store the input, + * the shape must be [num_token, hidden_size]. + * @param gather_idx: Input. Pointer to the MLU memory that stores the index, + * the shape must be [num_token * topk]. + * @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of + * token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The + * gather operation will be performed as follows: if cusum_token_count is not NULL: index = + * gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]] + * expand_hidden_state = hidden_state[index] + * else: + * index = gather_idx[:] + * expand_hidden_state = hidden_state[index] + * @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output, + * if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in + * which num_index = + * cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise, + * the shape should be [num_token * topk, hidden_size]. + * @param num_token: the number of token. + * @param hidden_size: the slice size. + * @param topk: the number of topk. + * @param data_type: Data type of hidden_state. + * @param total_expert_num: the total number of expert. + * @param start_expert_id: the first expert id. + * @param expert_count: the number of experts currently being processed. + */ + +KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue, + void *expand_hidden_state, + const void *hidden_state, + const int *gather_idx, + const int *cusum_token_count, + int num_token, + int hidden_size, + int topk, + cnnlDataType_t data_type, + int total_expert_num, + int start_expert_id, + int expert_count); +} // namespace tmo + +#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu new file mode 100644 index 0000000..d6ae3f4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu @@ -0,0 +1,935 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include +#include +#include +#include "gen_idx.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { + +#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024) +#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024) +#define ALIGN_16 (16) + +#define EXPERT_AVG_COUNT_TEST (0) + +__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE]; +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; +__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}; + +// Generate integer sequence data from 0 to length-1 +__mlu_func__ void generateIntSeq(int *dst, int length) { + int count = 64; + __bang_move(dst, range, std::min(count, length) * sizeof(int)); + while (count < length) { + __bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count)); + count *= 2; + } +} + +// genIdx Block kernel, use only 1 core to process +__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx, + int *gather_combine_idx, + int *token_count, + int *cusum_token_count, + const void *expert_id, + const int num_token, + const int num_expert, + const int topk) { + /* NRAM space */ + // Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int) + // -------------------------------------------------------------- + // | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result| + // | combine_idx | expand_idx | | scatter_offset | + // |num_token*topk|num_token*topk|num_token*topk| num_token*topk | + // -------------------------------------------------------------- + // ------------------------------ + // |token_count|token_count_presum| + // | | | + // | num_expert| num_expert | + // ------------------------------ + + uint32_t token_total_num = num_token * topk; + // num align to 16, size align to 64B + uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4; + + int8_t *expert_id_onchip = (int8_t *)nram_buffer; + int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int); + int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int); + int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int); + int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int); + int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int); + + int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space +#if __BANG_ARCH__ >= 592 + int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space +#endif + int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space + + // Load current core input expert_id and generate int sequence + __memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int), + GDRAM2NRAM); + generateIntSeq((int *)gen_idx_onchip, token_total_num); + __sync(); + + // Initialize sort idx offset + uint32_t sorted_idx_offset = 0; + // Initialize token count first presum with 0 + ((int *)token_count_presum_onchip)[0] = 0; + bool need_cusum_token_count = bool(cusum_token_count != nullptr); + + // Loop on each expert, eq, count, filter index + for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) { + __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, + token_total_num); + // Use filter to sort gen_idx, output with sorted_idx_offset + uint32_t cur_expert_count = + __bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip, + (float *)cur_expert_result, token_total_num); + + sorted_idx_offset += cur_expert_count; + ((int *)token_count_onchip)[cur_expert] = cur_expert_count; + + // Compute cusum token count and store + if (need_cusum_token_count) { + ((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset; + } + } + +#if EXPERT_AVG_COUNT_TEST + // NOTE: test avg expert code here: + uint32_t token_count_avg = token_total_num / num_expert; + uint32_t expert_remain_num = token_total_num % num_expert; + + for (int i = 0; i < num_expert; i++) { + ((int *)token_count_onchip)[i] = + (i < expert_remain_num) ? token_count_avg + 1 : token_count_avg; + ((int *)token_count_presum_onchip)[i + 1] = + ((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i]; + } +#endif + + __sync_compute(); + // Store token_count and cusum token count + __memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), + NRAM2GDRAM); + if (need_cusum_token_count) { + __memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip, + (num_expert + 1) * sizeof(int), NRAM2GDRAM); + } + + // Use sorted idx to generate gather idx for expand and combine +#if __BANG_ARCH__ >= 592 + // scatter_offset = sorted_idx mul_scalar sizeof(int); + __bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)), + token_total_num); +#else + // scatter dst GDRAM addr should align to 64B + int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6); + int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr); + + __bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip, + combine_idx_align_offset, (int)(sizeof(int)), token_total_num); +#endif + __sync_compute(); + +#if __BANG_ARCH__ >= 592 + // scatter_async to NRAM + __scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset, + sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num); +#endif + // expand_idx = sorted_idx div(topk) + __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num); + + // Store expand_idx and combine_idx + __sync_compute(); + __memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int), + NRAM2GDRAM); +#if __BANG_ARCH__ >= 592 + __sync_move(); + __memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip, + token_total_num * sizeof(int), NRAM2GDRAM); +#else + // 370 directly scatter to GDRAM + __scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset, + sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num); +#endif +} + +// Only MLU500 series support NRAM2SRAM scatter direction +__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) { +#if __BANG_ARCH__ >= 592 + // When length larger than 65535(maximum segnum in bang_scatter), + // and src/offset address should align to 64B + int seg_repeat = length / 32768; + int seg_remain = length % 32768; + int seg_offset = 0; + + for (int seg = 0; seg < seg_repeat; seg++) { + __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, + sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768); + seg_offset += 32768; + } + + if (seg_remain > 0) { + __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, + sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain); + } +#endif +} + +// Scatter sequence, transfer size is sizeof(int) +__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) { + // When length larger than 65535(maximum segnum in bang_scatter), + // and src/offset address should align to 64B + int seg_repeat = length / 32768; + int seg_remain = length % 32768; + int seg_offset = 0; + + for (int seg = 0; seg < seg_repeat; seg++) { +#if __BANG_ARCH__ >= 592 + __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, + sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768); +#else + __scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), + NRAM2GDRAM, sizeof(int), (unsigned short)32768); +#endif + seg_offset += 32768; + } + if (seg_remain > 0) { +#if __BANG_ARCH__ >= 592 + __scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, + sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain); +#else + __scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int), + NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain); +#endif + } +} + +// 1. Get token count +__mlu_func__ void getTokenCount(int *token_count, + int *expert_id, + int token_cur_core, + int cur_token_start, + int num_expert) { + // 1. Partition on [num_token*topk], + // each core for-loop on all expert_id, use eq and count instructions, + // use AtomicAdd to accumulate all expert_id token counts, on GDRAM. + // And sync for all cores. + // NRAM: + // ------------------------------------------------------ + // |expert_id_onchip|cur_expert_result|expert_count_onchip| + // | deal_num | deal_num | num_expert | + // ------------------------------------------------------ + + uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2; + int8_t *expert_id_onchip = (int8_t *)nram_buffer; + int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int); + int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int); + + // Current core data loop + uint32_t repeat = token_cur_core / deal_num; + uint32_t remain = token_cur_core % deal_num; + uint32_t total_repeat = repeat + (int)(remain > 0); + uint32_t token_addr_offset = cur_token_start; + + // Initialize token_count with 0 + if (taskId == 0) { + __gdramset((int *)token_count, num_expert, 0); + } + // Sync for initialize token_count + __sync_all_ipu(); + + // Initialize expert count onchip with 0 + if (token_cur_core > 0) { + __bang_write_zero((int *)expert_count_onchip, num_expert); + } + + // actual num in loop + int cur_deal_num = deal_num; + for (int i = 0; i < total_repeat; i++) { + if (i == total_repeat - 1 && remain > 0) { + cur_deal_num = remain; + } + // Load current core input expert_id + __memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset, + cur_deal_num * sizeof(int), GDRAM2NRAM); + token_addr_offset += cur_deal_num; + + // Loop on each expert, eq, count + for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) { + __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num); + // NOTE: __bang_count() only support floating data type + uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num); + ((int *)expert_count_onchip)[cur_expert] += cur_expert_count; + } + } + + // AtomicAdd(reduce) all cores token count results + if (token_cur_core > 0) { + __bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert); + } + // Sync for all cores, get accumulate of token_count + __sync_all_ipu(); +} + +// 2. Get token count presum, for each expert index start address after sorting +__mlu_func__ void getTokenCountPresum(int *token_count_presum, + int *token_count, + const int num_expert) { + // 2. After first process, already get token_count. + // Then use one core to pre-sum on token_count, consider size of int32, + // first expert id start address should be zero. + // to get each expert id start address after sorting, store to workspace, + // token_count_presum. + // And sync for all cores. + // NRAM: + // load token_count to token_count_presum[1~num_expert+1], + // for i = 0 to num_expert: + // token_count_presum[i+1] += token_count_presum[i] + // store token_count_presum[0~num_expert] + // ------------------------- + // |token_count_presum_onchip| + // | {0}, num_expert | + // ------------------------- + + if (taskId == 0) { + // Initialize count presum onchip with a first 0 + int8_t *token_count_presum_onchip = nram_buffer; + ((int *)token_count_presum_onchip)[0] = 0; + // Load token_count with an offset of 1 + __memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int), + GDRAM2NRAM); + + // Calculate presum of token count by each expert + for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) { + ((int *)token_count_presum_onchip)[cur_expert + 1] += + ((int *)token_count_presum_onchip)[cur_expert]; + } + + // Store token count presum to workspace + __memcpy((int *)token_count_presum, (int *)token_count_presum_onchip, + (num_expert + 1) * sizeof(int), NRAM2GDRAM); + } + // Sync for all cores, get presum of token count + __sync_all_ipu(); +} + +__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum, + int *token_count, + const uint32_t token_total_num, + const int num_expert) { + uint32_t token_count_avg = token_total_num / num_expert; + uint32_t expert_remain_num = token_total_num % num_expert; + + int8_t *token_count_onchip = nram_buffer; + int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int); + + ((int *)token_count_presum_onchip)[0] = 0; + + for (int i = 0; i < num_expert; i++) { + ((int *)token_count_onchip)[i] = + (i < expert_remain_num) ? token_count_avg + 1 : token_count_avg; + ((int *)token_count_presum_onchip)[i + 1] = + ((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i]; + } + + __memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM); + __memcpy((int *)token_count_presum, (int *)token_count_presum_onchip, + (num_expert + 1) * sizeof(int), NRAM2GDRAM); +} + +// 3. Get expert position index after sorting +__mlu_func__ void getSortedIdx(int *sorted_idx, + int *expert_id, + int *token_count_presum, + const int token_total_num, + const int num_expert, + const int expert_cur_core, + const int cur_expert_start, + const int cur_expert_end) { + // 3. Partition on num_expert, each core generate position index from 0, + // and for-loop on all expert_id data, use eq with own each expert_id, + // and filter on index, stores to each expert_id start address of + // sorted_idx on workspace. + // And sync for all cores. + // NRAM: + // ------------------------------------------------------------------- + // |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip| + // | deal_num | deal_num | deal_num | deal_num | + // ------------------------------------------------------------------- + // |expert_start_addr| + // | num_expert | + // ----------------- + + // Calculate new deal_num of sorting process + int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4; + + // Each core deal with whole token expert_id data + int repeat = token_total_num / deal_num; + int remain = token_total_num % deal_num; + int token_addr_offset = 0; + + int8_t *expert_id_onchip = nram_buffer; + int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int); + int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int); + int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int); + int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int); + + // When num_expert < taskDim, not all cores need to sort + if (expert_cur_core > 0) { + // Generate position index from 0 + if (deal_num <= token_total_num) { + generateIntSeq((int *)gen_idx_onchip, deal_num); + } else { // only remainder part + generateIntSeq((int *)gen_idx_onchip, token_total_num); + } + + // Initialize expert start address with presum of token count + __memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int), + GDRAM2NRAM); + + // repeat part + for (int i = 0; i < repeat; i++) { + // Load current core expert_id + __memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset, + deal_num * sizeof(int), GDRAM2NRAM); + token_addr_offset += deal_num; + + // Loop for current core expert, eq, filter position index + // use filter, store to sorted_idx[expert_start_addr] + for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) { + __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num); + + int cur_expert_offset = ((int *)expert_start_addr)[cur_expert]; + + // NOTE: __bang_filter() only support floating data type + uint32_t cur_expert_count = + __bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip, + (float *)cur_expert_result, deal_num); + + // Store to the corresponding address of sorted_idx + if (cur_expert_count > 0) { + __memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip, + cur_expert_count * sizeof(int), NRAM2GDRAM); + + // Update address offset of current expert + ((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count; + } + } + + // Update position index for each data loop + __bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num); + } + + // remainder part + if (remain > 0) { + __memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset, + remain * sizeof(int), GDRAM2NRAM); + + for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) { + __bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain); + + int cur_expert_offset = ((int *)expert_start_addr)[cur_expert]; + + // NOTE: __bang_filter() only support floating data type + uint32_t cur_expert_count = + __bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip, + (float *)cur_expert_result, remain); + // Store to the corresponding address of sorted_idx + if (cur_expert_count > 0) { + __memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip, + cur_expert_count * sizeof(int), NRAM2GDRAM); + } + } + } + } + + // Sync for all cores, get position index after sorting + __sync_all_ipu(); +} + +// 4. Get gather index for expand and combine +template +__mlu_func__ void getGatherIdx(int *gather_expand_idx, + int *gather_combine_idx, + int *sorted_idx, + const int token_cur_core, + const int cur_token_start, + const int topk) { + // 4. Partition on [num_token*topk], + // load sorted_idx onchip, + // generate sequence according to position index from 0, add token offset + // gather_combine_idx = scatter(seq, sorted_idx) + // gather_expand_idx = sorted_idx / topk + // update sequence + // NRAM: + // ------------------------------------------------------------------- + // |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence| + // | deal_num | deal_num | deal_num | deal_num | + // ------------------------------------------------------------------- + + // Calculate new deal_num of generate gather index + // NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints + int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4; + int repeat = token_cur_core / deal_num; + int remain = token_cur_core % deal_num; + int token_addr_offset = cur_token_start; + + // scatter dst GDRAM addr should align to 64B + int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6); + int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr); + + int8_t *sorted_idx_onchip = nram_buffer; + int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int); + int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int); + int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int); + + // Generate position index from 0 + // Add base offset to sequence according to current core token start address + if (token_cur_core > 0) { + if (deal_num <= token_cur_core) { + generateIntSeq((int *)scatter_sequence, deal_num); + __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, + deal_num); + } else { // only remainder part + generateIntSeq((int *)scatter_sequence, token_cur_core); + __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, + token_cur_core); + } + } + + // repeat part + for (int i = 0; i < repeat; i++) { + // Load current core sorted_idx + __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, + deal_num * sizeof(int), GDRAM2NRAM); + + // offset = sorted_idx * sizeof(int), counted in bytes + if (is_sram_scatter) { + __bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)), + deal_num); + } else { + // GDRAM addr should align to 64B + __bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip, + combine_idx_align_offset, (int)(sizeof(int)), deal_num); + } + // Sync for scatter + __sync_compute(); + + if (is_sram_scatter) { + scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, + deal_num); + } else { + // Scatter to output gather_combine_idx + scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence, + (uint32_t *)scatter_offset, deal_num); + } + + // expand_idx_onchip = sorted_idx / topk + __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num); + // Store expand idx + __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, + deal_num * sizeof(int), NRAM2GDRAM); + if (is_sram_scatter) { + // if scatter to SRAM, need to sync compute with mv + __sync_move(); + } + // Add offset to sequence and token_address + __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num); + token_addr_offset += deal_num; + } + + // remainder part + if (remain > 0) { + // Load current core sorted_idx + __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, + remain * sizeof(int), GDRAM2NRAM); + + // offset = sorted_idx * sizeof(int), counted in bytes + if (is_sram_scatter) { + __bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)), + remain); + } else { + // GDRAM addr should align to 64B + __bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip, + combine_idx_align_offset, (int)(sizeof(int)), remain); + } + + // Sync for scatter + __sync_compute(); + + if (is_sram_scatter) { + scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, + remain); + } else { + // Scatter to output gather_combine_idx + scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence, + (uint32_t *)scatter_offset, remain); + } + + // expand_idx_onchip = sorted_idx / topk + __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain); + // Store expand idx + __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, + remain * sizeof(int), NRAM2GDRAM); + } +} + +// 4.1 Get gather combine index on SRAM +__mlu_func__ void getCombineIdxSram(int *sorted_idx, + const int token_cur_core, + const int cur_token_start) { + // 4.1 Partition on [num_token*topk], with only 1 union + // load sorted_idx onchip, + // generate sequence according to position index from 0, add token offset + // gather_combine_idx = scatter(seq, sorted_idx) + // update sequence + // NRAM: + // ------------------------------- + // |scatter_offset|scatter_sequence| + // | deal_num | deal_num | + // ------------------------------- + + // Calculate new deal_num of generate gather index + // NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints + int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2; + int repeat = token_cur_core / deal_num; + int remain = token_cur_core % deal_num; + int token_addr_offset = cur_token_start; + + int8_t *scatter_offset = nram_buffer; + int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int); + + // Generate position index from 0 + // Add base offset to sequence according to current core token start address + if (token_cur_core > 0) { + if (deal_num <= token_cur_core) { + generateIntSeq((int *)scatter_sequence, deal_num); + __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, + deal_num); + } else { // only remainder part + generateIntSeq((int *)scatter_sequence, token_cur_core); + __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset, + token_cur_core); + } + } + + // repeat part + for (int i = 0; i < repeat; i++) { + // Load current core sorted_idx + __memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int), + GDRAM2NRAM); + + // offset = sorted_idx * sizeof(int), counted in bytes + __bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num); + // Sync for scatter + __sync_compute(); + + // Scatter to SRAM + scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, + deal_num); + __sync_move(); + + // Add offset to sequence and token_address + __bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num); + token_addr_offset += deal_num; + } + + // remainder part + if (remain > 0) { + // Load current core sorted_idx + __memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int), + GDRAM2NRAM); + + // offset = sorted_idx * sizeof(int), counted in bytes + __bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain); + // Sync for scatter + __sync_compute(); + + scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain); + } +} + +// 4.2 Get gather expand index +__mlu_func__ void getExpandIdx(int *gather_expand_idx, + int *sorted_idx, + const int token_cur_core, + const int cur_token_start, + const int topk) { + // 4.2 Partition on [num_token*topk], + // load sorted_idx onchip, + // gather_expand_idx = sorted_idx / topk + // NRAM: + // ----------------------------------- + // |sorted_idx_onchip|expand_idx_onchip| + // | deal_num | deal_num | + // ----------------------------------- + + // Calculate new deal_num of generate gather index + int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2; + int repeat = token_cur_core / deal_num; + int remain = token_cur_core % deal_num; + int token_addr_offset = cur_token_start; + + int8_t *sorted_idx_onchip = nram_buffer; + int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int); + + // repeat part + for (int i = 0; i < repeat; i++) { + // Load current core sorted_idx + __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, + deal_num * sizeof(int), GDRAM2NRAM); + + // expand_idx_onchip = sorted_idx / topk + __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num); + // Store expand idx + __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, + deal_num * sizeof(int), NRAM2GDRAM); + token_addr_offset += deal_num; + } + + // remainder part + if (remain > 0) { + // Load current core sorted_idx + __memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset, + remain * sizeof(int), GDRAM2NRAM); + // expand_idx_onchip = sorted_idx / topk + __bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain); + // Store expand idx + __memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip, + remain * sizeof(int), NRAM2GDRAM); + } +} + +__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx, + int *gather_combine_idx, + int *token_count, + int *cusum_token_count, + void *workspace, + const void *expert_id, + const int num_token, + const int num_expert, + const int topk) { + // Store token count presum result, shape [num_expert + 1] + int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace; + // Store position index after sorting, shape [num_token*topk] + int *sorted_idx = ((int *)workspace) + num_expert + 1; + + // Calculate partition information for different processes + // Partition on [num_token*topk] + uint32_t token_total_num = num_token * topk; + uint32_t token_cur_core = token_total_num / taskDim; + uint32_t token_remain_num = token_total_num % taskDim; + token_cur_core += (uint32_t)(taskId < token_remain_num); + // Current core range according to partition on [num_token*topk] + uint32_t cur_token_start = (taskId < token_remain_num) + ? token_cur_core * taskId + : token_cur_core * taskId + token_remain_num; + + // Partition on [num_expert] + uint32_t expert_cur_core = num_expert / taskDim; + uint32_t expert_remain_num = num_expert % taskDim; + expert_cur_core += (uint32_t)(taskId < expert_remain_num); + // Current core range according to partition on [num_expert] + uint32_t cur_expert_start = (taskId < expert_remain_num) + ? expert_cur_core * taskId + : expert_cur_core * taskId + expert_remain_num; + uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1; + + // Use Union1 SRAM to scatter, only MLU500 series support now +#if __BANG_ARCH__ >= 592 + bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE; +#else + bool is_sram_scatter = false; +#endif + + if (__is_ipu()) { + // 1. Get token count + getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start, + num_expert); + // 2. Get presum of token count + getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert); + + // 3. Get expert position index after sorting + getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num, + num_expert, expert_cur_core, cur_expert_start, cur_expert_end); + } + +#if EXPERT_AVG_COUNT_TEST + // NOTE: test avg expert code here: + if (__is_ipu() && taskId == 0) { + modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num, + num_expert); + } + __sync_cluster(); +#endif + + // 4. Get gather index for expand and combine + if (is_sram_scatter) { + // Only use Union1 SRAM + uint32_t scatter_idx_cur_core = token_total_num / 4; + uint32_t scatter_idx_remain_num = token_total_num % 4; + scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num); + uint32_t cur_idx_start = (taskId < scatter_idx_remain_num) + ? scatter_idx_cur_core * taskId + : scatter_idx_cur_core * taskId + scatter_idx_remain_num; + + // Only Union1 task type, + // deal once num is same with deal_num in getGatherIdx, + // which means only 1 repeat to generate both expand and combine idx on NRAM + const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4; + if (taskDim <= 4 || token_total_num < deal_once_num) { + if (taskId < 4) { + if (__is_ipu()) { + getGatherIdx((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx, + scatter_idx_cur_core, cur_idx_start, topk); + // sync for ipu and mpu + __sync_cluster(); + } else { + // sync for ipu and mpu + __sync_cluster(); + __memcpy_async((int *)gather_combine_idx, (int *)sram_buffer, + token_total_num * sizeof(int), SRAM2GDRAM); + } + } + } else { + // If taskDim > 4, use first union to generate combine idx, + // use other union to generate expand idx + if (taskId < 4) { + if (__is_ipu()) { + // Scatter combine idx to SRAM + getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start); + __sync_cluster(); + } else { + __sync_cluster(); + __memcpy_async((int *)gather_combine_idx, (int *)sram_buffer, + token_total_num * sizeof(int), SRAM2GDRAM); + } + } else { + // Other union generate expand idx + if (__is_ipu()) { + uint32_t expand_dim = taskDim - 4; + uint32_t expand_id = taskId - 4; + uint32_t expand_token_cur_core = token_total_num / expand_dim; + uint32_t expand_token_remain_num = token_total_num % expand_dim; + expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num); + + uint32_t expand_cur_token_start = + (expand_id < expand_token_remain_num) + ? expand_token_cur_core * expand_id + : expand_token_cur_core * expand_id + expand_token_remain_num; + getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core, + expand_cur_token_start, topk); + } + } + } + } else { + // not use SRAM to generate both expand and combine idx + if (__is_ipu()) { + getGatherIdx((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx, + token_cur_core, cur_token_start, topk); + } + } + + // step 5 does not need MPU + if (__is_mpu()) { + return; + } +} // end of kernel + +} // namespace kernels + +KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue, + int *gather_expand_idx, + int *gather_combine_idx, + int *token_count, + int *cusum_token_count, + void *workspace, + const void *expert_id, + const int num_token, + const int num_expert, + const int topk) { + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + + const int token_total_num = num_token * topk; + + // For partition on num_token*topk, single core processes at least 128 num + const int single_core_num_limit = 1024; + int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit); + // When partition on num_expert, each core at least processes one expert + need_core_num = std::max(num_expert, need_core_num); + + // When consider UnionX cnrt func type, reset cluster_num + if (token_total_num <= 4096) { // Block + cnrtFunctionType_t k_type = cnrtFuncTypeBlock; + cnrtDim3_t k_dim{1, 1, 1}; + // Block kernel does not need workspace + kernels::launchMoeGenIdxBlockKernel<<>>( + gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token, + num_expert, topk); + + return KernelStatus::KERNEL_STATUS_SUCCESS; + } else if (need_core_num <= 4) { // Union1 + cluster_num = 1; + } else if (need_core_num <= 8) { // Union2 + cluster_num = std::min(cluster_num, 2); + } else if (need_core_num <= 16) { // Union4 + cluster_num = std::min(cluster_num, 4); + } else if (need_core_num <= 32) { // Union8 + cluster_num = std::min(cluster_num, 8); + } + + cnrtFunctionType_t k_type; + cnrtDim3_t k_dim{1, 1, 1}; + + // Find max UnionX cnrt func type + if (cluster_num == 1) { + k_type = cnrtFuncTypeUnion1; + k_dim.x = 4; + } else if (cluster_num < 4) { // cluster num is 2 or 3 + k_type = cnrtFuncTypeUnion2; + k_dim.x = 8; + } else if (cluster_num < 8) { // cluster num is 4,5,6,7 + k_type = cnrtFuncTypeUnion4; + k_dim.x = 16; + } else { // cluster num larger than 8 + k_type = cnrtFuncTypeUnion8; + k_dim.x = 32; + } + + // The expert_id is int data type + kernels::launchMoeGenIdxKernel<<>>( + gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id, + num_token, num_expert, topk); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +#undef EXPERT_AVG_COUNT_TEST // undef test macro + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh new file mode 100644 index 0000000..b71be28 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh @@ -0,0 +1,58 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_ +#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_ + +#include +#include "../kernel_utils.h" +#include "cnnl.h" +namespace tmo { +/** + * @brief Apply generate MOE index operation, which performs the following + * tasks: + * - 1. Generate gather_expand_idx and gather_combine_idx. + * - 2. Output token_count, the token number of each expert. + * - 3. Prepare inputs and outputs address for group_gemm. + * @param queue: The queue of mlu. + * @param gather_expand_idx: Output. Pointer to the MLU memory that stores the + * gather index for expand hidden state operation, the shape must be + * [num_token * topk]. + * @param gather_combine_idx: Output. Pointer to the MLU memory that stores the + * gather index for combine MOE operation, the shape must be + * [num_token * topk]. + * @param token_count: Output. Pointer to the MLU memory that stores the token + * number of each expert, the shape must be [num_expert]. + * @param cusum_token_count: Output. Pointer to the MLU memory that stores the + * cumulative sum of the token number of each expert, the shape must be + * [num_expert + 1]. It can be set to nullptr if don't need cusum output. + * @param workspace: Input. A pointer to the extra workspace required in the + * operation, the size must be larger than + * (num_expert + 1 + num_token * topk) multiplied by the size of uint32. + * @param expert_id: Input. Pointer to the MLU memory that stores the expert id + * of each token, the shape must be [num_token, topk]. + * @param num_token: The number of tokens. + * @param num_expert: The number of experts. + * @param topk: The number of expert selected by each token. + */ +KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue, + int *gather_expand_idx, + int *gather_combine_idx, + int *token_count, + int *cusum_token_count, + void *workspace, + const void *expert_id, + const int num_token, + const int num_expert, + const int topk); +} // namespace tmo + +#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh new file mode 100644 index 0000000..1839fa0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_MOE_MOE_MLUH_ +#define CSRC_KERNELS_MOE_MOE_MLUH_ + +#include "add_bias_activation.mluh" +#include "combine_result.mluh" +#include "expand_input.mluh" +#include "gen_idx.mluh" +#include "softmax_topk.mluh" + +#endif // CSRC_KERNELS_MOE_MOE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu new file mode 100644 index 0000000..c7f6fa1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu @@ -0,0 +1,602 @@ +#include +#include +#include +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "softmax_topk.mluh" + +namespace tmo { + +namespace kernels { +#define SCATTER_ALIGN (64) // align for __scatter() + +#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) +#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024) +#define TILING_ALIGN (64) +#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0)) +__nram__ int8_t nram_buffer[NRAM_SIZE]; +__mlu_shared__ int8_t sram_buffer[SRAM_SIZE]; + +#define __TRANS_TILING(TYPE, CONVERT) \ + __asm__ volatile("trans.tiling." TYPE \ + " [%[dst]], [%[src]]," \ + "%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \ + "%[is4], %[in5], %[is5]," \ + "%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \ + "%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \ + [src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \ + [is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \ + [in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \ + [dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \ + [ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5)); + +template +__mlu_func__ void __mlvm_trans(DST_DTYPE *dst, + const SRC_DTYPE *src, + const uint32_t in0, + const uint32_t in1, + const uint32_t is1, + const uint32_t in2, + const uint32_t is2, + const uint32_t in3, + const uint32_t is3, + const uint32_t in4, + const uint32_t is4, + const uint32_t in5, + const uint32_t is5, + const uint32_t dn0, + const uint32_t dn1, + const uint32_t ds1, + const uint32_t dn2, + const uint32_t ds2, + const uint32_t dn3, + const uint32_t ds3, + const uint32_t dn4, + const uint32_t ds4, + const uint32_t dn5, + const uint32_t ds5) { + if (SRAM2NRAM == dir && std::is_same::value) { + if (std::is_same::value) { + __TRANS_TILING("nram.sram.b32", ";") + } else if (std::is_same::value) { + __TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();") +#if __BANG_ARCH__ >= 500 + } else if (std::is_same::value) { + __TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();") +#endif + } + } +} + +/* 将shape为[h,w]的数据转置为[w,h](带转数),分4块分别进行处理。 + * dst: dst地址 + * src: src地址 + * h: h方向大小 + * w: w方向大小 + */ +template +__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) { + uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE); + uint32_t w_align = w / align_num; + uint32_t w_rem = w % align_num; + uint32_t h_align = h / align_num; + uint32_t h_rem = h % align_num; + uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN; + uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE); + uint32_t in3 = w_align, is3 = TILING_ALIGN; + uint32_t in4 = h_align, is4 = w * TILING_ALIGN; + uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE); + uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE); + uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE); + /* 1. h_align * w_align */ + if (w_align > 0 && h_align > 0) { + __mlvm_trans(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0, + dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0); + } + /* 2. h_align * w_rem */ + if (w_rem > 0 && h_align > 0) { + SRC_DTYPE *src_temp = src + w_align * align_num; + DST_DTYPE *dst_temp = dst + w_align * align_num * h; + in0 = w_rem * sizeof(SRC_DTYPE); + dn0 = TILING_ALIGN; + in1 = align_num; + is1 = w * sizeof(SRC_DTYPE); + in4 = h_align; + is4 = w * TILING_ALIGN; + dn1 = w_rem; + ds1 = h * sizeof(DST_DTYPE); + dn4 = in4; + ds4 = align_num * sizeof(DST_DTYPE); + __mlvm_trans(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4, + 1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0); + } + /* 3. h_rem * w_align */ + if (w_align > 0 && h_rem > 0) { + SRC_DTYPE *src_temp = src + h_align * align_num * w; + DST_DTYPE *dst_temp = dst + h_align * align_num; + in0 = TILING_ALIGN; + dn0 = h_rem * sizeof(SRC_DTYPE); + in1 = h_rem; + is1 = w * sizeof(SRC_DTYPE); + in4 = w_align; + is4 = TILING_ALIGN; + dn1 = align_num; + ds1 = h * sizeof(DST_DTYPE); + dn4 = in4; + ds4 = h * align_num * sizeof(DST_DTYPE); + __mlvm_trans(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4, + 1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0); + } + /* 4. h_rem * w_rem */ + if (w_rem > 0 && h_rem > 0) { + SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num; + DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num; + in0 = w_rem * sizeof(SRC_DTYPE); + dn0 = h_rem * sizeof(SRC_DTYPE); + in1 = h_rem; + is1 = w * sizeof(SRC_DTYPE); + dn1 = w_rem; + ds1 = h * sizeof(DST_DTYPE); + __mlvm_trans(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1, + 0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0); + } +} + +__mlu_func__ void getTopk(float *value_buffer, + uint32_t *index_buffer, + float *src_buffer, + float *compute_buffer, + float *max_buffer, + float *temp_buffer, + uint32_t *i_buffer, + uint32_t *col_buffer, + uint32_t topk, + uint32_t num_expert_group, + uint32_t col, + uint32_t row, + uint32_t value_index_stride, + uint32_t group_size, + bool is_deal_group) { + __bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector + for (int k = 0; k < topk; k++) { + if (is_deal_group) { + __bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group, + 1, num_expert_group, 1, 1); + __bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col, + col); + } else { + __bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1, + value_index_stride); + __bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col); + } +#if __BANG_ARCH__ >= 592 + __bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte + __scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t), + col); // replace max value with -inf +#else + for (int i = 0; i < col; i++) { + uint32_t index = __load_nram(col_buffer + i); + max_buffer[index] = -INFINITY; + } +#endif +#if __BANG_ARCH__ < 500 + if (is_deal_group) { + for (int i = 0; i < col; i++) { + uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i); + __memcpy(compute_buffer + i * row + index * group_size, + src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM); + } + } +#endif + } +#if __BANG_ARCH__ >= 592 + if (is_deal_group) { + __bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col); + __bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col); + __bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0, + topk - 1); + __bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col); + __bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float), + (uint32_t *)compute_buffer, col * topk, col * topk); + __gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float), + NRAM2NRAM, group_size * sizeof(float), col * topk); + __bang_write_value(src_buffer, row * col, -INFINITY); + __scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM, + group_size * sizeof(float), col * topk); + } +#endif +} + +template +__mlu_func__ void computeSoftmaxTopk(T *sram_buffer, + T *load_buffer, + float *src_buffer, + float *compute_buffer, + float *group_max_buffer, + float *nramout_value, + uint32_t *nramout_index, + uint32_t *i_buffer, + uint32_t *col_buffer, + float *softmax_buffer, + uint32_t row, + uint32_t nram_compute_col_num, + uint32_t mask_num, + uint32_t nram_max_col_num, + uint32_t topk, + int num_expert_group, + uint32_t topk_group, + uint32_t top_num, + uint32_t nram_col_offset, + int normalize_mode, + bool valid_mask, + bool split_mask) { + uint32_t nram_compute_num = nram_compute_col_num * row; + // convert to float for half/bf16 datatype + if (std::is_same::value) { + __bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num); + } else if (std::is_same::value) { + __bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num); + } + // transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool. + __bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row); + + // compute softmax + int tmp = 0x3fb8aa3b; + float log2e = *(float *)&tmp; // for exp + // src_buffer reuse as buffer for max/sum. + __bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max + __bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num, + nram_compute_col_num); + __bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max) + __bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum + __bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum + __bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num, + nram_compute_col_num); + __sync_cluster(); + // move mask and compute + if (valid_mask) { + if (!split_mask) { + __bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num); + if (std::is_same::value) { + __memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T), + SRAM2NRAM); + __bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row, + mask_num * row); + } else if (std::is_same::value) { + __memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer, + mask_num * row * sizeof(T), SRAM2NRAM); + __bang_bfloat162float((float *)compute_buffer, + (bfloat16_t *)compute_buffer + mask_num * row, mask_num * row); + } else { + __memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM); + } + __bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row, + mask_num * row); + __bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row); + } else { + transhw2wh(src_buffer, sram_buffer + nram_col_offset * row, + nram_compute_col_num, row); + __sync(); + __bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row); + } + } + if (normalize_mode == 2) { + __bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); + } + + if (num_expert_group <= 1) { + // num_expert_group <= 1, maintain original topk calculation logic + getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer, + i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row, + nram_max_col_num * topk * sizeof(float), 0, false); + } else { + // num_expert_group > 1, use grouped_topk calculation logic + uint32_t group_size = row / num_expert_group; + __bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num); + __bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group, + group_size, 1, group_size, 1, 1); + __bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY); + // get topk_group + getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer, + (float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group, + nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true); + // get topk +#if __BANG_ARCH__ < 500 + __bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row); + getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer, + i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row, + nram_max_col_num * top_num * sizeof(float), 0, false); +#else + __bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row); + getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer, + i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row, + nram_max_col_num * top_num * sizeof(float), 0, false); +#endif + } // end else + + // normalize result + if (normalize_mode == 1) { + // compute_buffer reuse as buffer for sum. + __bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1); + __bang_recip(compute_buffer, compute_buffer, nram_compute_col_num); + __bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num, + nram_compute_col_num); + } else if (normalize_mode == 2) { + __bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num); + __bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num, + nram_compute_col_num); + } + + // transpose back. src and dst of transpose can not be the same address. + __bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num); + __bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num); +} + +template +__mlu_global__ void MLUSoftmaxTopkKernel(T *input, + T *mask, + int *index_out, + float *value_out, + int col, + int row, + int mask_num, + int topk, + int num_expert_group, + int topk_group, + int normalize_mode) { + bool valid_mask = (mask != nullptr); + int top_num = topk >= topk_group ? topk : topk_group; + uint32_t nram_low_space = + PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float), + SCATTER_ALIGN); + if (num_expert_group <= 1) { + nram_low_space = + PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN); + } + uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space; + if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) { + nram_max_col_num = col / taskDim + (col % taskDim > 0); + } + nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float)); + if (nram_max_col_num <= 0) { + nram_max_col_num = SCATTER_ALIGN / sizeof(float); + } + uint32_t nram_deal_num = nram_max_col_num * row; + uint32_t batch = col / mask_num; + + // nram split: + // |--------------------------|--------------------------|--------------------|... + // | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|... + // | src_buffer | compute_buffer | group_max_buffer |... + // |--------------------------|--------------------------|--------------------|... + + // |----------------------------------------|---------------|--------------| + // | nram_col_num*3 | col*topk | col*topk | + // | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index| + // |----------------------------------------|---------------|--------------| + float *src_buffer = (float *)nram_buffer; + float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float)); + float *group_max_buffer = compute_buffer + nram_deal_num; + uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num; + if (num_expert_group <= 1) { + i_buffer = (uint32_t *)group_max_buffer; + } + uint32_t *col_buffer = i_buffer + nram_max_col_num; + float *softmax_buffer = (float *)col_buffer + nram_max_col_num; + if (normalize_mode != 2) { + softmax_buffer = (float *)col_buffer; + } + + float *nramout_value = softmax_buffer + nram_max_col_num; + uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num; + if (num_expert_group <= 1) { + nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num; + } + T *load_buffer = (T *)src_buffer; + if (std::is_same::value || std::is_same::value) { + load_buffer = load_buffer + nram_deal_num; + } + + // set i_buffer + for (uint32_t i = 0; i < nram_max_col_num; i++) { + i_buffer[i] = i; + } + + // input[batch, mask, low], mask[mask, low] + if (nram_max_col_num >= mask_num) { // nram can deal complete mask + bool split_mask = false; + uint32_t batch_seg = nram_max_col_num / mask_num; + uint32_t batch_rem = batch % batch_seg; + uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0); + int repeat = DIV_UP(batch_seg_num, taskDim); + for (int i = 0; i < repeat; i++) { + uint32_t seg_id = i * taskDim + taskId; + uint32_t sram_load_num = mask_num * row; + uint32_t sram_load_offset = 0; + uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0) + ? batch_rem * mask_num + : batch_seg * mask_num; + uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0; + uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0; + uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row; + uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk; + + // Load + if (valid_mask) { + __memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM); + } + if (nram_load_num > 0) { + __memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM); + } + + // Compute + computeSoftmaxTopk((T *)sram_buffer, load_buffer, src_buffer, compute_buffer, + group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer, + softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num, + topk, num_expert_group, topk_group, top_num, 0, normalize_mode, + valid_mask, split_mask); + + // Store + if (nram_store_num > 0) { + __memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float), + NRAM2GDRAM); + __memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int), + NRAM2GDRAM); + } + __sync_cluster(); + } + } else { + bool split_mask = true; + uint32_t mask_seg = nram_max_col_num; + uint32_t mask_rem = mask_num % mask_seg; + uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0); + uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim); + uint32_t sram_mask_rem = mask_num % sram_mask_seg_num; + uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num; + for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) { + uint32_t batch_idx = i / sram_mask_seg_num; + uint32_t mask_idx = i % sram_mask_seg_num; + uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem); + uint32_t sram_load_num = sram_deal_mask_num * row; + uint32_t sram_mask_offset = mask_idx < sram_mask_rem + ? mask_idx * (sram_average_mask_num + 1) + : mask_idx * sram_average_mask_num + sram_mask_rem; + uint32_t sram_load_offset = sram_mask_offset * row; + uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX; + uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX; + uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem); + uint32_t nram_load_num = nram_deal_mask_num * row; + uint32_t nram_col_offset = taskIdX < nram_mask_rem + ? taskIdX * (nram_average_mask_num + 1) + : taskIdX * nram_average_mask_num + nram_mask_rem; + uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row; + uint32_t nram_store_num = nram_deal_mask_num * topk; + uint32_t nram_store_offset = + (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk; + // Load + if (valid_mask) { + __memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM); + } + if (nram_load_num > 0) { + __memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM); + } + + // Compute + computeSoftmaxTopk((T *)sram_buffer, load_buffer, src_buffer, compute_buffer, + group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer, + softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num, + topk, num_expert_group, topk_group, top_num, nram_col_offset, + normalize_mode, valid_mask, split_mask); + + // Store + if (nram_store_num > 0) { + __memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float), + NRAM2GDRAM); + __memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int), + NRAM2GDRAM); + } + __sync_cluster(); + } + } +} + +} // namespace kernels + +KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue, + float *reduce_weight, + int *expert_id, + const void *input, + const void *mask, + const int num_token, + const int num_expert, + const int num_mask, + const int topk, + const int num_expert_group, + const int topk_group, + const cnnlDataType_t dtype, + const int normalize_mode) { + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; + int top_num = topk >= topk_group ? topk : topk_group; + if (num_expert_group <= 1) { + if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported." + << "Supported max num_expert:" + << (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float) + << ". Current num_expert:" << num_expert; + return KernelStatus::KERNEL_STATUS_FAILED; + } + } else { + if (num_expert > + (NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported." + << "Supported max num_expert:" + << (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float) + << ". Current num_expert:" << num_expert; + return KernelStatus::KERNEL_STATUS_FAILED; + } + } + if (topk > num_expert) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert." + << "topk:" << topk << ". num_expert:" << num_expert; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (num_expert_group > 1) { + if (mask != nullptr) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr"; + } + if (num_expert % num_expert_group != 0) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be" + << "divisible by num_expert_group, but now num_expert:" << num_expert + << ", num_expert_group:" << num_expert_group; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (topk_group <= 0 || topk_group > num_expert_group) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be" + << "larger than 0 and less than or equal to num_expert_group, but now topk_group" + << topk_group << ", num_expert group:" << num_expert_group; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (topk > (num_expert / num_expert_group) * topk_group) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less" + << "than or equal to (num_expert / num_expert_group) * topk_group, but now" + << "topk :" << topk << ", num_expert:" << num_expert + << ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group; + return KernelStatus::KERNEL_STATUS_FAILED; + } + } + + if (dtype == CNNL_DTYPE_FLOAT) { + kernels::MLUSoftmaxTopkKernel<<>>( + (float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask, + topk, num_expert_group, topk_group, normalize_mode); + } else if (dtype == CNNL_DTYPE_HALF) { + kernels::MLUSoftmaxTopkKernel<<>>( + (half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask, + topk, num_expert_group, topk_group, normalize_mode); + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + if (!isBf16Supported()) { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + kernels::MLUSoftmaxTopkKernel<<>>( + (bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert, + num_mask, topk, num_expert_group, topk_group, normalize_mode); + } else { + std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported "; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh new file mode 100644 index 0000000..385db1e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_ +#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_ + +#include "../kernel_utils.h" +#include "cnnl.h" +namespace tmo { +/** + * @brief Execute MOE Softmax Top-K Kernel. + * + * This function executes the MOE Softmax Top-K Kernel, which computes + * the Top-K values along a specified dimension after applying softmax to the input data. + * It is specifically designed for reduction along the lowest dimension. + * + * @param queue CNRT queue used to specify the queue for execution. + * @param reduce_weight Pointer to store the Top-K values. + * The shape must be [num_token, topk]. + * @param expert_id Pointer to store the indices of the Top-K values. + * The shape must be [num_token, topk]. + * @param input Pointer to the input data containing the values to be computed. + * The shape must be [num_token, num_expert]. + * @param mask Pointer to the input data containing the mask value to be computed after + * computing softmax, Mask can be nullptr, which means no need to compute, + * otherwise the shape and datatype of mask should be the same as input. + * @param num_token Number of channels in the input data. + * @param num_expert Specified dimension. Note that num_expert should not exceed 32768. + * @param num_mask Number of channels in the mask data. + * @param topk Number of Top-K values to compute. topk should not be larger than num_expert. + * @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert + * should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group + * is not valid. + * @param topk_group Number of Top-K group values to compute. Topk_group should not be larger + * than num_expert_group. + * @param dtype Data type of the input data, should match the actual data type. + * float, half, bfloat16 is supported. + * @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no + normalization is performed; if normalize_mode == 1, the normalized denominator is + the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of + * the products of softmax_result mask. + */ +KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue, + float *reduce_weight, + int *expert_id, + const void *input, + const void *mask, + const int num_token, + const int num_expert, + const int num_mask, + const int topk, + const int num_expert_group, + const int topk_group, + const cnnlDataType_t dtype, + const int normalize_mode); +} // namespace tmo + +#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_linear_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_linear_cache.mlu new file mode 100644 index 0000000..b7438b7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_linear_cache.mlu @@ -0,0 +1,425 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "offline_quant_to_linear_cache.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { + +#define NRAM_BUFFER_SIZE (480 * 1024) +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; + +#define sizeof_(T) (uint32_t)sizeof(T) + +template +__mlu_func__ void quantify(int8_t *nram_output, + float *nram_input_float, + T *nram_input, + float *nram_scale, + int input_num, + int scale_num) { + if (std::is_same::value) { + __bang_half2float(nram_input_float, (half *)nram_input, input_num); + } else if (std::is_same::value) { +#if __BANG_ARCH__ > 500 + __bang_bfloat162float(nram_input_float, (bfloat16_t *)nram_input, input_num); +#endif + } + __bang_cycle_mul(nram_input_float, nram_input_float, nram_scale, input_num, scale_num); + __bang_float2int8_rn(nram_output, (float *)nram_input_float, input_num, 0); +} + +template +__mlu_func__ void quantPerHead(int8_t *output_gdram, + int8_t *output_nram, + const T *input_gdram, + T *input_nram, + const float *scale_gdram, + float *scale_nram, + float *input_nram_float, + T *trans_nram, + int seq, + int head_num, + int head_size, + size_t in_hstr_bytes, // context head_num stide bytes + size_t in_sstr_bytes, // context seq stide bytes + size_t scale_hstr_bytes, // scale head_num stide bytes + size_t out_hstr_bytes, // cache head_num stride bytes + size_t out_sstr_bytes // cache seq stride bytes +) { + constexpr int dtype_size = sizeof_(T); + // nram_input: (head_num, seq, head_size) + int io1_size = head_size * dtype_size; + __memcpy(trans_nram, input_gdram, io1_size, GDRAM2NRAM, seq * io1_size, head_num - 1, io1_size, + seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1); + + // nram_scale:(head_num, seq); + int io2_size = seq * sizeof_(float); + __memcpy(scale_nram, scale_gdram, io2_size, GDRAM2NRAM, io2_size, scale_hstr_bytes, head_num - 1); + __bang_recip(scale_nram, scale_nram, head_num * seq); + + if (std::is_same::value || std::is_same::value) { + __bang_transpose((half *)input_nram, (half *)trans_nram, head_num * seq, head_size); + } else { + __bang_transpose(input_nram_float, (float *)trans_nram, head_num * seq, head_size); + } + quantify(output_nram, input_nram_float, input_nram, scale_nram, head_size * head_num * seq, + head_num * seq); + __bang_transpose((int8_t *)trans_nram, output_nram, head_size, head_num * seq); + + __memcpy(output_gdram, trans_nram, head_size, NRAM2GDRAM, out_hstr_bytes, head_num - 1, + out_sstr_bytes, seq - 1, seq * head_size, head_num - 1, head_size, seq - 1); +} + +template +__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerHead(int8_t *key_cache, + int8_t *value_cache, + const float *key_cache_scale, + const float *value_cache_scale, + const int *cache_bs_offsets, + const int *cache_seq_offsets, + const T *key, + const T *value, + const int *context_seq_offsets, + const int *context_lens, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const size_t context_bs_stride, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t cache_scale_head_stride, + const bool packed, + const int seq_block) { + bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr); + bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr); + if ((!handle_key) && (!handle_value)) { + return; + } + + constexpr int dtype_size = sizeof_(T); + size_t in_hstr_bytes = context_head_stride * dtype_size; + size_t in_sstr_bytes = context_seq_stride * dtype_size; + size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t); + size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t); + size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float); + + /* ***************************nram space **************************** + * | scale | input/output | trans | + * scale size:[head_num, seq_block], float + * input size:[head_num, seq_block, head_size], float + * trans size:, [head_size, head_num, seq_block], T + */ + float *scale_nram = (float *)nram_buffer; + float *input_nram_float = nullptr; + T *trans_nram = nullptr, *input_nram = nullptr; + input_nram_float = scale_nram + head_num * seq_block; + trans_nram = (T *)(input_nram_float + head_num * seq_block * head_size); + + if (std::is_same::value || std::is_same::value) { + // need cast from input_nram to input_nram_float + input_nram = (T *)input_nram_float + seq_block * head_num * head_size; + } else { + input_nram = (T *)input_nram_float; + } + int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space + + for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) { + int context_len = __load_gdram(context_lens + bs_idx); + int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len; + int task_seq_begin = taskIdZ * seq_block; + if (task_seq_begin >= seq_len) continue; + int seq = std::min(seq_len - task_seq_begin, seq_block); + + // context offset + size_t context_offset = 0; + if (packed) { + context_offset = (context_len + task_seq_begin) * context_seq_stride; + } else { + int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx]; + context_offset = + (bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride); + } + + // cache offset + int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]); + int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx]; + if (cache_seq_offset < 0 || cache_bs_offset < 0) { + continue; + } + cache_seq_offset += task_seq_begin; + size_t cache_offset = (cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride); + // per_head, nram input[head_num, seq, head_size], nram scale[head_num, seq] + if (handle_key) { + quantPerHead(key_cache + cache_offset, output_nram, key + context_offset, input_nram, + key_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram, + seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes, + out_hstr_bytes, out_sstr_bytes); + } + + if (handle_value) { + quantPerHead(value_cache + cache_offset, output_nram, value + context_offset, input_nram, + value_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram, + seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes, + out_hstr_bytes, out_sstr_bytes); + } + } +} + +template +__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerChannel( + int8_t *key_cache, + int8_t *value_cache, + const float *key_cache_scale, + const float *value_cache_scale, + const int *cache_bs_offsets, + const int *cache_seq_offsets, + const T *key, + const T *value, + const int *context_seq_offsets, + const int *context_lens, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const size_t context_bs_stride, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t cache_scale_head_stride, + const bool packed, + const int seq_block) { + bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr); + bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr); + if ((!handle_key) && (!handle_value)) { + return; + } + + constexpr int dtype_size = sizeof_(T); + size_t in_hstr_bytes = context_head_stride * dtype_size; + size_t in_sstr_bytes = context_seq_stride * dtype_size; + size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t); + size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t); + size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float); + + /* *********************************nram space ************************************** + * per_chennel: |scale[head_num, head_size]| input[seq_block, head_num, head_size]| + */ + float *scale_nram = (float *)nram_buffer; + float *input_nram_float = scale_nram + head_num * head_size; + + T *input_nram = (T *)input_nram_float; + if (std::is_same::value || std::is_same::value) { + // need cast from input_nram to input_nram_float + input_nram = (T *)input_nram_float + seq_block * head_num * head_size; + } + int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space + int size1 = head_size * sizeof_(float); + int size2 = head_size * dtype_size; + int scale_num = head_num * head_size; + + if (handle_key) { + // load offline scale nram_scale:(head_num, head_size); + __memcpy(scale_nram, key_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes, head_num - 1); + __bang_recip(scale_nram, scale_nram, scale_num); + + for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) { + int context_len = __load_gdram(context_lens + bs_idx); + int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len; + int task_seq_begin = taskIdZ * seq_block; + if (task_seq_begin >= seq_len) continue; + int seq = std::min(seq_len - task_seq_begin, seq_block); + + // context offset + size_t context_offset = 0; + if (packed) { + context_offset = (context_len + task_seq_begin) * context_seq_stride; + } else { + int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx]; + context_offset = + (bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride); + } + + // cache offset + int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]); + int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx]; + if (cache_seq_offset < 0 || cache_bs_offset < 0) { + continue; + } + cache_seq_offset += task_seq_begin; + size_t cache_offset = + (cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride); + + __memcpy(input_nram, key + context_offset, size2, GDRAM2NRAM, size2, head_num - 1, + head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1); + + quantify((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num, + scale_num); + + __memcpy(key_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes, + head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1); + } + } + + if (handle_value) { + // load offline scale nram_scale:(head_num, head_size); + __memcpy(scale_nram, value_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes, + head_num - 1); + __bang_recip(scale_nram, scale_nram, scale_num); + + for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) { + int context_len = __load_gdram(context_lens + bs_idx); + int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len; + int task_seq_begin = taskIdZ * seq_block; + if (task_seq_begin >= seq_len) continue; + int seq = std::min(seq_len - task_seq_begin, seq_block); + + // context offset + size_t context_offset = 0; + if (packed) { + context_offset = (context_len + task_seq_begin) * context_seq_stride; + } else { + int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx]; + context_offset = + (bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride); + } + + // cache offset + int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]); + int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx]; + if (cache_seq_offset < 0 || cache_bs_offset < 0) { + continue; + } + cache_seq_offset += task_seq_begin; + size_t cache_offset = + (cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride); + + __memcpy(input_nram, value + context_offset, size2, GDRAM2NRAM, size2, head_num - 1, + head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1); + + quantify((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num, + scale_num); + + __memcpy(value_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes, + head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1); + } + } +} +} // namespace kernels + +#define LAUNCH_OFFLINE_QUANT_KERNEL(Dtype, Name) \ + kernels::MLUOfflineQuantToLinearCacheKernel##Name<<>>( \ + (int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, \ + (float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (Dtype *)key, \ + (Dtype *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, \ + max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, \ + cache_bs_stride, cache_head_stride, cache_seq_stride, cache_scale_head_stride, packed, \ + seq_block); + +KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue, + void *key_cache, + void *value_cache, + const void *key_cache_scale, + const void *value_cache_scale, + const void *cache_bs_offsets, + const void *cache_seq_offsets, + const void *key, + const void *value, + const void *context_seq_offsets, + const void *context_lens, + const cnnlDataType_t dtype, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const size_t context_bs_stride, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t cache_scale_head_stride, + const bool packed, + const int quant_mode) { + constexpr int nram_size = 480 * 1024; + int dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof(float) : sizeof(half); + int seq_block = 0; + if (quant_mode == 0) { + seq_block = nram_size / (head_num * head_size * sizeof(float)) - 1; + if (seq_block <= 0) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * head_size * sizeof(float) should be less than 240KB when " + "quant_mode is 0." + << std::endl; + } + } else { + seq_block = nram_size / + (head_num * sizeof(float) + head_num * head_size * (sizeof(float) + dtype_size)); + if (seq_block <= 0) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * sizeof(float) + head_num * head_size * (sizeof(float) + " + "context_dtype_size)) " + << " should be less than 480KB when quant_mode is not 0." << std::endl; + } + } + seq_block = std::min(seq_block, max_context_len); + if (seq_block > 16 && seq_block < max_context_len) { + seq_block = seq_block / 16 * 16; + } + int seq_seg = (max_context_len + seq_block - 1) / seq_block; + + CNdev dev; + int cluster_dim, core_dim; + cnCtxGetDevice(&dev); + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev)); + + uint32_t core_num = cluster_dim * core_dim; + uint32_t task_y_dim = std::min((uint32_t)batch, core_num); + cnrtDim3_t dim{1, task_y_dim, (uint32_t)seq_seg}; + + if (dtype == CNNL_DTYPE_HALF) { + if (quant_mode == 0) { + LAUNCH_OFFLINE_QUANT_KERNEL(half, PerChannel); + } else { + LAUNCH_OFFLINE_QUANT_KERNEL(half, PerHead); + } + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + if (quant_mode == 0) { + LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerChannel); + } else { + LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerHead); + } + } else { + if (quant_mode == 0) { + LAUNCH_OFFLINE_QUANT_KERNEL(float, PerChannel); + } else { + LAUNCH_OFFLINE_QUANT_KERNEL(float, PerHead); + } + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_linear_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_linear_cache.mluh new file mode 100644 index 0000000..575bfb4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_linear_cache.mluh @@ -0,0 +1,103 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_ +#define CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_ + +#include "kernel_utils.h" + +namespace tmo { +/** + * @brief Quantize current key and value, Then store key and value to key_cache and value_cache. + * @param queue: The queue for mlu. + * @param key_cache: Pointer to the MLU memory that stores the key cache, + * the shape must be [max_batch, head_num, cache_mem_len, head_size]. + * Data type of key_cache must be int8. key_cache could be nullptr. + * @param value_cache: Pointer to the MLU memory that stores the value cache, + * the shape must be [max_batch, head_num, cache_mem_len, head_size]. + * Data type of value_cache must be int8. value_cache could be nullptr. + * @param key_cache_scale: Pointer to the MLU memory that stores the key cache scale, + * the shape must be [head_num, cache_mem_len] when quant_mode is not zero, + * and [head_num, head_size] when quant_mode is zero. Data type of key_cache_scale + * must be float. value_cache could be nullptr. + * @param value_cache_scale: Pointer to the MLU memory that stores the value cache scale, + * the shape must be [head_num, cache_mem_len] when quant_mode is not zero, + * and [head_num, head_size] when quant_mode is zero. Data type of value_cache_scale + * must be float. value_cache_scale could be nullptr. + * @param cache_bs_offsets: Pointer to the MLU memory that stores the batch + * offset of cache, the shape must be [batch], if it's nullptr, the + * default value is {0, 1, 2 ... batch - 1}. + * @param cache_seq_offsets: Pointer to the MLU memory that stores the sequence + * offset of cache, the shape must be [batch], if it's nullptr, the + * default value is 0 for every batch. + * @param key: Pointer to the MLU memory that stores the key, + * the shape must be [batch, max_contxt_len, head_num, head_size]. + * Data type of key couble be float/half/bfloat16. key could be nullptr. + * @param value: Pointer to the MLU memory that stores the value, + * the shape must be [batch, max_contxt_len, head_num, head_size]. + * Data type of value couble be float/half/bfloat16, value could be nullptr. + * @param context_seq_offsets: Pointer to the MLU memory that stores the + * sequence offset of context, the shape must be [batch]. if it's nullptr, + * the default value is 0 for every batch. It must be nullptr when packed is true. + * @param context_lens: Pointer to the MLU memory that stores the sequence length or cumulative + * sequence length of context. when packed is false, the shape must be [batch], which + * indicates sequence length of context. when packed is true, the shape must be [batch + 1], + which + * indicates cumulative sequence length of context. + * @param dtype: Data type. + * @param batch: Batch size. + * @param head_num: Head number. + * @param head_size: Head size. + * @param max_contxt_len: The maximum sequence length of context. + * @param cache_mem_len: The maximum sequence length of cache. + * @param contxt_bs_stride: The stride of batch in context, does not work when packed is true. + * @param contxt_head_stride: The stride of head_num in context. + * @param contxt_seq_stride: The stride of max_contxt_len in context. + * @param cache_bs_stride: The stride of batch in cache. + * @param cache_head_stride: The stride of head_num in cache. + * @param cache_seq_stride: The stride of cache_mem_len in cache. + * @param cache_scale_bs_stride: The stride of batch in cache scale. + * @param cache_scale_head_stride: The stride of head in cache scale. + * @param packed: A boolean value indicates whether to use pack mode. + * @param quant_mode: A int value indicates the quantify mode, 0 means quantify by per_channel, and + others value means quantify by per_head. + * @note If one of key/key_cache/key_cache_scale is nullptr, nothing todo for key. + If one of value/value_cache/value_cache_scale is nullptr, nothing todo for value. + */ +KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue, + void *key_cache, + void *value_cache, + const void *key_cache_scale, + const void *value_cache_scale, + const void *cache_bs_offsets, + const void *cache_seq_offsets, + const void *key, + const void *value, + const void *context_seq_offsets, + const void *context_lens, + const cnnlDataType_t dtype, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const size_t context_bs_stride, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t cache_seq_stride, + const size_t cache_scale_head_stride, + const bool packed, + const int quant_mode); +} // namespace tmo + +#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_paged_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_paged_cache.mlu new file mode 100644 index 0000000..be4dbf7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_paged_cache.mlu @@ -0,0 +1,232 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include "offline_quant_to_paged_cache.mluh" + +namespace tmo { +namespace kernels { + +#define sizeof_(T) (uint32_t)sizeof(T) +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +#define REM_FOR_STACK (32 * 1024) +__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK]; +__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + +template +__mlu_func__ void quantifyToInt8(T *nram_input, float *nram_scale, int token_handle, int head_len) { + // quantify + if (std::is_same::value) { + __bang_half2float((float *)nram_input, + (half *)((int8_t *)nram_input + token_handle * head_len * sizeof_(half)), + token_handle * head_len); + } + if (std::is_same::value) { +#if __BANG_ARCH__ > 500 + __bang_bfloat162float( + (float *)nram_input, + (bfloat16_t *)((int8_t *)nram_input + token_handle * head_len * sizeof_(bfloat16_t)), + token_handle * head_len); +#endif + } + __bang_cycle_mul((float *)nram_input, (float *)nram_input, nram_scale, token_handle * head_len, + head_len); + __bang_float2int8_rn((int8_t *)nram_input, (float *)nram_input, token_handle * head_len, 0); +} + +template +__mlu_global__ void MLUOfflineQuantToPagedCacheKernel(T *key, + T *value, + int8_t *key_cache, + int8_t *value_cache, + float *key_cache_scale, + float *value_cache_scale, + int *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int tokens_num, + int head_num, + int block_size, + int head_size, + int tokens_block) { +/*******************************************************nram space*********************** + * nram:| input | scale | cache_offset | scale_offset | mask | temp | index | + * input size: tokens_block * head_num * head_size * sizeof(float) + * scale size: head_num * head_size * sizeof(float) + * cache_offset size: tokens_block * head_num * sizeof(float) + * scale_offset size: equal to cache_offset size + * mask size: CEIL_DIV(tokens_size * head_num, 8) * sizeof(int8_t) + * temp size: CEIL_ALIGN(token_size * head_num, 8) * sizeof(int) + * index size: head_num * sizeof(int) + ****************************************************************************************/ +#if __BANG_ARCH__ > 500 + int token_begin = taskId * tokens_block; + if (token_begin >= tokens_num) return; + int token_handle = std::min(tokens_block, tokens_num - token_begin); + + int seq_len = token_handle * head_num; + int head_len = head_num * head_size; + int pad8_num = CEIL_DIV(seq_len, CHAR_BIT) * CHAR_BIT; + int input_size = seq_len * head_size * sizeof_(float); + int8_t *nram_input = nram_buffer; + float *nram_scale = (float *)(nram_buffer + input_size); + int *cache_offset = (int *)(nram_scale + head_len); + int *scale_offset = cache_offset + pad8_num; + int *nram_mask = scale_offset + pad8_num; + int *nram_temp = nram_mask + pad8_num; + int *head_index = nram_temp + pad8_num; + + // generate range: (0, 1, 2, ..., (head_num - 1)) + __memcpy(head_index, nram_range_32, std::min(head_num, 32) * sizeof_(int), NRAM2NRAM); + int begin = 32; + while (begin < head_num) { + int count = std::min(begin, head_num - begin); + __bang_add_scalar(head_index + begin, head_index, begin, count); + begin += count; + } + + // load slot(token_handle) -> expand(head_num, token_handle) ->transpose(token_handle, head_num) + int token_size = token_handle * sizeof_(int); + __memcpy(scale_offset, slot_mapping + token_begin, token_size, GDRAM2NRAM); + __memcpy(nram_temp, scale_offset, token_size, NRAM2NRAM, token_size, 0, head_num - 1); + __bang_transpose(scale_offset, nram_temp, head_num, token_handle); + + __bang_write_zero((float *)nram_temp, pad8_num); + __bang_ge_bitindex((float *)nram_mask, (float *)scale_offset, (float *)nram_temp, pad8_num); + + // calculate cache/scale scatter offset + __bang_div(cache_offset, scale_offset, (int)block_size, seq_len); + __bang_rem(scale_offset, scale_offset, (int)block_size, seq_len); + __bang_mul_scalar(cache_offset, cache_offset, head_num * block_size, seq_len); + __bang_mul_scalar(head_index, head_index, block_size, head_num); + __bang_cycle_add(cache_offset, cache_offset, head_index, seq_len, head_num); + __bang_add(scale_offset, cache_offset, scale_offset, seq_len); + __bang_mul_scalar(cache_offset, scale_offset, head_size, seq_len); + __bang_mul_scalar(scale_offset, scale_offset, sizeof_(float), seq_len); + + int hidden_bytes = head_num * head_size * sizeof_(T); + bool half_size = (sizeof(T) == sizeof(half)); + if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr) { + // load key_cache_scale + __memcpy(nram_scale, key_cache_scale, head_len * sizeof_(float), GDRAM2NRAM); + __bang_recip(nram_scale, nram_scale, head_len); + // (token_handle, head_num, head_size) + __memcpy(nram_input + half_size * token_handle * hidden_bytes, key + token_begin * key_stride0, + hidden_bytes, GDRAM2NRAM, hidden_bytes, key_stride0 * sizeof_(T), token_handle - 1); + // quantify + quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len); + // scatter to gdram + __scatter(key_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size, + NRAM2GDRAM, head_size, seq_len); + } + + if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr) { + // load key_cache_scale + __memcpy(nram_scale, value_cache_scale, head_len * sizeof_(float), GDRAM2NRAM); + __bang_recip(nram_scale, nram_scale, head_len); + // (token_handle, head_num, head_size) + __memcpy(nram_input + half_size * token_handle * hidden_bytes, + value + token_begin * value_stride0, hidden_bytes, GDRAM2NRAM, hidden_bytes, + value_stride0 * sizeof_(T), token_handle - 1); + // quantify + quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len); + // scatter to gdram + __scatter(value_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size, + NRAM2GDRAM, head_size, seq_len); + } +#endif +} +} // namespace kernels + +KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue, + cnnlDataType_t data_type, + void *key, + void *value, + void *key_cache, + void *value_cache, + void *key_cache_scale, + void *value_cache_scale, + void *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_num, + int block_size, + int head_size) { + if (is_arch300()) { + std::cerr << "[invokeOfflineQuantToPagedCache]: kernel does not support MLU300 devices." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + int dtype_size = 1; + if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) { + dtype_size = 2; + } else if (data_type == CNNL_DTYPE_FLOAT) { + dtype_size = 4; + } else { + std::cerr << "invokeOfflineQuantToPagedCache: unsupport data type\n"; + return KernelStatus::KERNEL_STATUS_FAILED; + } + int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size; + if (kv_cache_range > UINT32_MAX) { + std::cerr + << "invokeOfflineQuantToPagedCache: The addressing range of kv_cache cannot exceed 4G." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + // nram_size_need: token_block * head_num * head_size + head_num * head_size * sizeof(float) + // token_block * head_num * 4 * sizeof(int) + head_num * sizeof(int) + // nram uesd: 480KB + int nram_size = 480 * 1024 - num_heads * sizeof(int) - num_heads * head_size * sizeof(float); + int hidden_bytes = num_heads * head_size * sizeof(float) + + 4 * CEIL_DIV(num_heads, CHAR_BIT) * CHAR_BIT * sizeof(int); + int seq_block = nram_size / hidden_bytes; + if (seq_block <= 0) { + std::cerr << "invokeOfflineQuantToPagedCache: " + << "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) " + << "should be less than 480KB.\n"; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (seq_block > 16) { + seq_block = seq_block / 16 * 16; + } + int cluster_num, core_dim; + CNdev dev; + cnCtxGetDevice(&dev); + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev)); + int core_num = core_dim * cluster_num; + seq_block = std::min(seq_block, CEIL_DIV(num_tokens, core_num)); + uint32_t task_dim = CEIL_DIV(num_tokens, seq_block); + cnrtDim3_t dim{1, task_dim, 1}; + + if (data_type == CNNL_DTYPE_FLOAT) { + kernels::MLUOfflineQuantToPagedCacheKernel<<>>( + (float *)key, (float *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0, + value_stride0, num_tokens, num_heads, block_size, head_size, seq_block); + } else if (data_type == CNNL_DTYPE_HALF) { + kernels::MLUOfflineQuantToPagedCacheKernel<<>>( + (half *)key, (half *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0, + value_stride0, num_tokens, num_heads, block_size, head_size, seq_block); + } else { + kernels::MLUOfflineQuantToPagedCacheKernel<<>>( + (bfloat16_t *)key, (bfloat16_t *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0, + value_stride0, num_tokens, num_heads, block_size, head_size, seq_block); + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_paged_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_paged_cache.mluh new file mode 100644 index 0000000..120b2de --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/offline_quant_to_paged_cache.mluh @@ -0,0 +1,62 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_ +#define CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_ + +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Perform offline_quant_to_paged_cache operation. + * @param queue[in]: The queue for mlu. + * @param data_type[in]: The cnnl data type of key. + * @param key[in]: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens, + * num_heads, head_size]. Data type of key must be half/bfloat16_t/float. + * @param value[in]: Pointer to the MLU memory that stores the value tensor which has shape + * [num_tokens, num_heads, head_size]. Data type of value must be half/bfloat16_t/float. + * @param key_cache[out]: Pointer to the MLU memory that stores the key_cache tensor which has + * shape [num_blocks, num_heads, block_size, head_size]. Data type of key cache must be int8_t. + * @param value_cache[out]: Pointer to the MLU memory that stores the value_cache tensor which has + * shape [num_blocks, num_heads, block_size, head_size]. Data type of value cache must be int8_t. + * @param key_cache_scale[in]: Pointer to the MLU memory that stores the key_cache_scale tensor + * which has shape [num_heads, head_size]. Data type of key cache scale must be float. + * @param value_cache_scale[in]: Pointer to the MLU memory that stores the value_cache_scale tensor + * which has shape [num_heads, head_size]. Data type of value cache scale must be float. + * @param slot_mapping[in]: Pointer to the MLU memory that stores the slot_mapping tensor which has + * shape [num_tokens]. Data type of slot mapping must be int32_t. + * @param key_stride0[in]: The first dimension stride length of key_cache tensor. + * @param value_stride0[in]: The first dimension stride length of value_cache tensor. + * @param num_tokens[in]: Total number of tokens. + * @param num_heads[in]: Head number. + * @param block_num[in]: Total number of blocks. + * @param block_size[in]: Number of tokens per block. + * @param head_size[in]: Head size. + * @note: offline_quant_to_paged_cache does not support MLU300 device. + */ +KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue, + cnnlDataType_t data_type, + void *key, + void *value, + void *key_cache, + void *value_cache, + void *key_cache_scale, + void *value_cache_scale, + void *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_num, + int block_size, + int head_size); +} // namespace tmo + +#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mlu new file mode 100644 index 0000000..0c7c452 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mlu @@ -0,0 +1,156 @@ +#include +#include "cnrt.h" +#include "operate_cu_seq_lens.mluh" + +namespace { +constexpr int pair_elem_num = 2; +} + +namespace tmo { +namespace kernels { +#define ONCHIP_DATA_NUM ((int)((__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) / sizeof(int))) +__nram__ int nram_buffer[ONCHIP_DATA_NUM]; +__nram__ const int acc_seq_lens[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +__mlu_func__ void genSeqLens(int *seq_len_nram, int start, int multi, int elem_count) { + constexpr int acc_seq_lens_size = 16; + int count = std::min(acc_seq_lens_size, elem_count); + int add_on = multi * acc_seq_lens_size; + __bang_mul_scalar(seq_len_nram, acc_seq_lens, multi, count); + __bang_add_scalar(seq_len_nram, seq_len_nram, start, count); + while (count < elem_count) { + __bang_add_scalar(seq_len_nram + count, seq_len_nram, add_on, + std::min(count, elem_count - count)); + count *= 2; + add_on *= 2; + } +} + +__mlu_global__ void MLUSliceCuSeqlens(int *cu_seq_lens, + int *sliced_cu_seq_lens, + int batch, + int every, + int remain, + int loop) { + int cu_seq_lens_elem_count = batch + 1; + int sliced_cu_seq_lens_elem_count = batch + loop; + + int *cu_seq_lens_narm = nram_buffer; + int *sliced_cu_seq_lens_narm = cu_seq_lens_narm + cu_seq_lens_elem_count; + int *sliced_cu_seq_lens_narm_start = sliced_cu_seq_lens_narm; + + __memcpy(cu_seq_lens_narm, cu_seq_lens, cu_seq_lens_elem_count * sizeof(int), GDRAM2NRAM); + __bang_write_zero(sliced_cu_seq_lens_narm, sliced_cu_seq_lens_elem_count); + for (int i = 0; i < loop; ++i) { + int elem_num = 1 + (i == loop - 1 && remain != 0 ? remain : every); + __bang_sub_scalar(sliced_cu_seq_lens_narm, cu_seq_lens_narm, cu_seq_lens_narm[0], elem_num); + cu_seq_lens_narm += elem_num - 1; + sliced_cu_seq_lens_narm += elem_num; + } + __memcpy(sliced_cu_seq_lens, sliced_cu_seq_lens_narm_start, + sliced_cu_seq_lens_elem_count * sizeof(int), NRAM2GDRAM); +} + +__mlu_global__ void MLUGenerateKVCuSeqlens(int *gen_cu_seq_lens, + int every, + int remain, + int loop, + int seq_len, + bool is_causal_mask, + int seg_data_num, + int task_num) { + int offset = seg_data_num * taskIdX; + int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset); + int seq_len_elem_num = total_elem_num / pair_elem_num; + + int *gen_cu_seq_lens_narm = nram_buffer; + __bang_write_zero(gen_cu_seq_lens_narm, total_elem_num); + + if (is_causal_mask) { + int *seq_lens_narm = gen_cu_seq_lens_narm + total_elem_num; + genSeqLens(seq_lens_narm, every * offset / pair_elem_num, every, seq_len_elem_num); + __memcpy(gen_cu_seq_lens_narm + 1, seq_lens_narm, sizeof(int), NRAM2NRAM, + pair_elem_num * sizeof(int), sizeof(int), seq_len_elem_num - 1); + if (remain != 0 && taskIdX == task_num - 1) { + gen_cu_seq_lens_narm[total_elem_num - 1] -= (every - remain); + } + } else { + __bang_write_value(gen_cu_seq_lens_narm + 1, 1, seq_len, pair_elem_num * sizeof(int), + seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0); + } + __memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int), + NRAM2GDRAM); +} + +__mlu_global__ void MLUGenerateQCuSeqlens(int *gen_cu_seq_lens, + int every, + int remain, + int loop, + int seg_data_num, + int task_num) { + int offset = seg_data_num * taskIdX; + int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset); + int seq_len_elem_num = total_elem_num / pair_elem_num; + + int *gen_cu_seq_lens_narm = nram_buffer; + + __bang_write_zero(gen_cu_seq_lens_narm, total_elem_num); + __bang_write_value(gen_cu_seq_lens_narm + 1, 1, every, pair_elem_num * sizeof(int), + seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0); + if (remain != 0 && taskIdX == task_num - 1) { + gen_cu_seq_lens_narm[total_elem_num - 1] = remain; + } + __memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int), + NRAM2GDRAM); +} +} // namespace kernels + +KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue, + int *cu_seq_lens, + int *sliced_cu_seq_lens, + int batch, + int parallel_num) { + int every = (batch + parallel_num - 1) / parallel_num; + int repeat = batch / every; + int remain = batch % every; + int loop = repeat + (remain != 0); + + cnrtDim3_t dim{1, 1, 1}; + kernels::MLUSliceCuSeqlens<<>>(cu_seq_lens, sliced_cu_seq_lens, + batch, every, remain, loop); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue, + int *gen_cu_seq_lens, + int seq_len, + int parallel_num, + bool is_causal_mask, + bool is_kv_seq_len) { + int every = (seq_len + parallel_num - 1) / parallel_num; + int repeat = seq_len / every; + int remain = seq_len % every; + int loop = repeat + (remain != 0); + + int seg_data_num = ONCHIP_DATA_NUM / 2; + if (is_kv_seq_len && is_causal_mask) { + // max segnum for 2d memcpy is 64k + seg_data_num = std::min(pair_elem_num * 64 * 1024, seg_data_num); + } + int total_elem_num = loop * pair_elem_num; + int task_num = (total_elem_num + seg_data_num - 1) / seg_data_num; + + cnrtDim3_t dim{(unsigned int)task_num, 1, 1}; + + if (is_kv_seq_len) { + kernels::MLUGenerateKVCuSeqlens<<>>( + gen_cu_seq_lens, every, remain, loop, seq_len, is_causal_mask, seg_data_num, task_num); + } else { + kernels::MLUGenerateQCuSeqlens<<>>( + gen_cu_seq_lens, every, remain, loop, seg_data_num, task_num); + } + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mluh new file mode 100644 index 0000000..8593f5f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mluh @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_ +#define CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" + +namespace tmo { + +/** + * @brief slice cu_seq_lens and cu_k_seq_lens for parallel context when attention split batch; + * @example + * cu_seq_lens: [0, 2, 5, 10, 20, 33, 46, 51, 77] + * batch: 8, parallel_num: 3 + * sliced_cu_seq_lens: [0, 2, 5, 10, 0, 10, 23, 36, 0, 5, 31] + * @param queue: The queue for mlu. + * @param cu_seq_lens: Input. Pointer to the MLU memory that stores the current seq lens, the shape + * is [batch + 1]. + * @param sliced_cu_seq_lens: Output. Pointer to the MLU memory that stores the sliced current seq + * lens, the shape is [batch + loop_time]. + * @param batch: Batch size. + * @param parallel_num: Parallel num of batch. + */ +KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue, + int *cu_seq_lens, + int *sliced_cu_seq_lens, + int batch, + int parallel_num); + +/** + * @brief generate cu_seq_lens and cu_k_seq_lens for parallel context when attention split seq; + * @example + * seq_len: 11, parallel_num: 3 + * gen_cu_seq_lens for q: [0, 4, 0, 4, 0, 3] + * @example + * seq_len: 11, parallel_num: 3 + * is_causal_mask false, gen_cu_seq_lens for kv: [0, 11, 0, 11, 0, 11] + * is_causal_mask true , gen_cu_seq_lens for kv: [0, 4, 0, 8, 0, 11] + * @param queue: The queue for mlu. + * @param gen_cu_seq_lens: Output. Pointer to the MLU memory that stores the generated current seq + * lens, the shape is [2 * loop_time]. + * @param seq_len: Sequence length. + * @param parallel_num: Parallel num of sequence length. + * @param is_causal_mask: Whether self attention use causal mask. + * @param is_kv_seq_len: The gen_cu_seq_lens is for q or kv. + */ +KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue, + int *gen_cu_seq_lens, + int seq_len, + int parallel_num, + bool is_causal_mask, + bool is_kv_seq_len); + +} // namespace tmo + +#endif // CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu new file mode 100644 index 0000000..663feea --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "preload.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { + +#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024) +__mlu_shared__ int8_t sram_buffer[SRAM_SIZE]; + +__mlu_func__ void split(const int64_t total, + const int64_t num, + const int64_t id, + size_t &every, + size_t &offset) { + int64_t base = total / num; + int64_t tail = total - base * num; + every = base + (id < tail ? 1 : 0); + offset = base * id + (id < tail ? id : tail); +} + +__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) { +#if __BANG_ARCH__ > 372 + size_t cluster_preload_size = 0; + size_t cluster_preload_offset = 0; + split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset); + + size_t load_repeat = cluster_preload_size / SRAM_SIZE; + size_t load_remain = cluster_preload_size % SRAM_SIZE; + + for (size_t i = 0; i < load_repeat + 1; i++) { + if (i == load_repeat && load_remain == 0) { + break; + } + size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain); + int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE; + if (loop_load_size > 0) { + __memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM); + } + } +#endif +} + +} // namespace kernels + +KernelStatus invokePreload(cnrtQueue_t queue, + void *filter_ptr, + size_t filter_size, + size_t preload_size) { + if (preload_size == 0) { + std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + if (preload_size > filter_size) { + preload_size = filter_size; + } + + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1}; + if (cluster_num == 1) { + dim.y = 1; + } else if (cluster_num >= 2) { + dim.y = 2; + } + + kernels::MLUUnion1Preload<<>>(filter_ptr, preload_size); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mluh new file mode 100644 index 0000000..7c6e67a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/preload.mluh @@ -0,0 +1,34 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_PRELOAD_MLUH_ +#define CSRC_KERNELS_PRELOAD_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief When tp is greater than 1, while executing reducesum, the weight of ffn + * or selfattention to be calculated is loaded into LLC in advance. + * @param queue: The queue for mlu. + * @param filter_ptr: Input. Pointer to the MLU memory that stores the weight of ffn or + * selfattention. + * @param filter_size: The weight size of ffn or selfattention. + * @param preload_size: The size of the preload weight. + * @note The weights of ffn or selfattention must be continuous in filter_ptr. + */ +KernelStatus invokePreload(cnrtQueue_t queue, + void *filter_ptr, + size_t filter_size, + size_t preload_size); +} // namespace tmo + +#endif // CSRC_KERNELS_PRELOAD_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_linear_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_linear_cache.mlu new file mode 100644 index 0000000..85110fd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_linear_cache.mlu @@ -0,0 +1,476 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include +#include "quant_to_linear_cache.mluh" + +namespace tmo { + +namespace kernels { + +#define NRAM_BUFFER_SIZE (480 * 1024) +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; +__nram__ uint8_t post_table_nram[64]; + +#define sizeof_(T) (uint32_t)sizeof(T) +#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0)) + +#define __reshape_nhwc2nchw_smallhw(TYPE) \ + asm volatile( \ + "trans.tiling.nram.nram." TYPE \ + "[%[dst]], [%[src]], " \ + "%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \ + "%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \ + ".posttable.nram([%[post]]); \n\t" ::[dst] "r"(dst), \ + [src] "r"(src), [post] "r"(post_table), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), \ + [in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \ + [is4] "r"(batch_offset), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \ + [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \ + [ds4] "r"(batch_offset), [dn5] "i"(1), [ds5] "i"(0)); + +template +__mlu_func__ void __reshape_nhwc2nchw_smallhw_init(uint8_t *post_table_nram, uint32_t hw) { + uint32_t align_num = 64 / sizeof(T); + int tmp = hw - 1; + asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp)); + int align_hw = 1 << (tmp + 1); + int repeat = align_num / align_hw; + for (int i = 0; i < 64; i++) { + int idx = i / sizeof_(T); + int tmp_idx = (idx % hw) * repeat + idx / hw; + int real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T); + int mask = idx < repeat * hw ? 0x80 : 0x0; + post_table_nram[i] = (uint8_t)real_idx + mask; + } +} + +template +__mlu_func__ void trans_nhwc2nchw(T *dst, + const T *src, + uint8_t *post_table, + const uint32_t n, + const uint32_t hw, + const uint32_t c) { + uint32_t align_num = 64 / sizeof_(T); + int tmp = hw - 1; + asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp)); + int align_hw = 1 << (tmp + 1); + int repeat = align_num / align_hw; + int in0 = 64; + int in1 = hw; + int is1 = c * sizeof(T); + int in3 = c / align_num; + int is3 = in0; + int batch_offset = hw * c * sizeof(T); + int dn0 = hw * repeat * sizeof_(T); + int dn1 = align_hw; + int ds1 = dn0; + int dn3 = in3; + int ds3 = dn0 * dn1; + align_hw = in3 > 0 ? align_hw : 0; + if (align_hw == 2) { + __reshape_nhwc2nchw_smallhw("b256"); + } else if (align_hw == 4) { + __reshape_nhwc2nchw_smallhw("b128"); + } else if (align_hw == 8) { + __reshape_nhwc2nchw_smallhw("b64"); + } else if (align_hw == 16) { + __reshape_nhwc2nchw_smallhw("b32"); + } else if (align_hw == 32) { + __reshape_nhwc2nchw_smallhw("b16"); + } + + constexpr uint32_t bw = 8 * sizeof_(T); + int in3_rem = c % align_num; + int tail_in0 = in3_rem * sizeof_(T); + int tail_dn0 = hw * sizeof_(T); + if (in3_rem > 0) { + asm volatile( + "trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \ + %[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5], \ + %[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" :: + [bw] "i"(bw), + [dst] "r"(dst + dn3 * ds3 / sizeof_(T)), [src] "r"(src + is3 * in3 / sizeof_(T)), + [in0] "r"(tail_in0), [in1] "r"(hw), [is1] "r"(c * sizeof_(T)), [in2] "i"(1), [is2] "i"(0), + [in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(batch_offset), [in5] "i"(1), + [is5] "i"(0), [dn0] "r"(tail_dn0), [dn1] "r"(in3_rem), [ds1] "r"(tail_dn0), [dn2] "i"(1), + [ds2] "i"(0), [dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(batch_offset), + [dn5] "i"(1), [ds5] "i"(0)); + } +} + +template +__mlu_func__ void quantify(T *src_input, + float *nram_input, + float *transpose, + float *scale, + float *scale_recip, + int high_dim, + int low_dim, + int quant_bit) { + float recip_data = 1.0f / ((1 << (quant_bit - 1)) - 1); + if (std::is_same::value) { + __bang_transpose((half *)transpose, (half *)src_input, high_dim, low_dim); + __bang_abs((half *)src_input, (half *)transpose, high_dim * low_dim); + __bang_maxpool((half *)scale, (half *)src_input, high_dim, low_dim, 1, low_dim, 1, 1, 1); + __bang_half2float(scale_recip, (half *)scale, high_dim); + } else if (std::is_same::value) { +#if __BANG_ARCH__ > 500 + __bang_transpose((bfloat16_t *)transpose, (bfloat16_t *)src_input, high_dim, low_dim); + __bang_abs((bfloat16_t *)src_input, (bfloat16_t *)transpose, high_dim * low_dim); + __bang_maxpool((bfloat16_t *)scale, (bfloat16_t *)src_input, high_dim, low_dim, 1, low_dim, 1, + 1, 1); + __bang_bfloat162float(scale_recip, (bfloat16_t *)scale, high_dim); +#endif + } else { + __bang_transpose(transpose, (float *)src_input, high_dim, low_dim); + __bang_abs((float *)src_input, transpose, high_dim * low_dim); + __bang_maxpool(scale_recip, (float *)src_input, high_dim, low_dim, 1, low_dim, 1, 1, 1); + } + __bang_mul_scalar(scale, scale_recip, recip_data, high_dim); + __bang_recip(scale_recip, scale, high_dim); + if (std::is_same::value) { + __bang_half2float(nram_input, (half *)transpose, high_dim * low_dim); + __bang_cycle_mul(transpose, nram_input, scale_recip, high_dim * low_dim, high_dim); + } else if (std::is_same::value) { +#if __BANG_ARCH__ > 500 + __bang_bfloat162float(nram_input, (bfloat16_t *)transpose, high_dim * low_dim); + __bang_cycle_mul(transpose, nram_input, scale_recip, high_dim * low_dim, high_dim); +#endif + } else { + __bang_cycle_mul(transpose, transpose, scale_recip, high_dim * low_dim, high_dim); + } +} + +__mlu_func__ void castKeyToIntx(int8_t *dst, float *src, int high_dim, int low_dim, int quant_bit) { + if (quant_bit == 8) { + __bang_float2int8_rn((int8_t *)src, (float *)dst, high_dim * low_dim, 0); + __bang_transpose((int8_t *)dst, (int8_t *)src, low_dim, high_dim); + } else if (quant_bit == 4) { + __bang_transpose(src, (float *)dst, low_dim, high_dim); + __bang_float2int4_rn((int4x2_t *)dst, src, high_dim * low_dim, 0); + } +} + +__mlu_func__ void castValueToIntx(int8_t *value_cache_begin, + int8_t *value_cache_end, + uint8_t *post_table_nram, + float *dst, + float *src, + size_t cache_head_stride, + int seq, + int head_num, + int head_size, + int group_num, + int quant_bit, + bool need_pad_front, + bool need_pad_back) { + if (quant_bit == 8) { + __bang_float2int8_rn((int8_t *)src, dst, seq * head_num * head_size, 0); + __bang_transpose((int8_t *)dst, (int8_t *)src, head_size / group_num, + seq * head_num * group_num); + } else { + __bang_transpose(src, dst, head_size / group_num, seq * head_num * group_num); + __reshape_nhwc2nchw_smallhw_init(post_table_nram, 2); + if (!(need_pad_front || need_pad_back)) { + __bang_float2int8_rn((int8_t *)src, src, seq * head_num * head_size, + 0); // [head_num, seq, head_size] + trans_nhwc2nchw((int8_t *)dst, (int8_t *)src, post_table_nram, seq / 2, 2, + head_num * head_size); + __bang_int82int4_rn((int4x2_t *)dst, (int8_t *)dst, seq * head_num * head_size, 0, 0); + } else { + int origin_seq = seq; + if (need_pad_front) { + __memcpy((int8_t *)dst, value_cache_begin, head_size, GDRAM2NRAM, head_size, + cache_head_stride, head_num - 1); + __bang_band_scalar((int8_t *)dst, (int8_t *)dst, 0x0F, head_num * head_size); + seq += 1; + } + if (need_pad_back) { + __memcpy((int8_t *)dst + seq * head_num * head_size, value_cache_end, head_size, GDRAM2NRAM, + head_size, cache_head_stride, head_num - 1); + __bang_srl((int8_t *)dst + seq * head_num * head_size, + (int8_t *)dst + seq * head_num * head_size, 4, head_num * head_size); + seq += 1; + } + __bang_float2int8_rn((int8_t *)dst + need_pad_front * head_num * head_size, src, + origin_seq * head_num * head_size, 0); // [new_seq, head_num, head_size] + trans_nhwc2nchw((int8_t *)src, (int8_t *)dst, post_table_nram, seq / 2, 2, + head_num * head_size); // [seq / 2, 2, head_num, head_size] + __bang_int82int4_rn((int4x2_t *)dst, (int8_t *)src, seq * head_num * head_size, 0, 0); + } + } +} + +// [head_num, batch, seq_seg] +template +__mlu_global__ void MLUQuantToLinearCacheKernel(int8_t *key_cache, + int8_t *value_cache, + float *key_cache_scale, + float *value_cache_scale, + int *cache_bs_offsets, + int *cache_seq_offsets, + T *key, + T *value, + int *context_seq_offsets, + int *context_lens, + int batch, + int head_num, + int head_size, + int max_context_len, + int cache_mem_len, + size_t context_bs_stride, + size_t context_head_stride, + size_t context_seq_stride, + size_t cache_bs_stride, + size_t cache_head_stride, + size_t key_cache_seq_stride, + size_t value_cache_seq_stride, + size_t cache_scale_bs_stride, + size_t cache_scale_head_stride, + bool packed, + int seq_block, + int quant_bit, + int group_num) { + float *nram_input = (float *)nram_buffer; + T *src_input = (T *)nram_input; + if (sizeof_(T) == sizeof_(half)) { + src_input = (T *)(nram_buffer + seq_block * head_num * head_size * sizeof_(T)); + } + + float *nram_trans = nram_input + seq_block * head_num * head_size; + float *nram_scale = nram_trans + seq_block * head_num * head_size; + float *nram_scale_recip = nram_scale + seq_block * head_num * group_num; + constexpr int dtype_size = sizeof_(T); + int head_size_store = head_size * quant_bit / 8; + + for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) { + int seq_offset = + (packed || context_seq_offsets == nullptr) ? 0 : __load_gdram(context_seq_offsets + bs_idx); + int cache_seq_offset = + cache_seq_offsets == nullptr ? 0 : __load_gdram(cache_seq_offsets + bs_idx); + int context_len = __load_gdram(context_lens + bs_idx); + int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len; + int key_seq_begin = taskIdZ * seq_block; + int first_value_seq = (quant_bit == 4 && cache_seq_offset % 2 == 1) + ? std::min(seq_len, seq_block - 1) + : std::min(seq_len, seq_block); + int value_seq_begin = taskIdZ == 0 ? 0 : first_value_seq + (taskIdZ - 1) * seq_block; + + int key_seq = std::min(seq_len - key_seq_begin, seq_block); + int value_seq = taskIdZ == 0 ? first_value_seq : std::min(seq_len - value_seq_begin, seq_block); + size_t key_context_offset = 0; + size_t value_context_offset = 0; + if (packed) { + key_context_offset += (context_len + key_seq_begin) * context_seq_stride * dtype_size; + value_context_offset += (context_len + value_seq_begin) * context_seq_stride * dtype_size; + } else { + key_context_offset += + (bs_idx * context_bs_stride + (key_seq_begin + seq_offset) * context_seq_stride) * + dtype_size; + value_context_offset += + (bs_idx * context_bs_stride + (value_seq_begin + seq_offset) * context_seq_stride) * + dtype_size; + } + int key_cache_seq_offset = cache_seq_offset; + int value_cache_seq_offset = cache_seq_offset; + int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx]; + if (key_cache_seq_offset < 0 || cache_bs_offset < 0) { + continue; + } + key_cache_seq_offset += key_seq_begin; + value_cache_seq_offset += value_seq_begin; + + size_t key_cache_offset = + (cache_bs_offset * cache_bs_stride + key_cache_seq_offset * key_cache_seq_stride); + size_t key_cache_scale_offset = + cache_bs_offset * cache_scale_bs_stride + key_cache_seq_offset * group_num; + + if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr && + key_seq_begin < seq_len) { + int8_t *key_cache_begin = key_cache + key_cache_offset; + float *key_cache_scale_begin = key_cache_scale + key_cache_scale_offset; + char *key_begin = (char *)key + key_context_offset; + + // nram_input: (head_num, seq, head_size) + __memcpy(src_input, key_begin, head_size * dtype_size, GDRAM2NRAM, + key_seq * head_size * dtype_size, head_num - 1, head_size * dtype_size, key_seq - 1, + context_head_stride * dtype_size, head_num - 1, context_seq_stride * dtype_size, + key_seq - 1); + // [head_num, seq, head_size] + quantify(src_input, nram_input, nram_trans, nram_scale, nram_scale_recip, + key_seq * head_num * group_num, head_size / group_num, quant_bit); + castKeyToIntx((int8_t *)nram_trans, nram_input, key_seq * head_num * group_num, + head_size / group_num, quant_bit); + + // after quantify: (head_num, seq, head_size * quant_bit / 8) + __memcpy(key_cache_begin, nram_trans, head_size_store, NRAM2GDRAM, key_cache_seq_stride, + key_seq - 1, cache_head_stride, head_num - 1, head_size_store, key_seq - 1, + key_seq * head_size_store, head_num - 1); + + // nram_scale: (head_num, seq, group_num) + __memcpy(key_cache_scale_begin, nram_scale, key_seq * group_num * sizeof_(float), NRAM2GDRAM, + cache_scale_head_stride * sizeof_(float), key_seq * group_num * sizeof_(float), + head_num - 1); + } + + if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr && + value_seq_begin < seq_len) { + size_t value_cache_offset = + cache_bs_offset * cache_bs_stride + value_cache_seq_offset * value_cache_seq_stride; + if (quant_bit == 4) { + value_cache_offset = cache_bs_offset * cache_bs_stride + + (value_cache_seq_offset / 2) * value_cache_seq_stride; + } + size_t value_cache_scale_offset = + cache_bs_offset * cache_scale_bs_stride + value_cache_seq_offset * group_num; + + int8_t *value_cache_begin = value_cache + value_cache_offset; + int8_t *value_cache_end = + value_cache + cache_bs_offset * cache_bs_stride + + (DIV_UP(cache_seq_offset + seq_len, 2) - 1) * value_cache_seq_stride; + + float *value_cache_scale_begin = value_cache_scale + value_cache_scale_offset; + char *value_begin = (char *)value + value_context_offset; + bool need_pad_front = (quant_bit == 4) && (taskIdZ == 0) && (cache_seq_offset % 2 == 1); + bool need_pad_back = (quant_bit == 4) && (value_seq_begin + value_seq >= seq_len) && + ((cache_seq_offset + seq_len) % 2 == 1); + // quant_bit == 8 : nram_input: (head_num, seq, head_size) + if (quant_bit == 8) { + __memcpy(src_input, value_begin, head_size * dtype_size, GDRAM2NRAM, + value_seq * head_size * dtype_size, head_num - 1, head_size * dtype_size, + value_seq - 1, context_head_stride * dtype_size, head_num - 1, + context_seq_stride * dtype_size, value_seq - 1); + } else if (quant_bit == 4) { + // quant_bit == 4 : nram_input: (seq, head_num. head_size) + __memcpy(src_input, value_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, + head_num - 1, head_num * head_size * dtype_size, value_seq - 1, + context_head_stride * dtype_size, head_num - 1, context_seq_stride * dtype_size, + value_seq - 1); + } + quantify(src_input, nram_input, nram_trans, nram_scale, nram_scale_recip, + value_seq * head_num * group_num, head_size / group_num, quant_bit); + castValueToIntx(value_cache_begin, value_cache_end, post_table_nram, nram_trans, nram_input, + cache_head_stride, value_seq, head_num, head_size, group_num, quant_bit, + need_pad_front, need_pad_back); + + // [head_num, seq, head_size] + if (quant_bit == 8) { + __memcpy(value_cache_scale_begin, nram_scale, value_seq * group_num * sizeof_(float), + NRAM2GDRAM, cache_scale_head_stride * sizeof_(float), + value_seq * group_num * sizeof_(float), head_num - 1); + __memcpy(value_cache_begin, nram_trans, head_size, NRAM2GDRAM, value_cache_seq_stride, + value_seq - 1, cache_head_stride, head_num - 1, head_size, value_seq - 1, + value_seq * head_size, head_num - 1); + } else if (quant_bit == 4) { + int new_seq = value_seq + ((int)need_pad_front + (int)need_pad_back); + __sync(); + __memcpy_async(nram_scale_recip, nram_scale, group_num * sizeof_(float), NRAM2NRAM, + value_seq * group_num * sizeof_(float), head_num - 1, + group_num * sizeof_(float), value_seq - 1, group_num * sizeof_(float), + head_num - 1, head_num * group_num * sizeof_(float), value_seq - 1); + __memcpy_async(value_cache_begin, nram_trans, head_size, NRAM2GDRAM, cache_head_stride, + head_num - 1, value_cache_seq_stride, new_seq / 2 - 1, head_size, + head_num - 1, head_num * head_size, new_seq / 2 - 1); + __sync(); + __memcpy(value_cache_scale_begin, nram_scale_recip, value_seq * group_num * sizeof_(float), + NRAM2GDRAM, cache_scale_head_stride * sizeof_(float), + value_seq * group_num * sizeof_(float), head_num - 1); + } + } // end if + } // end for +} +} // namespace kernels + +void invokeQuantToLinearCache(cnrtQueue_t queue, + void *key_cache, + void *value_cache, + void *key_cache_scale, + void *value_cache_scale, + const void *cache_bs_offsets, + const void *cache_seq_offsets, + void *key, + void *value, + const void *context_seq_offsets, + const void *context_lens, + const cnnlDataType_t dtype, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const size_t context_bs_stride, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t cache_scale_bs_stride, + const size_t cache_scale_head_stride, + const bool packed, + const int quant_bit, + const int group_size) { + constexpr int nram_size = 480 * 1024; + int group_num = head_size / group_size; + int hidden_bytes = head_num * (head_size + group_num) * sizeof(float) * 2; + int seq_block = nram_size / hidden_bytes; + if (seq_block <= 1) { + std::cerr << __func__ << "," << __LINE__ + << " :head_num * (head_size + group_num) * sizeof(float) should be less than 120KB." + << std::endl; + } + if (seq_block > 16) { + seq_block = seq_block / 16 * 16; + } else { + seq_block = seq_block / 2 * 2; + } + int seq_seg = max_context_len / seq_block + 1; + int cluster_num, core_dim; + CNdev dev; + cnCtxGetDevice(&dev); + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev)); + int core_num = core_dim * cluster_num; + uint32_t task_y_dim = std::min(batch, core_num); + cnrtDim3_t dim{1, task_y_dim, (uint32_t)seq_seg}; + + if (dtype == CNNL_DTYPE_HALF) { + kernels::MLUQuantToLinearCacheKernel<<>>( + (int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, + (float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (half *)key, + (half *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, + max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, + cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride, + cache_scale_bs_stride, cache_scale_head_stride, packed, seq_block, quant_bit, group_num); + } else if (dtype == CNNL_DTYPE_BFLOAT16) { + kernels::MLUQuantToLinearCacheKernel<<>>( + (int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, + (float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, + (bfloat16_t *)key, (bfloat16_t *)value, (int *)context_seq_offsets, (int *)context_lens, + batch, head_num, head_size, max_context_len, cache_mem_len, context_bs_stride, + context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride, + key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, + cache_scale_head_stride, packed, seq_block, quant_bit, group_num); + } else { + kernels::MLUQuantToLinearCacheKernel<<>>( + (int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, + (float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (float *)key, + (float *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, + max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, + cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride, + cache_scale_bs_stride, cache_scale_head_stride, packed, seq_block, quant_bit, group_num); + } +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_linear_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_linear_cache.mluh new file mode 100644 index 0000000..c2f70f0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_linear_cache.mluh @@ -0,0 +1,105 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_QUANT_TO_LINEAR_CACHE_MLUH_ +#define CSRC_KERNELS_QUANT_TO_LINEAR_CACHE_MLUH_ + +#include "cnnl.h" +namespace tmo { +/** + * @brief Quantize current key and value, Then store key and value to key_cache and value_cache. + * @param queue: The queue for mlu. + * @param key_cache: Pointer to the MLU memory that stores the key cache, + * the shape must be [max_batch, head_num, cache_mem_len, head_size]. + * Data type of key_cache must be int8. key_cache could be nullptr. +* @param value_cache: Pointer to the MLU memory that stores the value cache, + * the shape must be [max_batch, head_num, cache_mem_len, head_size]. + * Data type of value_cache must be int8. value_cache could be nullptr. +* @param key_cache_scale: Pointer to the MLU memory that stores the key cache scale, + * the shape must be [max_batch, head_num, cache_mem_len]. + * Data type of key_cache_scale must be float. key_cache_scale could be nullptr. +* @param value_cache_scale: Pointer to the MLU memory that stores the value cache scale, + * the shape must be [max_batch, head_num, cache_mem_len]. + * Data type of value_cache_scale must be float. value_cache_scale could be nullptr. + * @param cache_bs_offsets: Pointer to the MLU memory that stores the batch + * offset of cache, the shape must be [batch], if it's nullptr, the + * default value is {0, 1, 2 ... batch - 1}. + * @param cache_seq_offsets: Pointer to the MLU memory that stores the sequence + * offset of cache, the shape must be [batch], if it's nullptr, the + * default value is 0 for every batch. + * @param key: Pointer to the MLU memory that stores the key, + * the shape must be [batch, max_contxt_len, head_num, head_size]. + * Data type of key couble be float/half/bfloat16. key could be nullptr. + * @param value: Pointer to the MLU memory that stores the value, + * the shape must be [batch, max_contxt_len, head_num, head_size]. + * Data type of value couble be float/half/bfloat16, value could be nullptr. + * @param context_seq_offsets: Pointer to the MLU memory that stores the + * sequence offset of context, the shape must be [batch]. if it's nullptr, + * the default value is 0 for every batch. It must be nullptr when packed is true. + * @param context_lens: Pointer to the MLU memory that stores the sequence length or cumulative + * sequence length of context. when packed is false, the shape must be [batch], which + * indicates sequence length of context. when packed is true, the shape must be [batch + 1], which + * indicates cumulative sequence length of context. + * @param dtype: Data type. + * @param batch: Batch size. + * @param head_num: Head number. + * @param head_size: Head size. + * @param max_contxt_len: The maximum sequence length of context. + * @param cache_mem_len: The maximum sequence length of cache. + * @param contxt_bs_stride: The stride of batch in context, does not work when packed is true. + * @param contxt_head_stride: The stride of head_num in context. + * @param contxt_seq_stride: The stride of max_contxt_len in context. + * @param cache_bs_stride: The stride of batch in cache. + * @param cache_head_stride: The stride of head_num in cache. + * @param key_cache_seq_stride: The stride of cache_mem_len in key_cache. + * @param value_cache_seq_stride: The stride of cache_mem_len in value_cache. + * @param cache_scale_bs_stride: The stride of batch in cache scale. + * @param cache_scale_head_stride: The stride of head in cache scale. + * @param packed: A boolean value indicates whether to use pack mode. + * @param quant_bit: Bit width of quantified results. + * @param group_size: Size of a group during group quantization, + * @note If one of key/key_cache/key_cache_scale is nullptr, nothing todo for key. + If one of value/value_cache/value_cache_scale is nullptr, nothing todo for value. + A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for + the corresponding batch. + */ +void invokeQuantToLinearCache(cnrtQueue_t queue, + void *key_cache, + void *value_cache, + void *key_cache_scale, + void *value_cache_scale, + const void *cache_bs_offsets, + const void *cache_seq_offsets, + void *key, + void *value, + const void *context_seq_offsets, + const void *context_lens, + const cnnlDataType_t dtype, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const size_t context_bs_stride, + const size_t context_head_stride, + const size_t context_seq_stride, + const size_t cache_bs_stride, + const size_t cache_head_stride, + const size_t key_cache_seq_stride, + const size_t value_cache_seq_stride, + const size_t cache_scale_bs_stride, + const size_t cache_scale_head_stride, + const bool packed, + const int quant_bit, + const int group_size); +} // namespace tmo + +#endif // CSRC_KERNELS_QUANT_TO_LINEAR_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_paged_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_paged_cache.mlu new file mode 100644 index 0000000..a1be66a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_paged_cache.mlu @@ -0,0 +1,241 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include "quant_to_paged_cache.mluh" + +namespace tmo { +namespace kernels { + +#define sizeof_(T) (uint32_t)sizeof(T) +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +#define REM_FOR_STACK (32 * 1024) +__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK]; +#if __BANG_ARCH__ > 500 +__nram__ const int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; +#endif + +template +__mlu_func__ void quantifyToInt8(int8_t *nram_output, + float *nram_float, + T *nram_input, + float *nram_scale, + float *nram_temp, + int seq_len, + int head_size) { + // quantify + __bang_transpose((T *)nram_input, (T *)nram_scale, seq_len, head_size); + if (std::is_same::value) { + __bang_half2float(nram_float, (half *)nram_input, head_size * seq_len); + } + if (std::is_same::value) { +#if __BANG_ARCH__ > 500 + __bang_bfloat162float(nram_float, (bfloat16_t *)nram_input, head_size * seq_len); +#endif + } + __bang_abs(nram_scale, nram_float, head_size * seq_len); + __bang_maxpool(nram_scale, nram_scale, seq_len, head_size, 1, head_size, 1, 1, 1); + __bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, seq_len); + __bang_recip(nram_temp, nram_scale, seq_len); + __bang_cycle_mul(nram_float, nram_float, nram_temp, head_size * seq_len, seq_len); + __bang_float2int8_rn((int8_t *)nram_float, nram_float, head_size * seq_len, 0); + __bang_transpose((int8_t *)nram_output, (int8_t *)nram_float, head_size, seq_len); +} + +template +__mlu_global__ void MLUQuantToPagedCacheKernel(T *key, + T *value, + int8_t *key_cache, + int8_t *value_cache, + float *key_cache_scale, + float *value_cache_scale, + int *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int tokens_num, + int head_num, + int block_size, + int head_size, + int tokens_block) { +/*******************************************************nram space*********************** + * nram:| input | scale | cache_offset | scale_offset | mask | temp | index | + * input size: tokens_block * head_num * head_size * sizeof(float) + * scale size: equal to input size + * cache_offset size: tokens_block * head_num * sizeof(float) + * scale_offset size: equal to cache_offset size + * mask size: CEIL_DIV(tokens_size * head_num, 8) * sizeof(int8_t) + * temp size: CEIL_ALIGN(token_size * head_num, 8) * sizeof(int) + * index size: head_num * sizeof(int) + ****************************************************************************************/ +#if __BANG_ARCH__ > 500 + int token_begin = taskId * tokens_block; + if (token_begin >= tokens_num) return; + int token_handle = std::min(tokens_block, tokens_num - token_begin); + + int seq_len = token_handle * head_num; + int pad8_num = CEIL_DIV(seq_len, CHAR_BIT) * CHAR_BIT; + int input_size = seq_len * head_size * sizeof_(float); + int8_t *nram_input = nram_buffer; + float *nram_scale = (float *)(nram_buffer + input_size); + int *cache_offset = (int *)((int8_t *)nram_scale + input_size); + int *scale_offset = cache_offset + pad8_num; + int *nram_mask = scale_offset + pad8_num; + int *nram_temp = nram_mask + pad8_num; + int *head_index = nram_temp + pad8_num; + + // generate range: (0, 1, 2, ..., (head_num - 1)) + __memcpy(head_index, nram_range_32, std::min(head_num, 32) * sizeof_(int), NRAM2NRAM); + int begin = 32; + while (begin < head_num) { + int count = std::min(begin, head_num - begin); + __bang_add_scalar(head_index + begin, head_index, begin, count); + begin += count; + } + + // load slot(token_handle) -> expand(head_num, token_handle) ->transpose(token_handle, head_num) + int token_size = token_handle * sizeof_(int); + __memcpy(scale_offset, slot_mapping + token_begin, token_size, GDRAM2NRAM); + __memcpy(nram_temp, scale_offset, token_size, NRAM2NRAM, token_size, 0, head_num - 1); + __bang_transpose(scale_offset, nram_temp, head_num, token_handle); + + __bang_write_zero((float *)nram_temp, pad8_num); + __bang_ge_bitindex((float *)nram_mask, (float *)scale_offset, (float *)nram_temp, pad8_num); + + // calculate cache/scale scatter offset + __bang_div(cache_offset, scale_offset, (int)block_size, seq_len); + __bang_rem(scale_offset, scale_offset, (int)block_size, seq_len); + __bang_mul_scalar(cache_offset, cache_offset, head_num * block_size, seq_len); + __bang_mul_scalar(head_index, head_index, block_size, head_num); + __bang_cycle_add(cache_offset, cache_offset, head_index, seq_len, head_num); + __bang_add(scale_offset, cache_offset, scale_offset, seq_len); + __bang_mul_scalar(cache_offset, scale_offset, head_size, seq_len); + __bang_mul_scalar(scale_offset, scale_offset, sizeof_(float), seq_len); + + int hidden_bytes = head_num * head_size * sizeof_(T); + int8_t *nram_output = (int8_t *)(nram_input + seq_len * head_size * sizeof_(half)); + T *nram_input_origin = (T *)nram_input; + if (!std::is_same::value) { + nram_input_origin = (T *)(nram_input + seq_len * head_size * (sizeof_(float) - sizeof_(T))); + } + if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr) { + // (token_handle, head_num, head_size) + __memcpy(nram_scale, key + token_begin * key_stride0, hidden_bytes, GDRAM2NRAM, hidden_bytes, + key_stride0 * sizeof_(T), token_handle - 1); + // quantify + quantifyToInt8(nram_output, (float *)nram_input, nram_input_origin, nram_scale, + (float *)nram_temp, seq_len, head_size); + // scatter to gdram + __scatter(key_cache, (int8_t *)nram_output, (uint32_t *)cache_offset, nram_mask, head_size, + NRAM2GDRAM, head_size, seq_len); + __scatter(key_cache_scale, nram_scale, (uint32_t *)scale_offset, nram_mask, sizeof_(float), + NRAM2GDRAM, sizeof_(float), seq_len); + } + + if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr) { + // (token_handle, head_num, head_size) + __memcpy(nram_scale, value + token_begin * value_stride0, hidden_bytes, GDRAM2NRAM, + hidden_bytes, value_stride0 * sizeof_(T), token_handle - 1); + // quantify + quantifyToInt8(nram_output, (float *)nram_input, nram_input_origin, nram_scale, + (float *)nram_temp, seq_len, head_size); + // scatter to gdram + __scatter(value_cache, (int8_t *)nram_output, (uint32_t *)cache_offset, nram_mask, head_size, + NRAM2GDRAM, head_size, seq_len); + __scatter(value_cache_scale, nram_scale, (uint32_t *)scale_offset, nram_mask, sizeof_(float), + NRAM2GDRAM, sizeof_(float), seq_len); + } +#endif +} +} // namespace kernels + +KernelStatus invokeQuantToPagedCache(cnrtQueue_t queue, + cnnlDataType_t data_type, + void *key, + void *value, + void *key_cache, + void *value_cache, + void *key_cache_scale, + void *value_cache_scale, + void *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_num, + int block_size, + int head_size) { + int dtype_size = 1; + if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) { + dtype_size = 2; + } else if (data_type == CNNL_DTYPE_FLOAT) { + dtype_size = 4; + } else { + std::cerr << "invokeQuantToPagedCache: unsupport data type\n"; + return KernelStatus::KERNEL_STATUS_FAILED; + } + int64_t kv_range = block_num * block_size * num_heads * head_size * dtype_size; + if (kv_range > UINT32_MAX) { + std::cerr << "invokeQuantToPagedCache: The addressing range of kv_cache cannot exceed 4G." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (is_arch300()) { + std::cerr << "[invokeQuantToPagedCache]: kernel does not support MLU300 devices." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + // nram_size_need: token_block * head_num * head_size * 2 + + // token_block * head_num * 4 * sizeof(int) + head_num * sizeof(int) + // nram uesd: 480KB + int nram_size = 480 * 1024 - num_heads * sizeof(int); + int hidden_bytes = num_heads * head_size * 2 * sizeof(float) + 4 * num_heads * sizeof(int); + int seq_block = nram_size / hidden_bytes; + if (seq_block <= 0) { + std::cerr << "invokeQuantToPagedCache: " + << "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) " + << "should be less than 480KB.\n"; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (seq_block > 16) { + seq_block = seq_block / 16 * 16; + } + int cluster_num, core_dim; + CNdev dev; + cnCtxGetDevice(&dev); + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev)); + int core_num = core_dim * cluster_num; + seq_block = std::min(seq_block, CEIL_DIV(num_tokens, core_num)); + uint32_t task_dim = CEIL_DIV(num_tokens, seq_block); + cnrtDim3_t dim{1, task_dim, 1}; + + if (data_type == CNNL_DTYPE_FLOAT) { + kernels::MLUQuantToPagedCacheKernel<<>>( + (float *)key, (float *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0, + value_stride0, num_tokens, num_heads, block_size, head_size, seq_block); + } else if (data_type == CNNL_DTYPE_HALF) { + kernels::MLUQuantToPagedCacheKernel<<>>( + (half *)key, (half *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0, + value_stride0, num_tokens, num_heads, block_size, head_size, seq_block); + } else { + kernels::MLUQuantToPagedCacheKernel<<>>( + (bfloat16_t *)key, (bfloat16_t *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0, + value_stride0, num_tokens, num_heads, block_size, head_size, seq_block); + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_paged_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_paged_cache.mluh new file mode 100644 index 0000000..fff7fd1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_to_paged_cache.mluh @@ -0,0 +1,63 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_QUANT_TO_PAGED_CACHE_MLUH_ +#define CSRC_KERNELS_QUANT_TO_PAGED_CACHE_MLUH_ + +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Perform quant_to_paged_cache operation. + * @param handle: The handle of cnnl. + * @param data_type: The cnnl data type of key. + * @param key: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens, + * num_heads, head_size]. Data type of key must be half/bfloat16_t/float. + * @param value: Pointer to the MLU memory that stores the value tensor which has shape [num_tokens, + * num_heads, head_size]. Data type of key must be half/bfloat16_t/float. + * @param key_cache: Pointer to the MLU memory that stores the key_cache tensor which has + * shape [num_blocks, num_heads, block_size, head_size]. Data type of key cache must be int8_t. + * @param value_cache: Pointer to the MLU memory that stores the value_cache tensor which has + * shape [num_blocks, num_heads, block_size, head_size]. Data type of value cache must be int8_t. + * @param key_cache_scale: Pointer to the MLU memory that stores the key_cache_scale tensor which + * has shape [num_blocks, num_heads, block_size]. Data type of key cache scale must be float. + * @param value_cache_scale: Pointer to the MLU memory that stores the value_cache_scale tensor + * which has shape [num_blocks, num_heads, block_size]. Data type of value cache scale must be + * float. + * @param slot_mapping: Pointer to the MLU memory that stores the slot_mapping tensor which has + * shape [num_tokens]. Data type of slot mapping must be int32_t. + * @param key_stride0: The first dimension stride length of key_cache tensor. + * @param value_stride0: The first dimension stride length of value_cache tensor. + * @param num_tokens: Total number of tokens. + * @param num_heads: Head number. + * @param block_num: Total number of blocks. + * @param block_size: Number of tokens per block. + * @param head_size: Head size. + * @note: quant_to_paged_cache does not support MLU300 device. + */ +KernelStatus invokeQuantToPagedCache(cnrtQueue_t queue, + cnnlDataType_t data_type, + void *key, + void *value, + void *key_cache, + void *value_cache, + void *key_cache_scale, + void *value_cache_scale, + void *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_num, + int block_size, + int head_size); +} // namespace tmo + +#endif // CSRC_KERNELS_QUANT_TO_PAGED_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quant_utils.h b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_utils.h new file mode 100644 index 0000000..67a29bf --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quant_utils.h @@ -0,0 +1,313 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_QUANT_UTILS_H_ +#define CSRC_KERNELS_QUANT_UTILS_H_ + +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "kernel_utils.h" + +namespace tmo { + +#ifndef LT_NUM +#define LT_NUM (64) +#endif + +#ifndef ANT_LT_ROW +#define ANT_LT_ROW (4) +#endif + +#ifndef LT_NUM_ANT +#define LT_NUM_ANT (16) +#endif + +#ifndef ONE_LINE +#define ONE_LINE (64) +#endif + +#ifndef sizeof_ +#define sizeof_(T) (uint32_t)sizeof(T) +#endif + +#ifndef WRAM_LT_MAP16_STRIDE +#define WRAM_LT_MAP16_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 16) +#endif + +#ifndef TRANS_TABLE_SIZE +#define TRANS_TABLE_SIZE (64) +#endif + +#ifndef DIV_UP +#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0)) +#endif + +#ifndef CONV_FUSE_OP_CVT +#define CONV_FUSE_OP_CVT(dtype, op, cvt, op_data) \ + asm volatile("conv.nram.rn.f32" dtype dtype \ + "[%[dst]], [%[src]], [%[kernel]], %[src_channel], " \ + "%[src_height], 1, 1, 1, 1, 1, %[dst_channel]" op cvt \ + ";\n\t" ::[dst] "r"((Td *)output), \ + [src] "r"((Ts *)input), [kernel] "r"((Ts *)filter), [src_channel] "r"(k), \ + [src_height] "r"(m), [dst_channel] "r"(n), [operand0] "r"(op_data)); +#endif + +#define __reshape_nhwc2nchw_smallc(TYPE) \ + asm volatile( \ + "trans.tiling.nram.nram." TYPE \ + "[%[dst]], [%[src]], " \ + "%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4], %[is4], %[in5], %[is5]," \ + "%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]," \ + ".pretable.nram([%[pre]]); \n\t" ::[dst] "r"((T *)dst), \ + [src] "r"((T *)src), [pre] "r"((uint8_t *)pre_table), [in0] "r"(in0), [in1] "r"(in1), \ + [is1] "r"(in0), [in2] "i"(1), [is2] "i"(0), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(n), \ + [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), [dn0] "r"(dn0), [dn1] "r"(dn1), \ + [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(n), \ + [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0)); + +__mlu_func__ void next_power_of_two(int32_t &align_num, const int32_t num) { + int32_t tmp = num - 1; + asm volatile("findlast1.gpr.b32 %[out], %[in];" : [out] "=r"(tmp) : [in] "r"(tmp)); + align_num = 1 << (tmp + 1); +} + +/* copy from cnnl utils/trans_small.py by xwm. */ +template +__mlu_func__ void __reshape_nhwc2nchw_smallc_init(uint8_t *pre_table_nram, uint32_t channel) { + int32_t align_c; + next_power_of_two(align_c, channel); + int32_t align_num = ONE_LINE / sizeof_(T); + int32_t repeat = align_num / align_c; + for (int i = 0; i < 64; ++i) { + int32_t idx = i / sizeof_(T); + int32_t tmp_idx = (idx % repeat) * channel + idx / repeat; + int32_t real_idx = tmp_idx * sizeof_(T) + i % sizeof_(T); + __store_nram((uint8_t *)pre_table_nram + i, (uint8_t)real_idx + 0x80); + } +} + +template +__mlu_func__ void trans_nhwc2nchw_smallc(T *dst, + T *src, + uint8_t *pre_table, + uint32_t n, + uint32_t h, + uint32_t w, + uint32_t c) { + int32_t align_c; + next_power_of_two(align_c, c); + int32_t align_num = 64 / sizeof_(T); + int32_t hw = h * w; + int32_t repeat = align_num / align_c; + int32_t in0 = c * repeat * sizeof_(T); + int32_t in1 = align_c; + int32_t in3 = hw / align_num; + int32_t is3 = in0 * in1; + int32_t n_stride = hw * c * sizeof_(T); + int32_t dn0 = 64; + int32_t dn1 = c; + int32_t ds1 = hw * sizeof_(T); + int32_t dn3 = in3; + int32_t ds3 = dn0; + align_c = in3 > 0 ? align_c : 0; + if (align_c == 2) { + __reshape_nhwc2nchw_smallc("b256"); + } else if (align_c == 4) { + __reshape_nhwc2nchw_smallc("b128"); + } else if (align_c == 8) { + __reshape_nhwc2nchw_smallc("b64"); + } else if (align_c == 16) { + __reshape_nhwc2nchw_smallc("b32"); + } else if (align_c == 32) { + __reshape_nhwc2nchw_smallc("b16"); + } + + constexpr int32_t bw = 8 * sizeof_(T); + int32_t in3_rem = hw % align_num; + int32_t tail_in0 = c * sizeof_(T); + int32_t tail_dn0 = in3_rem * sizeof_(T); + if (in3_rem) { + asm volatile( + "trans.tiling.nram.nram.b%[bw] [%[dst]], [%[src]], \ + %[in0], %[in1], %[is1], %[in2], %[is2], %[in3], \ + %[is3], %[in4], %[is4], %[in5], %[is5], \ + %[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], \ + %[ds3], %[dn4], %[ds4], %[dn5], %[ds5]; \n\t" ::[bw] "i"(bw), + [dst] "r"((T *)dst + dn3 * ds3 / sizeof_(T)), [src] "r"((T *)src + is3 * in3 / sizeof_(T)), + [in0] "r"(tail_in0), [in1] "r"(in3_rem), [is1] "r"(tail_in0), [in2] "i"(1), [is2] "i"(0), + [in3] "i"(1), [is3] "i"(0), [in4] "r"(n), [is4] "r"(n_stride), [in5] "i"(1), [is5] "i"(0), + [dn0] "r"(tail_dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), [dn2] "i"(1), [ds2] "i"(0), + [dn3] "i"(1), [ds3] "i"(0), [dn4] "r"(n), [ds4] "r"(n_stride), [dn5] "i"(1), [ds5] "i"(0)); + } +} + +__mlu_func__ void convert(float *dst, int8_t *src, int32_t num) { + __bang_int82float((float *)dst, (int8_t *)src, num, 0); +} + +__mlu_func__ void convert(float *dst, int4x2_t *src, int32_t num) { + __bang_int42float((float *)dst, (int4x2_t *)src, num, 0); +} + +__mlu_func__ void convert(half *dst, float *src, int32_t num) { + __bang_float2half((half *)dst, (float *)src, num); +} + +__mlu_func__ void convert(bfloat16_t *dst, float *src, int32_t num) { +#if __BANG_ARCH__ >= 500 + __bang_float2bfloat16((bfloat16_t *)dst, (float *)src, num); +#endif +} + +__mlu_func__ void convert(int8_t *dst, int4x2_t *src, int32_t num) { + __bang_int42int8((int8_t *)dst, (int4x2_t *)src, num, 0, 0); +} + +// if the dst dtype == src dtype, do nothing. if you want to mv, use mv directly +__mlu_func__ void convert(float *dst, float *src, int32_t num) {} + +__mlu_func__ void convert(int8_t *dst, int8_t *src, int32_t num) {} + +template +__mlu_func__ void transpose(T *dst, T *src, int32_t dim1, int32_t dim2) { + __bang_transpose((T *)dst, (T *)src, dim1, dim2); +} + +// if data type is int4x2_t, transpose is not supported directly +__mlu_func__ void transpose(int4x2_t *dst, int4x2_t *src, int32_t dim1, int32_t dim2) {} + +template +__mlu_func__ void mvNram2WramLT16(int8_t *wram_dst, + int8_t *nram_src, + int32_t n, + int32_t k, + int32_t total_k) { + int32_t data_size = k * sizeof_(T); + int32_t ds0 = PAD_UP(data_size, ONE_LINE); + int32_t ss0 = total_k * sizeof_(T); + int32_t count = DIV_UP(n, LT_NUM); + if (count > 0) { + for (int i = 0; i < count; ++i) { + __memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1, + WRAM_LT_MAP16_STRIDE, LT_NUM_ANT - 1, ss0, LT_NUM - 1, 0, 0); + wram_dst += ANT_LT_ROW * ds0; + nram_src += LT_NUM * ss0; + } + } + + count = n % LT_NUM / ANT_LT_ROW; + if (count > 0) { + __memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ANT_LT_ROW - 1, + WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ANT_LT_ROW - 1, 0, 0); + wram_dst += count * WRAM_LT_MAP16_STRIDE; + nram_src += count * ANT_LT_ROW * ss0; + } + + count = n % ANT_LT_ROW; + if (count) { + __memcpy((int8_t *)wram_dst, (int8_t *)nram_src, data_size, NRAM2WRAM, ds0, ss0, count - 1); + } +} + +template +__mlu_func__ void +conv_fuse_mul_cvt(Td *output, Ts *input, Ts *filter, float *partial, int m, int n, int k) { + if (std::is_same::value && std::is_same::value) { + CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.f16()", partial) + } else if (std::is_same::value && std::is_same::value) { +#if __BANG_ARCH__ > 500 + CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", ", .cvt.dst.rn.bf16()", partial) +#endif + } else if (std::is_same::value && std::is_same::value) { + CONV_FUSE_OP_CVT(".f32", ", .mul.partial.rn([%[operand0]])", "", partial) + } +} + +template +__mlu_func__ void process_offsets(int32_t *lens_nram, + int32_t *offsets_nram, + const int32_t *context_lens, + const int32_t *context_seq_offsets, + const int32_t batch_size) { + if constexpr (ProcessOffsets) { + __memcpy((int32_t *)lens_nram, (int32_t *)context_lens, sizeof_(int32_t) * batch_size, + GDRAM2NRAM); + int total_lens = 0; + for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + __store_nram((int32_t *)offsets_nram + batch_idx, total_lens); + total_lens += __load_nram((int32_t *)lens_nram + batch_idx); + } + } +} + +template +__mlu_func__ void load_len_offset(int32_t &seq_len, + int32_t &seq_offset, + const int32_t *lens_nram, + const int32_t *offsets_nram, + const int32_t *context_lens, + const int32_t *context_seq_offsets, + const int32_t batch_idx) { + if (ProcessOffsets) { + seq_len = __load_nram((int32_t *)lens_nram + batch_idx); + seq_offset = __load_nram((int32_t *)offsets_nram + batch_idx); + } else { + seq_len = __load_gdram((int32_t *)context_lens + batch_idx); + seq_offset = __load_gdram((int32_t *)context_seq_offsets + batch_idx); + } +} + +template +__mlu_func__ void load_scale_once(T *scale_nram, + const T *scale, + const int32_t head_num, + const int32_t head_size, + const size_t scale_bs_stride, + const size_t scale_head_stride) { + __memcpy((T *)scale_nram, (T *)scale, head_size * sizeof_(T), GDRAM2NRAM, head_size * sizeof_(T), + scale_head_stride * sizeof_(T), head_num - 1); +} + +template +__mlu_func__ void dequantize(T *output_nram, + Tc *input_nram, + Ts *scale_nram, + Ts *start_nram, + const int32_t input_num, + const int32_t scale_num) { + convert((float *)output_nram, (Tc *)input_nram, input_num); + convert((float *)start_nram, (Ts *)scale_nram, input_num); + __bang_cycle_mul((float *)output_nram, (float *)output_nram, (float *)start_nram, input_num, + scale_num); + convert((T *)output_nram, (float *)output_nram, input_num); +} + +inline void getDeviceCoreAndRam(int32_t &cluster_dim, + int32_t &core_dim, + int32_t &nram_size, + int32_t &wram_size, + int32_t &sram_size, + const int32_t rem_for_stack) { + CNdev dev; + cnCtxGetDevice(&dev); + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&nram_size, cnrtAttrNramSizePerMcore, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&wram_size, cnrtAttrWramSizePerMcore, dev)); + CNRT_CHECK(cnrtDeviceGetAttribute(&sram_size, cnrtAttrSramSizePerMcore, dev)); + nram_size -= rem_for_stack; + sram_size -= rem_for_stack; +} +} // namespace tmo +#endif // CSRC_KERNELS_QUANT_UTILS_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mlu new file mode 100644 index 0000000..979af8f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mlu @@ -0,0 +1,186 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include "cnnl.h" +#include "cnrt.h" +#include "quantize.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +#define NRAM_BUFFER_SIZE (480 * 1024) +namespace kernels { +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; + +template +__mlu_func__ void quantify(int8_t *nram_dst, + TSrc *nram_src, + TSrc *nram_scale_temp, + float *nram_scale, + float *scale_origin, + int core_deal_tokens, + int hidden) { + __bang_abs((TSrc *)nram_dst, nram_src, core_deal_tokens * hidden); + __bang_maxpool(nram_scale_temp, (TSrc *)nram_dst, core_deal_tokens, hidden, 1, hidden, 1, 1, 1); + if (std::is_same::value) { + __bang_half2float(nram_scale, (half *)nram_scale_temp, core_deal_tokens); + } + + __bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, core_deal_tokens); + __bang_recip(scale_origin, nram_scale, core_deal_tokens); + if (std::is_same::value) { + __bang_float2half_rn((half *)scale_origin, scale_origin, core_deal_tokens); + } + __bang_cycle_mul(nram_src, nram_src, (TSrc *)scale_origin, core_deal_tokens * hidden, + core_deal_tokens); + + if (std::is_same::value) { + __bang_half2int8_rn((int8_t *)nram_src, (half *)nram_src, core_deal_tokens * hidden, 0); + } else if (std::is_same::value) { + __bang_float2int8_rn((int8_t *)nram_src, (float *)nram_src, core_deal_tokens * hidden, 0); + } + __bang_transpose(nram_dst, (int8_t *)nram_src, hidden, core_deal_tokens); +} + +template +__mlu_global__ void MLUQuantizePerHead( + TDst *dst, // [bs, seq, head_num, head_size], may not be continuous + TScale *scale, // [bs, seq], must becontinuous + const TSrc *src, // [bs, seq, head_num, head_size], may not be continuous + int bs, + int seq_len, + int head_num, + int head_size, + int src_bs_stride, + int src_seq_stride, + int src_head_stride, + int dst_bs_stride, + int dst_seq_stride, + int dst_head_stride) { + int total_bs = bs * seq_len; + int hidden = head_num * head_size; + int core_average_tokens = (total_bs + taskDim - 1) / taskDim; + int core_begin_tokens = core_average_tokens * taskId; + int core_deal_tokens = std::min(total_bs - core_begin_tokens, core_average_tokens); + if (__is_mpu()) { + return; + } + if (core_deal_tokens <= 0) { + return; + } + + TScale *nram_scale = (TScale *)nram_buffer; + TScale *scale_origin = nram_scale + core_deal_tokens * head_num; + TSrc *nram_scale_temp = + (TSrc *)(nram_buffer + core_deal_tokens * head_num * (sizeof(TScale) - sizeof(TSrc))); + TSrc *nram_ping = (TSrc *)(scale_origin + core_deal_tokens * head_num); + TSrc *nram_temp = nram_ping + core_deal_tokens * hidden; + const TSrc *src_begin = src + core_begin_tokens * src_seq_stride; + TDst *dst_begin = dst + core_begin_tokens * dst_seq_stride; + TScale *scale_begin = scale + core_begin_tokens * head_num; + + // load + __memcpy(nram_ping, src_begin, head_size * sizeof(TSrc), GDRAM2NRAM, head_size * sizeof(TSrc), + head_num - 1, hidden * sizeof(TSrc), core_deal_tokens - 1, + src_head_stride * sizeof(TSrc), head_num - 1, src_seq_stride * sizeof(TSrc), + core_deal_tokens - 1); + + __bang_transpose(nram_temp, nram_ping, core_deal_tokens * head_num, head_size); + quantify((TDst *)nram_ping, nram_temp, nram_scale_temp, nram_scale, scale_origin, + core_deal_tokens * head_num, head_size); + // store scale + __memcpy(scale_begin, nram_scale, core_deal_tokens * head_num * sizeof(TScale), NRAM2GDRAM); + // store + __memcpy(dst_begin, nram_ping, head_size * sizeof(TDst), NRAM2GDRAM, + dst_head_stride * sizeof(TDst), head_num - 1, dst_seq_stride * sizeof(TDst), + core_deal_tokens - 1, head_size * sizeof(TDst), head_num - 1, hidden * sizeof(TDst), + core_deal_tokens - 1); +} +} // namespace kernels + +KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue, + void *dst, + void *scale, + const void *src, + cnnlDataType_t dst_dtype, + cnnlDataType_t scale_dtype, + cnnlDataType_t src_dtype, + int bs, + int seq_len, + int head_num, + int head_size, + int dst_bs_stride, + int dst_seq_stride, + int dst_head_stride, + int src_bs_stride, + int src_seq_stride, + int src_head_stride) { + // bs must be continuous, for pack mode, bs = 1, seq_len equals to sum of all bs seq_len. + if (dst_bs_stride != seq_len * dst_seq_stride) { + std::cerr + << "[invokeMluQuantizePerToken]: dst_bs_stride must equal to seq_len * dst_seq_stride." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (dst_head_stride != head_size) { + std::cerr << "[invokeMluQuantizePerToken]: dst_head_stride must equal to head_size." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (src_bs_stride != seq_len * src_seq_stride) { + std::cerr + << "[invokeMluQuantizePerToken]: src_bs_stride must equal to seq_len * src_seq_stride." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (src_head_stride != head_size) { + std::cerr << "[invokeMluQuantizePerToken]: src_head_stride must equal to head_size." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + + int scale_buffer_size = 64 * 1024; // scale on nram + int dtype_size = (src_dtype == CNNL_DTYPE_HALF || src_dtype == CNNL_DTYPE_BFLOAT16) ? 2 : 4; + int bs_once = (NRAM_BUFFER_SIZE - scale_buffer_size) / (2 * head_num * head_size * dtype_size); + int bs_once_ = scale_buffer_size / 2 / sizeof(float); + bs_once = std::min(bs_once, bs_once_); + uint32_t task_dim = std::min(bs * seq_len, cluster_num * core_num); + task_dim = std::max((uint32_t)(bs * seq_len + bs_once - 1) / bs_once, task_dim); + cnrtDim3_t dim{task_dim, 1, 1}; + if (src_dtype == CNNL_DTYPE_FLOAT) { + kernels::MLUQuantizePerHead<<>>( + (int8_t *)dst, (float *)scale, (const float *)src, bs, seq_len, head_num, head_size, + src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride, + dst_head_stride); + } else if (src_dtype == CNNL_DTYPE_HALF) { + kernels::MLUQuantizePerHead<<>>( + (int8_t *)dst, (float *)scale, (const half *)src, bs, seq_len, head_num, head_size, + src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride, + dst_head_stride); + } else if (src_dtype == CNNL_DTYPE_BFLOAT16) { + std::cerr << __func__ << "," << __LINE__ << " :currently does not support bfloat16." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mluh new file mode 100644 index 0000000..4ce56a2 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mluh @@ -0,0 +1,60 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_QUANTIZE_MLUH_ +#define CSRC_KERNELS_QUANTIZE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief quantize tensor by per head. + * @param queue: The queue for mlu. + * @param dst: Output. Pointer to the destination MLU memory. Shape is [bs, seq, head_num, + * head_size], may not be continuous. + * @param scale: Input. Pointer to the destination scale MLU memory. Shape is [bs, seq, head_num], + * must be continuous. + * @param src: Input. Pointer to the source MLU memory. Shape is [bs, seq, head_num, head_size], may + * not be continuous. + * @param dst_dtype: Data type of destination tensor. Must be int8. + * @param scale_dtype: Data type of destination scale tensor. Must be float32. + * @param src_dtype: Data type of src tensor. Must be float or half. + * @param bs: batch_size of dst or src tensor. + * @param seq_len: seq_len of dst or src tensor. + * @param head_num: head_num of dst or src tensor. + * @param head_size: head_size of dst or src tensor. + * @param dst_bs_stride: stride of batch_size dim of dst tensor. + * @param dst_seq_stride: stride of seq_len dim of dst tensor. + * @param dst_head_stride: stride of head_num dim of dst tensor. + * @param src_bs_stride: stride of batch_size dim of src tensor. + * @param src_seq_stride: stride of seq_len dim of src tensor. + * @param src_head_stride: stride of head_num dim of src tensor. + */ +KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue, + void *dst, + void *scale, + const void *src, + cnnlDataType_t dst_dtype, + cnnlDataType_t scale_dtype, + cnnlDataType_t src_dtype, + int bs, + int seq_len, + int head_num, + int head_size, + int dst_bs_stride, + int dst_seq_stride, + int dst_head_stride, + int src_bs_stride, + int src_seq_stride, + int src_head_stride); +} // namespace tmo + +#endif // CSRC_KERNELS_QUANTIZE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mlu new file mode 100644 index 0000000..e4d9d4a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mlu @@ -0,0 +1,134 @@ +#include "reshape_linear_cache.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { + +namespace kernels { + +// [head_num, batch, seq_seg] +__mlu_global__ void MLUReshapeLinearCacheKernel(int8_t *key_cache, + int8_t *value_cache, + int *cache_bs_offsets, + int *cache_seq_offsets, + int8_t *key, + int8_t *value, + int *context_seq_offsets, + int *context_lens, + int batch, + int head_num, + int head_size, + int max_context_len, + int cache_mem_len, + size_t context_bs_stride, + size_t context_head_stride, + size_t context_seq_stride, + size_t cache_bs_stride, + size_t cache_head_stride, + size_t cache_seq_stride, + bool packed, + int dtype_size, + int SEQ_BLOCK) { + int head_repeat = taskDimX > 1 ? 1 : head_num; + for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) { + int seq_offset = (packed || context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx]; + int task_seq_begin = taskIdZ * SEQ_BLOCK; + int seq_len = packed ? (context_lens[bs_idx + 1] - context_lens[bs_idx]) : context_lens[bs_idx]; + if (task_seq_begin >= seq_len) continue; + + int seq = std::min(seq_len - task_seq_begin, SEQ_BLOCK); + size_t context_offset = taskIdX * context_head_stride * dtype_size; + if (packed) { + context_offset += (context_lens[bs_idx] + task_seq_begin) * context_seq_stride * dtype_size; + } else { + context_offset += + (bs_idx * context_bs_stride + (task_seq_begin + seq_offset) * context_seq_stride) * + dtype_size; + } + + int cache_seq_offset = cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]; + int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx]; + if (cache_seq_offset < 0 || cache_bs_offset < 0) { + continue; + } + + cache_seq_offset += task_seq_begin; + if (key != nullptr && key_cache != nullptr) { + int8_t *key_cache_begin = + key_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride + + cache_seq_offset * cache_seq_stride) * + dtype_size; + int8_t *key_begin = key + context_offset; + __memcpy(key_cache_begin, key_begin, head_size * dtype_size, GDRAM2GDRAM, + cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size, + head_repeat - 1, context_seq_stride * dtype_size, seq - 1, + context_head_stride * dtype_size, head_repeat - 1); + } + + if (value != nullptr && value_cache != nullptr) { + int8_t *value_cache_begin = + value_cache + (cache_bs_offset * cache_bs_stride + taskIdX * cache_head_stride + + cache_seq_offset * cache_seq_stride) * + dtype_size; + int8_t *value_begin = value + context_offset; + __memcpy(value_cache_begin, value_begin, head_size * dtype_size, GDRAM2GDRAM, + cache_seq_stride * dtype_size, seq - 1, cache_head_stride * dtype_size, + head_repeat - 1, context_seq_stride * dtype_size, seq - 1, + context_head_stride * dtype_size, head_repeat - 1); + } + } +} +} // namespace kernels + +KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue, + void *key_cache, + void *value_cache, + const void *cache_bs_offsets, + const void *cache_seq_offsets, + void *key, + void *value, + const void *context_seq_offsets, + const void *context_lens, + const cnnlDataType_t dtype, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const int context_bs_stride, + const int context_head_stride, + const int context_seq_stride, + const int cache_bs_stride, + const int cache_head_stride, + const int cache_seq_stride, + const bool packed) { + constexpr int SEQ_BLOCK = 512; + int seq_seg = (max_context_len + SEQ_BLOCK - 1) / SEQ_BLOCK; + bool is_decoder_case = head_num * max_context_len < SEQ_BLOCK; + uint32_t task_x_dim = is_decoder_case ? 1 : head_num; + uint32_t task_y_dim = is_decoder_case ? std::min(batch, 48) : batch; + cnrtDim3_t dim{task_x_dim, task_y_dim, (uint32_t)seq_seg}; + + int dtype_size = 1; + if (dtype == CNNL_DTYPE_HALF || dtype == CNNL_DTYPE_BFLOAT16) { + dtype_size = 2; + } else if (dtype == CNNL_DTYPE_INT8) { + dtype_size = 1; + } else if (dtype == CNNL_DTYPE_FLOAT) { + dtype_size = 4; + } else { + std::cerr << "invokeReshapeLinearCache: unsupport dtype" << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + kernels::MLUReshapeLinearCacheKernel<<>>( + (int8_t *)key_cache, (int8_t *)value_cache, (int *)cache_bs_offsets, (int *)cache_seq_offsets, + (int8_t *)key, (int8_t *)value, (int *)context_seq_offsets, (int *)context_lens, batch, + head_num, head_size, max_context_len, cache_mem_len, context_bs_stride, context_head_stride, + context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed, dtype_size, + SEQ_BLOCK); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mluh new file mode 100644 index 0000000..2e85429 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_linear_cache.mluh @@ -0,0 +1,106 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_ +#define CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief In the context stage, concate the result of multi head attention + * key and value to key_cache and value_cache. + * @example + * input: + * cache: + * [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + * [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + * [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + * context: + * [[1, 2, 3, 4, 5], + * [6, 7, 8, 9, 10]] + * cache_bs_offsets: [1, 2] + * cache_seq_offsets: [3, 4] + * context_seq_offsets: [0, 1] + * context_lens: [4, 3] + * output: + * [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + * [0, 0, 0, 1, 2, 3, 4, 0, 0, 0], + * [0, 0, 0, 0, 7, 8, 9, 0, 0, 0]] + * @param queue: The queue for mlu. + * @param key_cache: Pointer to the MLU memory that stores the key cache, + * the shape must be [max_batch, head_num, cache_mem_len, head_size]. + * key_cache could be nullptr. +* @param value_cache: Pointer to the MLU memory that stores the value cache, + * the shape must be [max_batch, head_num, cache_mem_len, head_size]. + * value_cache could be nullptr. + * @param cache_bs_offsets: Pointer to the MLU memory that stores the batch + * offset of cache, the shape must be [batch], if it's nullptr, the + * default value is {0, 1, 2 ... batch - 1}. + * @param cache_seq_offsets: Input. Pointer to the MLU memory that stores the sequence + * offset of cache, the shape must be [batch], if it's nullptr, the + * default value is 0 for every batch. + * @param key: Pointer to the MLU memory that stores the key, + * the shape must be [batch, max_contxt_len, head_num, head_size]. + * key could be nullptr. + * @param value: Pointer to the MLU memory that stores the value, + * the shape must be [batch, max_contxt_len, head_num, head_size]. + * value could be nullptr. + * @param context_seq_offsets: Pointer to the MLU memory that stores the + * sequence offset of context, the shape must be [batch]. if it's nullptr, + * the default value is 0 for every batch. It must be nullptr when packed is true. + * @param context_lens: Input. Pointer to the MLU memory that stores the sequence length or + * cumulative sequence length of context. when packed is false, the shape must be [batch], which + * indicates sequence length of context. when packed is true, the shape must be [batch + 1], which + * indicates cumulative sequence length of context. + * @param dtype: Data type. + * @param batch: Batch size. + * @param head_num: Head number. + * @param head_size: Head size. + * @param max_contxt_len: The maximum sequence length of context. + * @param cache_mem_len: The maximum sequence length of cache. + * @param contxt_bs_stride: The stride of batch in context, does not work when packed is true. + * @param contxt_head_stride: The stride of head_num in context. + * @param contxt_seq_stride: The stride of max_contxt_len in context. + * @param cache_bs_stride: The stride of batch in cache. + * @param cache_head_stride: The stride of head_num in cache. + * @param cache_seq_stride: The stride of cache_mem_len in cache. + * @param packed: A boolean value indicates whether to use pack mode. + * @note If key and key_cache are nullptr, nothing todo for key. + If value and value_cache are nullptr, nothing todo for value. + A negative value in cache_bs_offsets or cache_seq_offsets means nothing to do for + the corresponding batch. + */ +KernelStatus invokeReshapeLinearCache(cnrtQueue_t queue, + void *key_cache, + void *value_cache, + const void *cache_bs_offsets, + const void *cache_seq_offsets, + void *key, + void *value, + const void *context_seq_offsets, + const void *context_lens, + const cnnlDataType_t dtype, + const int batch, + const int head_num, + const int head_size, + const int max_context_len, + const int cache_mem_len, + const int context_bs_stride, + const int context_head_stride, + const int context_seq_stride, + const int cache_bs_stride, + const int cache_head_stride, + const int cache_seq_stride, + const bool packed); +} // namespace tmo + +#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mlu new file mode 100644 index 0000000..05f014b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mlu @@ -0,0 +1,166 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "reshape_paged_cache.mluh" + +namespace tmo { +namespace kernels { + +#define NRAM_BUFFER_SIZE (480 * 1024) +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; +__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + +#define sizeof_(T) (uint32_t)sizeof(T) + +__mlu_global__ void MLUReshapePagedCacheKernel(int8_t *key, + int8_t *value, + int8_t *key_cache, + int8_t *value_cache, + int *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_size, + int head_size, + int dtype_size, + int seq_block) { +#if __BANG_ARCH__ > 500 + int seq_begin = taskId * seq_block; + if (seq_begin >= num_tokens) return; + int seq = std::min(seq_block, num_tokens - seq_begin); + + int head_bytes = head_size * dtype_size; + int head_stride = block_size * head_bytes; + int block_stride = num_heads * head_stride; + int hidden_bytes = num_heads * head_bytes; + + int8_t *nram_input = nram_buffer; + int *nram_token_offset = (int *)(nram_input + seq * hidden_bytes); + int pad_8_size = (num_heads * seq + 7) / 8 * 8; + int *nram_block_offset = nram_token_offset + pad_8_size; + int *nram_offset = nram_block_offset + pad_8_size; + int *nram_mask = nram_offset + pad_8_size; + + __memcpy(nram_offset, slot_mapping + seq_begin, seq * sizeof_(int), GDRAM2NRAM); + __bang_rem(nram_token_offset, nram_offset, (int)block_size, seq); + __bang_mul_scalar(nram_token_offset, nram_token_offset, head_bytes, seq); + __bang_div(nram_block_offset, nram_offset, (int)block_size, seq); + __bang_mul_scalar(nram_block_offset, nram_block_offset, block_stride, seq); + // (num_heads, seq) + __memcpy(nram_offset, nram_token_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0, + num_heads - 1); + // (num_heads, seq) -> (seq, num_heads) + __bang_transpose(nram_token_offset, nram_offset, num_heads, seq); + + // (num_heads, seq) + __memcpy(nram_offset, nram_block_offset, seq * sizeof_(int), NRAM2NRAM, seq * sizeof_(int), 0, + num_heads - 1); + // (num_heads, seq) -> (seq, num_heads) + __bang_transpose(nram_block_offset, nram_offset, num_heads, seq); + + __bang_write_zero(nram_offset, pad_8_size); + __bang_ge_bitindex((float *)nram_mask, (float *)nram_token_offset, (float *)nram_offset, + pad_8_size); + + // generate range: (0, head_stride, 2 * head_stride, ..., (num_heads - 1) * head_stride) + __memcpy(nram_offset, nram_range_32, std::min(num_heads, 32) * sizeof_(int), NRAM2NRAM); + int begin = 32; + while (begin < num_heads) { + int count = std::min(begin, num_heads - begin); + __bang_add_scalar(nram_offset + begin, nram_offset, begin, count); + begin += count; + } + __bang_mul_scalar(nram_offset, nram_offset, head_stride, num_heads); + + __bang_cycle_add(nram_token_offset, nram_token_offset, nram_offset, seq * num_heads, num_heads); + __bang_add(nram_offset, nram_token_offset, nram_block_offset, seq * num_heads); + + if (key != nullptr && key_cache != nullptr) { + // (seq, num_heads, head_size) + __memcpy(nram_input, key + seq_begin * key_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM, + hidden_bytes, key_stride0 * dtype_size, seq - 1); + __scatter(key_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM, + head_bytes, seq * num_heads); + } + + if (value != nullptr && value_cache != nullptr) { + __memcpy(nram_input, value + seq_begin * value_stride0 * dtype_size, hidden_bytes, GDRAM2NRAM, + hidden_bytes, value_stride0 * dtype_size, seq - 1); + __scatter(value_cache, nram_input, (uint32_t *)nram_offset, nram_mask, head_bytes, NRAM2GDRAM, + head_bytes, seq * num_heads); + } +#endif +} +} // namespace kernels + +KernelStatus invokeReshapePagedCache(cnrtQueue_t queue, + cnnlDataType_t data_type, + void *key, + void *value, + void *key_cache, + void *value_cache, + void *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_num, + int block_size, + int head_size) { + if (is_arch300()) { + std::cerr << "[invokeReshapePagedCache]: kernel does not support MLU300 devices." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + int dtype_size = 1; + if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) { + dtype_size = 2; + } else if (data_type == CNNL_DTYPE_INT8) { + dtype_size = 1; + } else if (data_type == CNNL_DTYPE_FLOAT) { + dtype_size = 4; + } else { + std::cerr << "invokeReshapePagedCache: unsupport data type\n"; + return KernelStatus::KERNEL_STATUS_FAILED; + } + int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size; + if (kv_cache_range > UINT32_MAX) { + std::cerr << "[invokeReshapePagedCache]: The addressing range of kv_cache cannot exceed 4G." + << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + constexpr int nram_size = 224 * 1024; + int hidden_bytes = num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int); + int seq_block = nram_size / hidden_bytes; + if (seq_block <= 0) { + std::cerr << "invokeReshapePagedCache: " + << "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) " + << "should be less than 224KB.\n"; + return KernelStatus::KERNEL_STATUS_FAILED; + } + if (seq_block > 16) { + seq_block = seq_block / 16 * 16; + } + uint32_t task_dim = (num_tokens + seq_block - 1) / seq_block; + task_dim = std::max(task_dim, (uint32_t)8); + task_dim = std::min(task_dim, (uint32_t)num_tokens); + cnrtDim3_t dim{task_dim, 1, 1}; + + kernels::MLUReshapePagedCacheKernel<<>>( + (int8_t *)key, (int8_t *)value, (int8_t *)key_cache, (int8_t *)value_cache, + (int *)slot_mapping, key_stride0, value_stride0, num_tokens, num_heads, block_size, head_size, + dtype_size, seq_block); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mluh new file mode 100644 index 0000000..ba925bf --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/reshape_paged_cache.mluh @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_ +#define CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_ + +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Perform reshape_paged_cache operation. + * @param handle: The handle of cnnl. + * @param key: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens, + * num_heads, head_size]. + * @param value: Pointer to the MLU memory that stores the value tensor which has shape [num_tokens, + * num_heads, head_size]. + * @param key_cache: Pointer to the MLU memory that stores the key_cache tensor which has shape + * [num_blocks, num_heads, block_size, head_size]. + * @param value_cache: Pointer to the MLU memory that stores the value_cache tensor which has shape + * [num_blocks, num_heads, block_size, head_size]. + * @param slot_mapping: Pointer to the MLU memory that stores the slot_mapping tensor which has + * shape [num_tokens]. Data type of slot mapping must be int32_t. + * @param key_stride0: The first dimension stride length of key_cache tensor. + * @param value_stride0: The first dimension stride length of value_cache tensor. + * @param num_tokens: Total number of tokens. + * @param num_heads: Head number. + * @param block_num: Total number of blocks. + * @param block_size: Number of tokens per block. + * @note: reshape_paged_cache does not support MLU300 device. + */ +KernelStatus invokeReshapePagedCache(cnrtQueue_t queue, + cnnlDataType_t data_type, + void *key, + void *value, + void *key_cache, + void *value_cache, + void *slot_mapping, + size_t key_stride0, + size_t value_stride0, + int num_tokens, + int num_heads, + int block_num, + int block_size, + int head_size); +} // namespace tmo + +#endif // CSRC_KERNELS_RESHAPE_PAGED_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mlu new file mode 100644 index 0000000..f1ea4a8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mlu @@ -0,0 +1,628 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include "rotary_embedding.mluh" +// clang-format off +#include +// clang-format on + +namespace tmo { +namespace kernels { +#define NRAM_REMAIN_SIZE (32 * 1024) +#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) + +__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; +__nram__ float nram_meta_mask[32] = {1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, + 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f}; +__nram__ float nram_mask[1024]; +__nram__ int nram_offsets[1024]; + +__mlu_func__ void loadTableAsync(void *nram_table, + void *gdram_table, + int *nram_offset, + int rotary_dim, + int rotary_stride, + int seq_block, + int seq_begin, + int dtype_size, + bool discrete, + bool decoder_mode) { + if (!discrete) { + int src_stride = decoder_mode ? 0 : rotary_stride * dtype_size; + __memcpy_async(nram_table, gdram_table, rotary_dim * dtype_size, GDRAM2NRAM, + rotary_dim * dtype_size, src_stride, seq_block - 1); + } else { +#if __BANG_ARCH__ >= 592 + __gather_async(nram_table, gdram_table, (uint32_t *)nram_offset, rotary_dim * dtype_size, + GDRAM2NRAM, rotary_dim * dtype_size, seq_block); +#else + for (int i = 0; i < seq_block; i++) { + __memcpy_async((int8_t *)nram_table + i * rotary_dim * dtype_size, + (int8_t *)gdram_table + nram_offset[i], rotary_dim * dtype_size, GDRAM2NRAM); + } +#endif + } +} + +template +__mlu_func__ void toFloat(float *dst, T *src, int count) { + if (std::is_same::value) { + __bang_half2float(dst, (half *)src, count); + } else if (std::is_same::value) { + __bang_bfloat162float(dst, (bfloat16_t *)src, count); + } +} + +template +__mlu_func__ void floatTo(T *dst, float *src, int count) { + if (std::is_same::value) { + __bang_float2half_rn((half *)dst, src, count); + } else if (std::is_same::value) { + __bang_float2bfloat16_rn((bfloat16_t *)dst, src, count); + } +} + +template +__mlu_func__ void initMask(float *mask, int rotary_dim, bool interleaved) { + if (interleaved) { + T *mask0 = (T *)mask; + T *mask1 = (T *)(mask + 512); + int seg = (rotary_dim + 31) / 32; + __memcpy(mask0, nram_meta_mask, 32 * sizeof(float), NRAM2NRAM, 32 * sizeof(float), 0, seg - 1); + floatTo((T *)mask0, (float *)mask0, rotary_dim); + __bang_add_scalar(mask1, mask0, (T)-1, rotary_dim); + } else { + __bang_write_value((T *)mask, rotary_dim / 2, (T)-1); + __bang_write_value((T *)mask + rotary_dim / 2, rotary_dim / 2, (T)1); + } +} + +/* + * half: mask, in, sl, sr + * float: sl, , sr, , sin, cos + */ +template +__mlu_func__ void crossRotaryEmbedding(T *output, + T *input, + T *sin_table, + T *cos_table, + int *seq_offsets, + int head_num, + int seq_block, + int head_size, + int rotary_dim, + int rotary_stride, + size_t input_head_stride, + size_t input_seq_stride, + size_t output_head_stride, + size_t output_seq_stride, + int seq_begin, + bool discrete, + bool decoder_mode = false) { + int float_size = sizeof(float); + int dtype_size = sizeof(T); + int seq_rotary = seq_block * rotary_dim; + int block_head = head_num * seq_rotary; + float *q_1 = (float *)nram_buffer; + float *sincos = q_1 + block_head + 2; + float *q_2 = sincos + block_head + 2; + T *temp = (T *)q_2 + block_head + 2; + + if (seq_offsets != nullptr && (discrete || decoder_mode)) { + __memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM); + __bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block); + } + bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete; + + T *mask0 = (T *)nram_mask; + T *mask1 = (T *)(nram_mask + 512); + + T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * (block_head + 2)); + T *sincos_ = (T *)((int8_t *)sincos + (float_size - dtype_size) * (block_head + 2)); + T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * (block_head + 2)); + // if dtype is float, temp point to a new buffer, and temp_ is temp; + // if dtype is half/bfloat16, temp is q_2_, and temp_ is (T*)q_2; + T *temp_ = dtype_size == 4 ? temp : (T *)q_2; + + // load input + __memcpy_async(q_1_, input, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, + seq_block - 1, seq_rotary * dtype_size, head_num - 1, + input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, + head_num - 1); + __bang_write_zero(q_2_ + block_head, 2); + __sync(); + + // copy input + __memcpy_async(q_2_, q_1_, block_head * dtype_size, NRAM2NRAM); + __bang_cycle_mul(temp_, q_1_, mask0, block_head, rotary_dim); + __sync(); + + // load cos + loadTableAsync(sincos_, cos_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin, + dtype_size, gather_table, decoder_mode); + + __bang_cycle_mul(q_2_, q_2_, mask1, block_head, rotary_dim); + + // rotary_input + __bang_add(q_2_ + 2, temp_, q_2_ + 2, block_head); + + toFloat(q_1, q_1_, block_head); + __sync(); + + toFloat(sincos, sincos_, block_head); + + // input * cos + __bang_cycle_mul(q_1, q_1, sincos, block_head, seq_rotary); + __sync(); + + toFloat(q_2, q_2_, block_head + 2); + + // load sin + loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin, + dtype_size, gather_table, decoder_mode); + __sync(); + + toFloat(sincos, sincos_, block_head); + + // rotary_input * sin + __bang_cycle_mul(q_2, q_2 + 1, sincos, block_head, seq_rotary); + + // input_cos + rotary_input_sin + __bang_add(q_1, q_1, q_2, block_head); + + floatTo((T *)q_1, q_1, block_head); + + if ((head_size - rotary_dim) > 0) { + __memcpy_async(output + rotary_dim, input + rotary_dim, (head_size - rotary_dim) * dtype_size, + GDRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1, + output_head_stride * dtype_size, head_num - 1, input_seq_stride * dtype_size, + seq_block - 1, input_head_stride * dtype_size, head_num - 1); + } + + // copy out + __memcpy(output, q_1, rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size, + seq_block - 1, output_head_stride * dtype_size, head_num - 1, rotary_dim * dtype_size, + seq_block - 1, seq_rotary * dtype_size, head_num - 1); +} + +template +__mlu_func__ void foldRotaryEmbedding(T *output, + T *input, + T *sin_table, + T *cos_table, + int *seq_offsets, + int head_num, + int seq_block, + int head_size, + int rotary_dim, + int rotary_stride, + size_t input_head_stride, + size_t input_seq_stride, + size_t output_head_stride, + size_t output_seq_stride, + int seq_begin, + bool discrete, + bool decoder_mode, + bool loop_head, + int once_head_num) { + once_head_num = loop_head ? once_head_num : head_num; + int loop_num = (head_num + once_head_num - 1) / once_head_num; + // int head_per_loop = loop_head ? 1 : head_num; + int seq_rotary = seq_block * rotary_dim; + int block_head = once_head_num * seq_rotary; + int buffer_blocks = loop_head ? 2 : 1; + + float *buffer = (float *)nram_buffer; + float *q_2 = buffer + block_head * buffer_blocks; + float *sin = q_2 + block_head; + float *cos = sin + seq_rotary; + + int float_size = sizeof(float); + int dtype_size = sizeof(T); + T *sincos_ = (T *)((int8_t *)sin + (float_size - dtype_size) * seq_rotary * 2); + T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * block_head); + if (seq_offsets != nullptr && (discrete || decoder_mode)) { + __memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM); + __bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block); + __sync_io_move_compute(); + } + bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete; + + int load_head_num = 0; + int calc_head_num = 0; + int store_head_num = 0; + for (int i = 0; i < loop_num + 2; i++) { + // store + if (i > 1) { + store_head_num = std::min(once_head_num, head_num - (i - 2) * once_head_num); + if ((head_size - rotary_dim) > 0) { + __memcpy_async(output + (i - 2) * once_head_num * output_head_stride + rotary_dim, + input + (i - 2) * once_head_num * input_head_stride + rotary_dim, + (head_size - rotary_dim) * dtype_size, GDRAM2GDRAM, + output_seq_stride * dtype_size, seq_block - 1, + output_head_stride * dtype_size, store_head_num - 1, + input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, + store_head_num - 1); + } + float *nram_store = buffer + (i % 2) * block_head; + __memcpy_async(output + (i - 2) * once_head_num * output_head_stride, nram_store, + rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size, + seq_block - 1, output_head_stride * dtype_size, store_head_num - 1, + rotary_dim * dtype_size, seq_block - 1, seq_block * rotary_dim * dtype_size, + store_head_num - 1); + } + // load + float *temp_load = buffer + (i % 2) * block_head; + T *nram_load = (T *)((int8_t *)temp_load + (float_size - dtype_size) * block_head); + if (i < loop_num) { + load_head_num = std::min(once_head_num, head_num - i * once_head_num); + __memcpy_async(nram_load, input + i * once_head_num * input_head_stride, + rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block - 1, + seq_block * rotary_dim * dtype_size, load_head_num - 1, + input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, + load_head_num - 1); + } + if (i == 1) { + loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, + seq_begin, dtype_size, gather_table, decoder_mode); + loadTableAsync(sincos_ + seq_rotary, cos_table, nram_offsets, rotary_dim, rotary_stride, + seq_block, seq_begin, dtype_size, gather_table, decoder_mode); + } + // compute + if (i > 0 && i < loop_num + 1) { + float *q_1 = buffer + ((i + 1) % 2) * block_head; + T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * block_head); + calc_head_num = std::min(once_head_num, head_num - (i - 1) * once_head_num); + __memcpy_async(q_2_, q_1_ + rotary_dim / 2, rotary_dim / 2 * dtype_size, NRAM2NRAM, + rotary_dim * dtype_size, rotary_dim * dtype_size, + calc_head_num * seq_block - 1); + __memcpy_async(q_2_ + rotary_dim / 2, q_1_, rotary_dim / 2 * dtype_size, NRAM2NRAM, + rotary_dim * dtype_size, rotary_dim * dtype_size, + calc_head_num * seq_block - 1); + __sync_move(); + + toFloat(q_1, q_1_, block_head); + __bang_cycle_mul(q_2_, q_2_, (T *)nram_mask, block_head, rotary_dim); + toFloat(q_2, q_2_, block_head); + if (i == 1) { + __sync_io(); + toFloat(sin, sincos_, seq_rotary * 2); + } + __bang_cycle_mul(q_1, q_1, cos, block_head, seq_rotary); + __bang_cycle_mul(q_2, q_2, sin, block_head, seq_rotary); + __bang_add(q_1, q_1, q_2, block_head); + floatTo((T *)q_1, q_1, block_head); + } + __sync_io_move_compute(); + } +} + +// [bs, seq_block] +template +__mlu_global__ void MluRotaryEmebdding(void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_dim, + int seq_once, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool discrete, + bool dynamic_ntk, + bool decoder_mode, + bool loop_head, + int once_head_num) { + initMask(nram_mask, rotary_dim, interleaved); + + int head_begin = taskIdX; + int head_per_task = taskDimX == 1 ? head_num : 1; + // decode mode little diff: no loop + if (decoder_mode) { + int task_begin_seq = taskIdY * seq_once; + int seq_block = std::min(batch - task_begin_seq, seq_once); + if (seq_block <= 0 || __is_mpu()) { + return; + } + size_t input_offset = task_begin_seq * input_seq_stride + head_begin * input_head_stride; + size_t output_offset = task_begin_seq * output_seq_stride + head_begin * output_head_stride; + T *input_begin = (T *)input + input_offset; + T *output_begin = (T *)output + output_offset; + if (interleaved) { + crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table, + (int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim, + rotary_stride, input_head_stride, input_seq_stride, output_head_stride, + output_seq_stride, task_begin_seq, discrete, decoder_mode); + } else { + foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table, + (int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim, + rotary_stride, input_head_stride, input_seq_stride, output_head_stride, + output_seq_stride, task_begin_seq, discrete, decoder_mode, loop_head, + once_head_num); + } + return; + } + int seq_begin = cu_seq_lens == nullptr ? taskIdY * max_seq_len : cu_seq_lens[taskIdY]; + int seq_len = cu_seq_lens == nullptr ? max_seq_len : cu_seq_lens[taskIdY + 1] - seq_begin; + + for (int i = taskIdZ * seq_once; i < seq_len; i += taskDimZ * seq_once) { + int seq_block = std::min(seq_once, seq_len - i); + int global_seq_begin = seq_begin + i; + int seq_block_begin = i; + size_t input_offset = global_seq_begin * input_seq_stride + head_begin * input_head_stride; + size_t output_offset = global_seq_begin * output_seq_stride + head_begin * output_head_stride; + size_t bs_table_offset = dynamic_ntk ? (size_t)taskIdY * rotary_seq_len * rotary_stride : 0; + T *input_begin = (T *)input + input_offset; + T *output_begin = (T *)output + output_offset; + T *sin_table_begin = (T *)sin_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride; + T *cos_table_begin = (T *)cos_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride; + if (seq_offsets != nullptr && !discrete) { + sin_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride; + cos_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride; + } else if (seq_offsets != nullptr && discrete) { + sin_table_begin = (T *)sin_table + bs_table_offset; + cos_table_begin = (T *)cos_table + bs_table_offset; + } + if (interleaved) { + crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin, + (T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block, + head_size, rotary_dim, rotary_stride, input_head_stride, + input_seq_stride, output_head_stride, output_seq_stride, + global_seq_begin, discrete, decoder_mode); + __sync_io_move_compute(); + } else { + foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin, + (T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block, + head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride, + output_head_stride, output_seq_stride, global_seq_begin, discrete, + decoder_mode, loop_head, once_head_num); + } + } +} + +#if __BANG_ARCH__ < 592 +template <> +__mlu_global__ void MluRotaryEmebdding(void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_dim, + int seq_once, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool discrete, + bool dynamic_ntk, + bool decoder_mode, + bool loop_head, + int once_head_num) {} + +template <> +__mlu_global__ void MluRotaryEmebdding(void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_dim, + int seq_once, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool discrete, + bool dynamic_ntk, + bool decoder_mode, + bool loop_head, + int once_head_num) {} +#endif +} // namespace kernels + +KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue, + void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_dim, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool interleaved, + bool discrete, + bool dynamic_ntk, + cnnlDataType_t data_type) { + void (*rotary_embedding_kernels[])(void *, /* output */ + const void *, /* input */ + const void *, /* sin_table */ + const void *, /* cos_table */ + const int *, /* seq_offsets */ + const int *, /* cu_seq_lens */ + int, /* batch */ + int, /* max_seq_len */ + int, /* head_num */ + int, /* head_size */ + int, /* rotary_seq_len */ + int, /* rotary_dim */ + int, /* seq_once */ + int, /* rotary_stride */ + size_t, /* input_seq_stride */ + size_t, /* input_head_stride */ + size_t, /* output_seq_stride */ + size_t, /* output_head_stride */ + bool, /* discrete, */ + bool, /* dynamic_ntk */ + bool, /* decoder_mode */ + bool, /* loop_head */ + int) /* once_head_num */ + = {kernels::MluRotaryEmebdding, + kernels::MluRotaryEmebdding, + kernels::MluRotaryEmebdding, + kernels::MluRotaryEmebdding, + kernels::MluRotaryEmebdding, + kernels::MluRotaryEmebdding}; + + int kernel_index = 0; + if (data_type == CNNL_DTYPE_HALF) { + kernel_index = interleaved ? 0 : 1; + } else if (data_type == CNNL_DTYPE_BFLOAT16) { + kernel_index = interleaved ? 2 : 3; + } else if (data_type == CNNL_DTYPE_FLOAT) { + kernel_index = interleaved ? 4 : 5; + } + if (head_size > 256) { + std::cerr << "[invokeRotaryEmbedding]: only supported head_size <= 256, currently head_size = " + << head_size << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + CNdev dev; + cnCtxGetDevice(&dev); + int cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + int core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + int total_core_num = cluster_num * core_num; + + uint32_t seq_once = data_type == CNNL_DTYPE_FLOAT ? (rotary_dim > 128 ? 64 : 128) + : (rotary_dim > 128 ? 128 : 256); + // decode场景,需要判断空间是否够,fold场景下最大限制为每个ipu处理64,cross限制为batch*head小于等于sq_once + int batch_per_core = (batch + total_core_num - 1) / total_core_num; + int batch_per_core_cap = 64; + bool batch_limit = interleaved ? (batch_per_core * head_num <= seq_once) + : (batch_per_core <= batch_per_core_cap); + bool decoder_mode = batch_limit && max_seq_len == 1 && dynamic_ntk == false; + + bool do_one_head_per_task = (head_num > 32 && max_seq_len > 2048) || head_num > seq_once; + seq_once = do_one_head_per_task ? seq_once : seq_once / head_num; + + // fold rotary做了流水,拆分有所不同。 + bool loop_head = true; + int once_head_num = 1; + if (!interleaved) { + seq_once = rotary_dim > 128 ? 64 : 128; + // 小seq情况下,不够拆,需要减小seq_once + if (batch * (max_seq_len + seq_once - 1) / seq_once < total_core_num) { + seq_once = std::max(1, max_seq_len / (total_core_num / batch)); + } + do_one_head_per_task = false; + // 判断decode场景能否一次性处理完所有head + if (decoder_mode) { + loop_head = false; + int nram_buffer_size = 480 * 1024; + int nram_input_size = batch_per_core * head_num * rotary_dim * sizeof(float); + int nram_q2_size = batch_per_core * head_num * rotary_dim * sizeof(float); + int nram_table_size = batch_per_core * rotary_dim * sizeof(float) * 2; + int total_nram_size = nram_input_size + nram_q2_size + nram_table_size; + loop_head = total_nram_size > nram_buffer_size; + if (loop_head) { + // 如果需要循环,则重新计算每次处理多少头 + once_head_num = (nram_buffer_size - nram_table_size) / + (batch_per_core * rotary_dim * sizeof(float) * 3); + } + // rebalance + int loop_num = (head_num + once_head_num - 1) / once_head_num; + once_head_num = (head_num + loop_num - 1) / loop_num; + } + } + + uint32_t seq_segments = ((uint32_t)max_seq_len + seq_once - 1) / seq_once; + uint32_t task_dimx = do_one_head_per_task ? head_num : 1; + uint32_t task_dimz = total_core_num > seq_segments ? seq_segments : total_core_num; + uint32_t task_dimy = + decoder_mode && !do_one_head_per_task ? (uint32_t)total_core_num : (uint32_t)batch; + seq_once = decoder_mode ? (batch + task_dimy - 1) / task_dimy : seq_once; + cnrtDim3_t dim = {task_dimx, task_dimy, task_dimz}; + + if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) { + std::cerr << "[invokeRotaryEmbedding]: MLU300 devices do not support bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + rotary_embedding_kernels[kernel_index]<<>>( + output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, max_seq_len, head_num, + head_size, rotary_seq_len, rotary_dim, seq_once, rotary_stride, input_seq_stride, + input_head_stride, output_seq_stride, output_head_stride, discrete, dynamic_ntk, decoder_mode, + loop_head, once_head_num); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} + +KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue, + void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int total_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool interleaved, + cnnlDataType_t data_type) { + size_t type_size = 0; + cnnlGetSizeOfDataType(data_type, &type_size); + invokeRotaryEmbedding(queue, output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, + max_seq_len, head_num, head_size / 2, rotary_seq_len, head_size / 2, + rotary_stride, input_seq_stride, input_head_stride, output_seq_stride, + output_head_stride, interleaved, true, false, data_type); + + invokeRotaryEmbedding(queue, (int8_t *)output + head_size / 2 * type_size, + (int8_t *)input + head_size / 2 * type_size, sin_table, cos_table, + seq_offsets + total_seq_len, cu_seq_lens, batch, max_seq_len, head_num, + head_size / 2, rotary_seq_len, head_size / 2, rotary_stride, + input_seq_stride, input_head_stride, output_seq_stride, output_head_stride, + interleaved, true, false, data_type); + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mluh new file mode 100644 index 0000000..42780a6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mluh @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_ +#define CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Apply rotary embedding. + * @param queue: The queue for mlu. + * @param output: Output. Pointer to the MLU memory that stores the output, + * the shape must be [total_seq_len, head_num, head_size] + * @param input: Input. Pointer to the MLU memory that stores the input + * the shape must be [total_seq_len, head_num, head_size]. + * @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be + * continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If + * dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim]. + * @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be + * continous. If dynamic_ntk is true, the shape must be [batch, rotary_seq_len, rotary_dim]. If + * dynamic_ntk is false, the shape must be [rotary_seq_len, rotary_dim]. + * @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each + * batch. If discrete is true, the shape must be [total_seq_len]. If discrete is false, the shape + * must be [batch]. Seq_offsets could be nullptr if discrete is false, which means no offset for + * each batch. + * @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length + * of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all + * batches is max_seq_Len. + * @param batch: Batch size. + * @param max_seq_len: The maximum sequence length of input. + * @param head_num: Head number. + * @param head_size: Head size. + * @param rotary_seq_len: The rotary seq_len of sin_table and cos_table. + * @param rotary_dim: The rotary dimension of sin_table and cos_table. + * @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table. + * @param input_seq_stride: The stride of total_seq_len in input. + * @param input_head_stride: The stride of head_num in input. + * @param output_seq_stride: The stride of total_seq_len in output. + * @param output_head_stride: The stride of head_num in output. + * @param interleaved: A boolean value indicates compute mode of rotary embedding. + * @param discrete: A boolean value indicates whether all input tokens have offsets. + * @param dynamic_ntk: A boolean value indicates whether all batches have different sin_table and + * cos_table. + * @param data_type: Data type of all inputs and outputs. + */ +KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue, + void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_dim, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool interleaved, + bool discrete, + bool dynamic_ntk, + cnnlDataType_t data_type); + +/** + * @brief Apply rotary embedding. + * @param queue: The queue for mlu. + * @param output: Output. Pointer to the MLU memory that stores the output, + * the shape must be [total_seq_len, head_num, head_size] + * @param input: Input. Pointer to the MLU memory that stores the input + * the shape must be [total_seq_len, head_num, head_size]. + * @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be + * continous. The shape must be [rotary_seq_len, head_size / 2]. + * @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be + * continous. The shape must be [rotary_seq_len, head_size / 2]. + * @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each + * batch. The Shape must be [2, total_seq_len]. + * @param cu_seq_lens: Input. Pointer to the MLU memory that stores the cumulative sequence length + * of each batch. The shape must be [batch + 1]. If cu_seq_lens is nullptr, Sequence length of all + * batches is max_seq_Len. + * @param batch: Batch size. + * @param max_seq_len: The maximum sequence length of input. + * @param head_num: Head number. + * @param head_size: Head size. + * @param rotary_seq_len: The rotary seq_len of sin_table and cos_table. + * @param rotary_stride: The stride of rotary_seq_len stride in sin_table and cos_table. + * @param input_seq_stride: The stride of total_seq_len in input. + * @param input_head_stride: The stride of head_num in input. + * @param output_seq_stride: The stride of total_seq_len in output. + * @param output_head_stride: The stride of head_num in output. + * @param interleaved: A boolean value indicates compute mode of rotary embedding. + * @param data_type: Data type of all inputs and outputs. + */ +KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue, + void *output, + const void *input, + const void *sin_table, + const void *cos_table, + const int *seq_offsets, + const int *cu_seq_lens, + int batch, + int max_seq_len, + int total_seq_len, + int head_num, + int head_size, + int rotary_seq_len, + int rotary_stride, + size_t input_seq_stride, + size_t input_head_stride, + size_t output_seq_stride, + size_t output_head_stride, + bool interleaved, + cnnlDataType_t data_type); +} // namespace tmo + +#endif // CSRC_KERNELS_ROTARY_EMBEDDING_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mlu new file mode 100644 index 0000000..628c530 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mlu @@ -0,0 +1,22 @@ +#include "swap_blocks.mluh" + +namespace tmo { +KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle, + void *dst, + const void *src, + const int64_t &block_size_in_bytes, + const cnrtMemTransDir_t &memcpy_type, + const std::map &block_mapping) { + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + for (const auto &pair : block_mapping) { + int64_t src_block_number = pair.first; + int64_t dst_block_number = pair.second; + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cnrtMemcpyAsync((int8_t *)dst + dst_offset, (int8_t *)src + src_offset, block_size_in_bytes, + queue, memcpy_type); + } + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mluh new file mode 100644 index 0000000..501d9b6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/swap_blocks.mluh @@ -0,0 +1,39 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_SWAP_BLOCKS_MLUH_ +#define CSRC_KERNELS_SWAP_BLOCKS_MLUH_ + +#include +#include "cnnl.h" +#include "kernel_utils.h" + +namespace tmo { +/** + * @brief Perform swap_blocks operation. + * @param handle: The handle of cnnl. + * @param dst: Output. Pointer to the MLU memory that stores the dst tensor which has shape + * [num_blocks, num_heads, block_size, head_size]. + * @param src: Input. Pointer to the MLU memory that stores the src tensor which has shape + * [num_blocks, num_heads, block_size, head_size]. + * @param block_size_in_bytes: Data block size for each copy. + * @param memcpy_type: Copy direction, including h2d, d2h and d2d. + * @param block_mapping: Mapping table of src and dst. + */ +KernelStatus invokeSwapBlocksKernel(const cnnlHandle_t handle, + void *dst, + const void *src, + const int64_t &block_size_in_bytes, + const cnrtMemTransDir_t &memcpy_type, + const std::map &block_mapping); +} // namespace tmo + +#endif // CSRC_KERNELS_SWAP_BLOCKS_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/update_out_and_lse.mlu b/torch_mlu_ops-v1.3.2/csrc/kernels/update_out_and_lse.mlu new file mode 100644 index 0000000..27bea18 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/update_out_and_lse.mlu @@ -0,0 +1,447 @@ +// clang-format off +#include +// clang-format on +#include "kernel_utils.h" +#include "update_out_and_lse.mluh" + +namespace tmo { +namespace kernels { +#define NRAM_REMAIN_SIZE (32 * 1024) +#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) +#define INF (2139095040) +__nram__ char nram_buffer[NRAM_BUFFER_SIZE]; + +__mlu_func__ void splitTask(int32_t total_task, int32_t &task_length, int32_t &task_offset) { + task_length = (total_task + taskDimX - 1) / taskDimX; + task_offset = taskIdX * task_length; + task_length = std::min(total_task - task_offset, task_length); +} + +template +__mlu_func__ void toFloat(float *dst, T *src, int32_t num) { + if (std::is_same::value) { + __bang_half2float(dst, (half *)src, num); + } else if (std::is_same::value) { + __bang_bfloat162float(dst, (bfloat16_t *)src, num); + } else if (std::is_same::value) { + __bang_move(dst, src, num * sizeof(float)); + } +} + +template +__mlu_func__ void floatTo(T *dst, float *src, int32_t num) { + if (std::is_same::value) { + __bang_float2half_rn((half *)dst, src, num); + } else if (std::is_same::value) { + __bang_float2bfloat16_rn((bfloat16_t *)dst, src, num); + } else if (std::is_same::value) { + __bang_move(dst, src, num * sizeof(float)); + } +} + +template +__mlu_func__ void subCvt(T *dst, float *src0, float *src1, int32_t num) { +#if __BANG_ARCH__ >= 500 + if (std::is_same::value) { + __asm__("sub.nram.crn.f16.f32 [%[dst]], [%[src0]], [%[src1]], %[num];" ::[dst] "r"(dst), + [src0] "r"(src0), [src1] "r"(src1), [num] "r"(num)); + } else if (std::is_same::value) { + __asm__("sub.nram.crn.bf16.f32 [%[dst]], [%[src0]], [%[src1]], %[num];" ::[dst] "r"(dst), + [src0] "r"(src0), [src1] "r"(src1), [num] "r"(num)); + } else { + __asm__("sub.nram.crn.f32 [%[dst]], [%[src0]], [%[src1]], %[num];" ::[dst] "r"(dst), + [src0] "r"(src0), [src1] "r"(src1), [num] "r"(num)); + } +#else + __bang_sub((float *)dst, src0, src1, num); + floatTo(dst, (float *)dst, num); +#endif +} + +// task_length在非decoder模式下,为每个实际的block_seq_len的长度;在decoder模式下,是batch*head_num的长度。 +// out_task_stride/block_task_stride在非deocder模式下,为seq_stride;在decoder模式下,是head_stride。 +// out = out - sigmoid(block_lse - lse) * (out - block_out) +// lse = lse - logsigmoid(lse - block_lse) +// sigmoid: y = 1 / (1 + e ^ (-x)) +// logsigmoid: y = log(1 / (1 + e ^ (-x))) = -log(1 + e ^ (-x)) +template +__mlu_func__ void updateOutAndLse(T *out, + float *lse, + const T *block_out, + const float *block_lse, + int32_t head_size, + int32_t task_length, + int64_t out_task_stride, + int64_t block_task_stride) { + constexpr int32_t tmp = 0x3fb8aa3b; + const float log2e = *(float *)&tmp; + const float neg_log2e = (-1) * log2e; + const float recip_log2e = 1 / log2e; + + constexpr bool is_fp32 = std::is_same::value; + constexpr int32_t buffer_num = is_fp32 ? 5 : 3; + constexpr int32_t pad_length = is_fp32 ? 16 : 32; + + int32_t task_once = + PAD_DOWN(NRAM_BUFFER_SIZE / ((head_size * buffer_num + 3) * sizeof(float)), pad_length); + int32_t task_loop = (task_length + task_once - 1) / task_once; + if (task_loop > 1) { + task_once = PAD_UP((task_length + task_loop - 1) / task_loop, pad_length); + } + float *nram_out = (float *)nram_buffer; + float *nram_block_out = nram_out + head_size * task_once * (is_fp32 + 1); + float *nram_lse = nram_block_out + head_size * task_once * (is_fp32 + 1); + float *nram_block_lse = nram_lse + task_once; + float *nram_sigmoid_result = nram_block_lse + task_once; + __attribute__((unused)) float *nram_end = nram_sigmoid_result + task_once; + + int32_t task_deal{task_once}; + + for (int32_t task_i = 0; task_i < task_loop; ++task_i) { + task_deal = std::min(task_once, task_length - task_i * task_once); + /* + 读取 lse block_lse + */ + __memcpy_async(nram_lse, lse, task_deal * sizeof(float), GDRAM2NRAM); + __memcpy_async(nram_block_lse, block_lse, task_deal * sizeof(float), GDRAM2NRAM); + __sync_io(); + + /* + 计算 lse + 读取 out + block_lse = (block_lse - lse) * -log2e + sigmoid_result = 1 / (pow2(block_lse) + 1) + */ + __bang_fusion(FUSION_FSM, nram_block_lse, nram_block_lse, nram_lse, neg_log2e, task_deal, + task_deal); + __bang_pow2(nram_sigmoid_result, nram_block_lse, task_deal); + __bang_add_scalar(nram_sigmoid_result, nram_sigmoid_result, (float)1, task_deal); + __bang_recip(nram_sigmoid_result, nram_sigmoid_result, task_deal); + __memcpy_async(nram_out, out, head_size * sizeof(T), GDRAM2NRAM, head_size * sizeof(T), + out_task_stride * sizeof(T), task_deal - 1); + + __sync_io(); + /* + 转置和升位宽 out + 读取 block_out + */ + __bang_transpose((T *)nram_out + task_deal * head_size, (T *)nram_out, task_deal, head_size); + toFloat(nram_out, (T *)nram_out + task_deal * head_size, task_deal * head_size); + __memcpy_async(nram_block_out, block_out, head_size * sizeof(T), GDRAM2NRAM, + head_size * sizeof(T), block_task_stride * sizeof(T), task_deal - 1); + __sync_io(); + /* + 转置和升位宽 block_out + */ + __bang_transpose((T *)nram_block_out + task_deal * head_size, (T *)nram_block_out, task_deal, + head_size); + toFloat(nram_block_out, (T *)nram_block_out + task_deal * head_size, task_deal * head_size); + /* + bang_fusor的计算流: + ((out - block_out) * sigmoid_result * -1 + out).tofp16() + */ + __bang_sub(nram_block_out, nram_out, nram_block_out, task_deal * head_size); + __bang_cycle_mul(nram_block_out, nram_block_out, nram_sigmoid_result, task_deal * head_size, + task_deal); + subCvt((T *)nram_out, nram_out, nram_block_out, task_deal * head_size); + __bang_transpose((T *)nram_out + task_deal * head_size, (T *)nram_out, head_size, task_deal); + __sync_compute(); + + __memcpy_async(out, (T *)nram_out + task_deal * head_size, head_size * sizeof(T), NRAM2GDRAM, + out_task_stride * sizeof(T), head_size * sizeof(T), task_deal - 1); + + /* + 算法上: lse = lse - logsigmoid(lse - block_lse) + = lse - (-log(1 + e ^ (-(lse - block_lse)))) + = lse + log(1 + e ^ (block_lse - lse)) + 之前block_lse = (block_lse - lse) * -log2e + 实际逻辑如下: + block_lse = block_lse * -1 + = (block_lse - lse) * log2e + block_lse = log2(pow2(block_lse) + 1) / log2e + lse + = lse + log(1 + e ^ (block_lse - lse)) + */ + __bang_mul_scalar(nram_block_lse, nram_block_lse, -1, task_deal); + __bang_pow2(nram_sigmoid_result, nram_block_lse, task_deal); + __bang_add_scalar(nram_sigmoid_result, nram_sigmoid_result, (float)1.0f, task_deal); + __bang_log2(nram_sigmoid_result, nram_sigmoid_result, task_deal); + __bang_mul_scalar(nram_sigmoid_result, nram_sigmoid_result, recip_log2e, task_deal); + + /* + nram_sigmoid_result 中的值为log(1 + e ^ (block_lse - lse)) + 在一些数值分布场景,例如block_lse - lse大于85左右时,这个值会出现inf。 + gpu采用的是std::log1p, + 在原始公式中,log(1 / (1 + e ^ (block_lse - lse)))中的log里的数值会极限接近0, + 采用log1p会比普通log在靠近0时有更高精度。 + mlu这里由于做了对数倒数外提*-1,所以log里的数值会变为inf(如果不外提,log2(0)同样会出现inf)。 + + logsigmoid在大数值场景下,是等于原值的,例如logsigmoid(-100) = -100 + 所以以下逻辑用于 识别inf的值,对inf的位置进行写入原值。 + */ + + __bang_ne_scalar((uint32_t *)nram_block_out, (uint32_t *)nram_sigmoid_result, (uint32_t)INF, + task_deal); + __bang_mul((uint32_t *)nram_sigmoid_result, (uint32_t *)nram_sigmoid_result, + (uint32_t *)nram_block_out, task_deal); + __bang_not((uint32_t *)nram_block_out, (uint32_t *)nram_block_out, task_deal); + __bang_mul((uint32_t *)nram_block_lse, (uint32_t *)nram_block_lse, (uint32_t *)nram_block_out, + task_deal); + __bang_mul_scalar(nram_block_lse, nram_block_lse, recip_log2e, + task_deal); // block_lse里的原值是*了log2e的,这里需要除回去。 + __bang_fusion(FUSION_FAA, nram_lse, nram_lse, nram_sigmoid_result, nram_block_lse, task_deal, + task_deal); + + __sync_compute(); + __memcpy_async(lse, nram_lse, task_deal * sizeof(float), NRAM2GDRAM); + + lse += task_deal; + block_lse += task_deal; + out += task_deal * out_task_stride; + block_out += task_deal * out_task_stride; + } +} + +// 非decoder模式,采用taskDimZ拆分batch,taskDimY拆分head,每个task内部进行block_seq_len的循环 +template +__mlu_global__ void MluUpdateOutAndLse(void *out, + float *lse, + const void *block_out, + const float *block_lse, + const int32_t *seq_offsets, + const int32_t *cu_seqs, + const int32_t *block_cu_seqs, + const int32_t batch, + const int32_t head_num, + const int32_t head_size, + const int32_t max_seq_len, + const int32_t block_seq_len, + const bool packed, + const int64_t bs_stride, + const int64_t seq_stride, + const int64_t head_stride, + const int64_t block_bs_stride, + const int64_t block_seq_stride, + const int64_t block_head_stride) { +#if __BANG_ARCH__ > 300 + if (!(std::is_same::value && __BANG_ARCH__ < 500)) { + int64_t kernel_out_offset = 0; + int64_t kernel_block_out_offset = 0; + int64_t kernel_lse_offset = 0; + int64_t kernel_block_lse_offset = 0; + int64_t kernel_seq_offset = 0; + + int32_t block_seq_len_real{block_seq_len}; + + if (seq_offsets != nullptr) { + kernel_seq_offset = __load_gdram(seq_offsets + taskIdZ); + } + + if (!packed) { + kernel_out_offset = + taskIdZ * bs_stride + taskIdY * head_stride + kernel_seq_offset * seq_stride; + kernel_block_out_offset = taskIdZ * block_bs_stride + taskIdY * block_head_stride; + kernel_lse_offset = + taskIdZ * max_seq_len * head_num + taskIdY * max_seq_len + kernel_seq_offset; + kernel_block_lse_offset = taskIdZ * block_seq_len * head_num + taskIdY * block_seq_len; + } else { + int32_t block_seq_begin = __load_gdram(block_cu_seqs + taskIdZ); + int32_t block_seq_end = __load_gdram(block_cu_seqs + taskIdZ + 1); + int32_t out_seq_begin = __load_gdram(cu_seqs + taskIdZ); + block_seq_len_real = block_seq_end - block_seq_begin; + kernel_out_offset = (out_seq_begin + kernel_seq_offset) * seq_stride + taskIdY * head_stride; + kernel_block_out_offset = block_seq_begin * block_seq_stride + taskIdY * head_stride; + kernel_lse_offset = + taskIdZ * max_seq_len * head_num + taskIdY * max_seq_len + kernel_seq_offset; + kernel_block_lse_offset = taskIdZ * block_seq_len * head_num + taskIdY * block_seq_len; + } + + auto kernel_out = (T *)out + kernel_out_offset; + auto kernel_lse = lse + kernel_lse_offset; + auto kernel_block_out = (T *)block_out + kernel_block_out_offset; + auto kernel_block_lse = block_lse + kernel_block_lse_offset; + + updateOutAndLse(kernel_out, kernel_lse, kernel_block_out, kernel_block_lse, head_size, + block_seq_len_real, seq_stride, block_seq_stride); + } +#endif +} + +// decoder模式下,采用launch 所有core,内部拆分batch*head_num维度, +// 每个core处理 batch*head_num/taskDimX +template +__mlu_global__ void MluUpdateOutAndLseDecoder(void *out, + float *lse, + const void *block_out, + const float *block_lse, + const int32_t *seq_offsets, + const int32_t *cu_seqs, + const int32_t *block_cu_seqs, + const int32_t batch, + const int32_t head_num, + const int32_t head_size, + const int32_t max_seq_len, + const int32_t block_seq_len, + const bool packed, + const int64_t bs_stride, + const int64_t seq_stride, + const int64_t head_stride, + const int64_t block_bs_stride, + const int64_t block_seq_stride, + const int64_t block_head_stride) { +#if __BANG_ARCH__ > 300 + if (!(std::is_same::value && __BANG_ARCH__ < 500)) { + int32_t task_length{0}, task_begin{0}; + splitTask(batch * head_num, task_length, task_begin); + if (task_length <= 0) { + return; + } + + int32_t batch_idx = task_begin / head_num; + int32_t head_idx = task_begin % head_num; + + auto kernel_out_offset = batch_idx * bs_stride + head_idx * head_stride; + auto kernel_block_out_offset = batch_idx * block_bs_stride + head_idx * block_head_stride; + auto kernel_lse_offset = batch_idx * max_seq_len * head_num + head_idx * max_seq_len; + auto kernel_block_lse_offset = batch_idx * block_seq_len * head_num + head_idx * block_seq_len; + + auto kernel_out = (T *)out + kernel_out_offset; + auto kernel_lse = lse + kernel_lse_offset; + auto kernel_block_out = (T *)block_out + kernel_block_out_offset; + auto kernel_block_lse = block_lse + kernel_block_lse_offset; + + updateOutAndLse(kernel_out, kernel_lse, kernel_block_out, kernel_block_lse, head_size, + task_length, head_stride, block_head_stride); + } +#endif +} + +#if __BANG_ARCH__ < 500 +template <> +__mlu_global__ void MluUpdateOutAndLseDecoder(void *out, + float *lse, + const void *block_out, + const float *block_lse, + const int32_t *seq_offsets, + const int32_t *cu_seqs, + const int32_t *block_cu_seqs, + const int32_t batch, + const int32_t head_num, + const int32_t head_size, + const int32_t max_seq_len, + const int32_t block_seq_len, + const bool packed, + const int64_t bs_stride, + const int64_t seq_stride, + const int64_t head_stride, + const int64_t block_bs_stride, + const int64_t block_seq_stride, + const int64_t block_head_stride) {} + +template <> +__mlu_global__ void MluUpdateOutAndLse(void *out, + float *lse, + const void *block_out, + const float *block_lse, + const int32_t *seq_offsets, + const int32_t *cu_seqs, + const int32_t *block_cu_seqs, + const int32_t batch, + const int32_t head_num, + const int32_t head_size, + const int32_t max_seq_len, + const int32_t block_seq_len, + const bool packed, + const int64_t bs_stride, + const int64_t seq_stride, + const int64_t head_stride, + const int64_t block_bs_stride, + const int64_t block_seq_stride, + const int64_t block_head_stride) {} +#endif +} // namespace kernels + +inline int32_t dtype_index(cnnlDataType_t dtype) { + switch (dtype) { + case CNNL_DTYPE_HALF: + return 0; + break; + case CNNL_DTYPE_BFLOAT16: + return 1; + break; + case CNNL_DTYPE_FLOAT: + return 2; + break; + default: + return 0; + break; + } +} + +KernelStatus invokeUpdateOutAndLse(cnrtQueue_t queue, + void *out, + float *lse, + const void *block_out, + const float *block_lse, + const int32_t *seq_offsets, + const int32_t *cu_seqs, + const int32_t *block_cu_seqs, + const int32_t batch, + const int32_t head_num, + const int32_t head_size, + const int32_t max_seq_len, + const int32_t block_seq_len, + const int64_t bs_stride, + const int64_t seq_stride, + const int64_t head_stride, + const int64_t block_bs_stride, + const int64_t block_seq_stride, + const int64_t block_head_stride, + const bool packed, + const cnnlDataType_t dtype) { + void (*update_out_and_lse_kernels[])( + void *, float *, const void *, const float *, const int32_t *, const int32_t *, + const int32_t *, const int32_t, const int32_t, const int32_t, const int32_t, const int32_t, + const bool, const int64_t, const int64_t, const int64_t, const int64_t, const int64_t, + const int64_t) = {kernels::MluUpdateOutAndLse, + kernels::MluUpdateOutAndLse, + kernels::MluUpdateOutAndLse, + kernels::MluUpdateOutAndLseDecoder, + kernels::MluUpdateOutAndLseDecoder, + kernels::MluUpdateOutAndLseDecoder}; + + // 非decoder模式,采用taskDimZ拆分batch,taskDimY拆分head,每个task内部进行block_seq_len的循环 + uint32_t task_dimx = 1; + uint32_t task_dimy = head_num; + uint32_t task_dimz = batch; + + bool decoder_mode = (block_seq_len == 1 && max_seq_len == 1); + if (decoder_mode) { + // decoder模式下,采用launch 所有core,内部拆分batch*head_num维度, + // 每个core处理 batch*head_num/taskDimX + CNdev dev; + cnCtxGetDevice(&dev); + int32_t cluster_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); + int32_t core_num; + CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); + + task_dimx = std::min(PAD_UP(batch * head_num, 16) / 16, cluster_num * core_num); + task_dimy = 1; + task_dimz = 1; + } + + cnrtDim3_t task_dim = {task_dimx, task_dimy, task_dimz}; + cnrtFunctionType_t func_type = cnrtFuncTypeBlock; + int32_t kernel_index = dtype_index(dtype) + decoder_mode * 3; + + if (dtype == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) { + std::cerr << "[invokeUpdateOutAndLse]: MLU300 devices do not support bfloat16." << std::endl; + return KernelStatus::KERNEL_STATUS_FAILED; + } + + update_out_and_lse_kernels[kernel_index]<<>>( + out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs, batch, head_num, + head_size, max_seq_len, block_seq_len, packed, bs_stride, seq_stride, head_stride, + block_bs_stride, block_seq_stride, block_head_stride); + + return KernelStatus::KERNEL_STATUS_SUCCESS; +} +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/kernels/update_out_and_lse.mluh b/torch_mlu_ops-v1.3.2/csrc/kernels/update_out_and_lse.mluh new file mode 100644 index 0000000..c2f4446 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/kernels/update_out_and_lse.mluh @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_KERNELS_UPDATE_OUT_AND_LSE_MLUH_ +#define CSRC_KERNELS_UPDATE_OUT_AND_LSE_MLUH_ + +#include "cnnl.h" +#include "kernel_utils.h" +namespace tmo { +/** + * @brief Update out and log-sum-exp(lse) according to block out and block lse. + * @param queue: The queue for mlu. + * @param out: Input/Output. Pointer to the MLU memory that stores the origin out. + * In pad mode, the shape must be [batch, max_seq_len, head_num, head_size]. + * In pack mode, the shape must be [total_seq_len, head_num, head_size]. + * Dim of seq_len and head_num may have stride. + * @param lse: Input/Output. Pointer to the MLU memory that stores the origin lse. + * The shape must be [batch, head_num, max_seq_len]. Lse must be continuous. + * @param block_out: Input. Pointer to the MLU memory that stores the block out. + * In pad mode, the shape must be [batch, block_seq_len, head_num, head_size]. + * In pack mode, the shape must be [total_block_seq_len, head_num, head_size]. + * Dim of seq_len and head_num may have stride. + * @param block_lse: Input. Pointer to the MLU memory that stores the origin lse. + * The shape must be [batch, head_num, block_seq_len]. Block_lse must be continuous. + * @param seq_offsets: Input. Pointer to the MLU memory that stores the origin out + * and lse sequence offset. The shape must be [batch]. + * Seq_offsets must be continuous, and could be nullptr. + * @param cu_seqs: Input. Pointer to the MLU memory that stores the cumulative sum of out seq_lens, + * In pack mode, the shape must be [batch + 1]. + * In pad mode. cu_seqs does not work, could be nullptr. + * @param block_cu_seqs: Input. Pointer to the MLU memory that stores the cumulative sum of block + * out seq_lens, In pack mode, the shape must be [batch + 1]. In pad mode. block_cu_seqs does not + * work, could be nullptr. + * @param dtype: Data type. + * @param batch: Batch size. + * @param head_num: Head number. + * @param head_size: Head size. + * @param max_seq_len: The sequence length of origin out. + * @param block_seq_len: The sequence length of block out. + * @param bs_stride: The stride of batch in origin out, does not work when packed is true. + * @param seq_stride: The stride of seq_len in origin out. + * @param head_stride: The stride of head_num in origin out. + * @param block_bs_stride: The stride of batch in block out, does not work when packed is true. + * @param block_seq_stride: The stride of seq_len in block out. + * @param block_head_stride: The stride of head_num in block out. + * @param packed: A boolean value indicates whether to use pack mode. + * @note All seq_lens in block out should be less than or equal to origin out. + */ +KernelStatus invokeUpdateOutAndLse(cnrtQueue_t queue, + void *out, + float *lse, + const void *block_out, + const float *block_lse, + const int *seq_offsets, + const int *cu_seqs, + const int *block_cu_seqs, + const int batch, + const int head_num, + const int head_size, + const int max_seq_len, + const int block_seq_len, + const int64_t bs_stride, + const int64_t seq_stride, + const int64_t head_stride, + const int64_t block_bs_stride, + const int64_t block_seq_stride, + const int64_t block_head_stride, + const bool packed, + const cnnlDataType_t dtype); +} // namespace tmo + +#endif // CSRC_KERNELS_RESHAPE_LINEAR_CACHE_MLUH_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp new file mode 100644 index 0000000..59ab864 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include +#include +#include "kernel_api.h" + +namespace tmo { +namespace ops { + +size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) { + size_t workspace_size = 0; + CNNL_CHECK_FATAL( + cnnlGetGroupGemmWorkspaceSize(handle, // handle + desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t + nullptr, // cnnlGroupGemmAlgo_t + num_expert, // groups + false, // is_trans_a + true, // is_trans_b + desc.group_host_tensor(), // m_desc + desc.group_host_tensor(), // n_desc + desc.group_host_tensor(), // k_desc + nullptr, // alpha_desc + desc.group_host_tensor(), // lda_desc + desc.a_desc(), // a_desc + desc.group_host_tensor(), // ldb_desc + desc.b_desc(), // b_desc + nullptr, // beta_desc + nullptr, // ldc_desc + nullptr, // c_desc + desc.group_host_tensor(), // ldd_desc + desc.d_desc(), // d_desc + &workspace_size)); + return workspace_size; +} + +void GroupGemm(const cnnlHandle_t &handle, + GroupGemmDesc &desc, + void *m, + void *alpha, + void *beta, + void *workspace, + size_t workspace_size, + int num_expert, + int k, + int n, + int lda, + std::vector &ldb) { + std::vector n_array(num_expert, n); + std::vector k_array(num_expert, k); + std::vector lda_array(num_expert, lda); + std::vector ldd_array(num_expert, n); + + CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/ + nullptr, /*algo*/ + num_expert, /*groups*/ + false, /*is_trans_a*/ + true, /*is_trans_b*/ + desc.group_device_tensor(), m, /*m*/ + desc.group_host_tensor(), n_array.data(), /*n*/ + desc.group_host_tensor(), k_array.data(), /*k*/ + alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/ + alpha, /*alpha*/ + desc.group_host_tensor(), lda_array.data(), /*lda*/ + desc.a_desc(), /*a*/ + ldb.empty() ? nullptr : desc.group_host_tensor(), + ldb.empty() ? nullptr : ldb.data(), /*ldb*/ + desc.b_desc(), /*b*/ + beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/ + beta, /*beta*/ + desc.group_host_tensor(), /*ldc_desc*/ + ldd_array.data(), /*ldc*/ + desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/ + workspace, workspace_size, /*workspace*/ + desc.group_host_tensor(), ldd_array.data(), /*ldd*/ + desc.d_desc())); /*d*/ +} + +} // namespace ops +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp new file mode 100644 index 0000000..1988e34 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp @@ -0,0 +1,117 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernel_api.h" + +namespace tmo { +namespace ops { + +void SmoothQuant(const cnnlHandle_t &handle, + void *input, + void *smooth_scale, + void *token_count, + void *gather_idx, + void *gather_idx_start_position, + void *output, + void *output_scale, + int n, + int c, + int e, + int input_stride, + int output_stride, + int topk, + cnnlDataType_t input_dtype) { + int dim_nb = 2; + int output_scale_dim_nb = 1; + if (!gather_idx) { + topk = 1; + } + int64_t input_dims[] = {n, c}; + int64_t output_dims[] = {n * topk, c}; + int64_t output_scale_dims[] = {n * topk}; + int64_t input_strides[] = {input_stride, 1}; + int64_t output_strides[] = {output_stride, 1}; + + cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor, + gather_idx_start_pos_tensor, output_tensor, output_scale_tensor; + cnnlCreateTensorDescriptor(&input_tensor); + cnnlCreateTensorDescriptor(&smooth_scale_tensor); + cnnlCreateTensorDescriptor(&output_tensor); + cnnlCreateTensorDescriptor(&output_scale_tensor); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype, + dim_nb, input_dims, input_strides)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8, + dim_nb, output_dims, output_strides)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, output_scale_dim_nb, + output_scale_dims)); + + if (token_count) { + int64_t smooth_scale_dims[] = {e, c}; + int64_t token_count_dims[] = {e}; + cnnlCreateTensorDescriptor(&token_count_tensor); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_INT32, 1, token_count_dims)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, 2, smooth_scale_dims)); + } else { + int64_t smooth_scale_dims[] = {c}; + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, 1, smooth_scale_dims)); + } + + if (gather_idx) { + int64_t gather_idx_dims[] = {n * topk}; + cnnlCreateTensorDescriptor(&gather_idx_tensor); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_INT32, 1, gather_idx_dims)); + } + + if (gather_idx_start_position) { + int64_t gather_idx_start_pos_dims[] = {1}; + cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims)); + } + + CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4( + handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/, + input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/, + smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/, + nullptr /*input_sq_zero_ptr*/, + token_count ? token_count_tensor /*token_count_desc*/ : nullptr, + token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr, + gather_idx /*gather_idx_ptr*/, + gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor + : nullptr, + gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/, + output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/, + output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/, + nullptr /*worksapce*/, 0 /*workspace_size*/)); + + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor)); + if (token_count) { + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor)); + } + if (gather_idx) { + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor)); + } + if (gather_idx_start_position) { + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor)); + } +} + +} // namespace ops +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h b/torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h new file mode 100644 index 0000000..1a25cca --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#ifndef CSRC_OPS_KERNEL_API_H_ +#define CSRC_OPS_KERNEL_API_H_ + +#include +#include +#include "cnrt.h" + +#include "op_descriptor/attn_proj_descriptor.h" +#include "op_descriptor/batchmatmul_descriptor.h" +#include "op_descriptor/ffn_descriptor.h" +#include "op_descriptor/group_gemm_descriptor.h" +#include "op_descriptor/matmul_descriptor.h" +#include "op_descriptor/quant_matmul_descriptor.h" + +namespace tmo { +namespace ops { + +using GroupGemmDesc = tmo::op_desc::GroupGemmDesc; + +size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert); + +void GroupGemm(const cnnlHandle_t &handle, + GroupGemmDesc &desc, + void *m, + void *alpha, + void *beta, + void *workspace, + size_t workspace_size, + int num_expert, + int k, + int n, + int lda, + std::vector &ldb); + +void SmoothQuant(const cnnlHandle_t &handle, + void *input, + void *smooth_scale, + void *token_count, + void *gather_idx, + void *gather_idx_start_position, + void *output, + void *output_scale, + int n, + int c, + int e, + int input_stride, + int output_stride, + int topk, + cnnlDataType_t input_dtype); + +} // namespace ops +} // namespace tmo + +#endif // CSRC_OPS_KERNEL_API_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/attn_proj_descriptor.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/attn_proj_descriptor.cpp new file mode 100644 index 0000000..b4bc97f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/attn_proj_descriptor.cpp @@ -0,0 +1,58 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "attn_proj_descriptor.h" +#include "base_descriptor.h" + +namespace tmo { +namespace op_desc { + +namespace { +static OpDescPool attn_proj_instance; +} + +AttnProjDesc::AttnProjDesc() { + impl_ = attn_proj_instance.get(); +} + +void AttnProjDesc::setDesc( + const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode, + const cnnlDataType_t compute_dtype, + const bool q_has_value, + const bool k_has_value, + const bool v_has_value, + const bool has_bias, + const bool is_pack_mode, + const int packed_max_seq_len, + const bool trans_out, + const bool store_layernorm_result, + const float alpha, + const float beta, + const float layernorm_eps, + const bool is_quant) { + CNNL_CHECK_FATAL(cnnlSetTransformerAttnProjDescriptor( + impl_->attn_proj_desc_, layernorm_residual_mode, nullptr, /*activation desc*/ + compute_dtype, q_has_value, k_has_value, v_has_value, has_bias, is_pack_mode, + packed_max_seq_len, trans_out, store_layernorm_result, alpha, beta, layernorm_eps)); + const int allow_tf32 = 0; /*Not allow*/ + CNNL_CHECK_FATAL( + cnnlSetTransformerAttnProjDescriptorAllowTF32(impl_->attn_proj_desc_, allow_tf32)); + this->is_quant_ = is_quant; +} + +AttnProjDescImpl::AttnProjDescImpl() { + CNNL_CHECK_FATAL(cnnlCreateTransformerAttnProjDescriptor(&this->attn_proj_desc_)); + CNNL_CHECK_FATAL(cnnlCreateTransformerAttnProjQuantifyDescriptor(&this->attn_proj_quant_desc_)); +} + +} // namespace op_desc +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/attn_proj_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/attn_proj_descriptor.h new file mode 100644 index 0000000..dc06017 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/attn_proj_descriptor.h @@ -0,0 +1,82 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#ifndef CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_ + +#include +#include +#include "cnnl.h" +#include "cnnl_extra.h" +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +class TMO_HIDDEN AttnProjDescImpl { + public: + AttnProjDescImpl(); + // using clear function to avoid compilation errors. + ~AttnProjDescImpl() { clear(); } + + DELETE_COPY_ASSIGN_CONSTRUCT(AttnProjDescImpl); + + cnnlTransformerAttnProjDescriptor_t attn_proj_desc_; + cnnlTransformerAttnProjQuantizeDescriptor_t attn_proj_quant_desc_; + + private: + inline void clear() { + CNNL_CHECK_FATAL(cnnlDestroyTransformerAttnProjDescriptor(this->attn_proj_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyTransformerAttnProjQuantifyDescriptor(this->attn_proj_quant_desc_)); + } +}; + +class TMO_EXPORT AttnProjDesc { + public: + using impl_type_ = std::unique_ptr>; + AttnProjDesc(); + void setDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode, + const cnnlDataType_t compute_dtype, + const bool q_has_value, + const bool k_has_value, + const bool v_has_value, + const bool has_bias, + const bool is_pack_mode, + const int packed_max_seq_len, + const bool trans_out, + const bool store_layernorm_result, + const float alpha, + const float beta, + const float layernorm_eps, + const bool is_quant = false); + + operator cnnlTransformerAttnProjDescriptor_t() { return this->impl_->attn_proj_desc_; } + + operator cnnlTransformerAttnProjDescriptor_t() const { return this->impl_->attn_proj_desc_; } + + operator cnnlTransformerAttnProjQuantizeDescriptor_t() { + return is_quant_ ? this->impl_->attn_proj_quant_desc_ : nullptr; + } + + operator cnnlTransformerAttnProjQuantizeDescriptor_t() const { + return is_quant_ ? this->impl_->attn_proj_quant_desc_ : nullptr; + } + + private: + impl_type_ impl_; + bool is_quant_ = false; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/base_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/base_descriptor.h new file mode 100644 index 0000000..8b307b0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/base_descriptor.h @@ -0,0 +1,62 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_ + +#include +#include +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +static constexpr int OpCreateStage = 0; + +// Add BaseDescPool is only for store different op +// desc pool in a specific container. +template +class TMO_HIDDEN OpDescPool { + public: + using desc_unique_ptr = std::unique_ptr; + OpDescPool() { createMultiItems(); } + + // delete copy construct and assign construct. + DELETE_COPY_ASSIGN_CONSTRUCT(OpDescPool); + + desc_unique_ptr get() { + auto ptr = std::make_unique(); + return desc_unique_ptr(ptr.release()); + } + + ~OpDescPool() { + std::lock_guard lock(this->mutex_); + for (size_t i = 0; i < this->container_.size(); ++i) { + this->container_[i].release(); + } + this->container_.clear(); + } + + private: + void createMultiItems() { + for (int i = 0; i < OpCreateStage; ++i) { + container_.emplace_back(std::make_unique()); + } + } + // variable + std::mutex mutex_; + std::deque> container_; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/batchmatmul_descriptor.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/batchmatmul_descriptor.cpp new file mode 100644 index 0000000..2131f5b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/batchmatmul_descriptor.cpp @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "batchmatmul_descriptor.h" + +namespace tmo { +namespace op_desc { + +namespace { +static OpDescPool bmm_instance; +} + +BatchMatMulDesc::BatchMatMulDesc(cnnlMatMulHeuristicResult_t &heuristic_result, + cnnlMatMulAlgo_t &algo, + bool use_beta, + bool trans_a, + bool trans_b) { + impl_ = bmm_instance.get(); + auto bmm_desc_ = impl_->bmm_desc_; + heuristic_result = impl_->heuristic_result_; + algo = impl_->algo_; + int32_t matmul_trans_a = int(trans_a); + int32_t matmul_trans_b = int(trans_b); + int32_t matmul_use_tf32 = 0; + int32_t matmul_use_beta = int(use_beta); + CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_DESC_TRANSA, &(matmul_trans_a), + sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_DESC_TRANSB, &(matmul_trans_b), + sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_ALLOW_TF32, &(matmul_use_tf32), + sizeof(int32_t))); + CNNL_CHECK_FATAL( + cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_USE_BETA, &(matmul_use_beta), sizeof(int32_t))); +} + +BatchMatMulDescImpl::BatchMatMulDescImpl() { + CNNL_CHECK_FATAL(cnnlCreateMatMulDescriptor(&this->bmm_desc_)); + CNNL_CHECK_FATAL(cnnlCreateMatMulHeuristicResult(&this->heuristic_result_)); + CNNL_CHECK_FATAL(cnnlCreateMatMulAlgo(&this->algo_)); +} + +} // namespace op_desc +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/batchmatmul_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/batchmatmul_descriptor.h new file mode 100644 index 0000000..f65afb0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/batchmatmul_descriptor.h @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#ifndef CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_ + +#include +#include +#include "base_descriptor.h" +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +class TMO_HIDDEN BatchMatMulDescImpl { + public: + BatchMatMulDescImpl(); + // using clear function to avoid compilation errors. + ~BatchMatMulDescImpl() { clear(); } + DELETE_COPY_ASSIGN_CONSTRUCT(BatchMatMulDescImpl); + + cnnlMatMulDescriptor_t bmm_desc_; + cnnlMatMulHeuristicResult_t heuristic_result_; + cnnlMatMulAlgo_t algo_; + + private: + inline void clear() { + CNNL_CHECK_FATAL(cnnlDestroyMatMulDescriptor(this->bmm_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyMatMulHeuristicResult(this->heuristic_result_)); + CNNL_CHECK_FATAL(cnnlDestroyMatMulAlgo(this->algo_)); + } +}; + +class TMO_EXPORT BatchMatMulDesc { + public: + using impl_type_ = + std::unique_ptr>; + BatchMatMulDesc(cnnlMatMulHeuristicResult_t &heuristic_result, + cnnlMatMulAlgo_t &algo, + bool use_beta = false, + bool trans_a = false, + bool trans_b = true); + + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulDescriptor_t, impl_->bmm_desc_) + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulHeuristicResult_t, impl_->heuristic_result_) + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulAlgo_t, impl_->algo_) + + private: + impl_type_ impl_; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.cpp new file mode 100644 index 0000000..bda0e48 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.cpp @@ -0,0 +1,67 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "ffn_descriptor.h" + +namespace tmo { +namespace op_desc { + +namespace { +static OpDescPool instance; +} // end of anonymous namespace + +FeedForwardDesc::FeedForwardDesc() : is_quanti_(false) { + this->impl_ = instance.get(); +} + +FeedForwardDesc::FeedForwardDesc( + const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode, + const std::string &act_name, + cnnlDataType_t compute_type, + float eps, + float alpha, + float beta, + float act_coef) + : is_quanti_(false) { + this->impl_ = instance.get(); + this->setFeedForwardDesc(layernorm_residual_mode, act_name, compute_type, eps, alpha, beta, + act_coef); +} + +void FeedForwardDesc::setFeedForwardDesc( + const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode, + const std::string &act_name, + cnnlDataType_t compute_type, + float eps, + float alpha, + float beta, + float act_coef) { + // Now only using bfloat16 and half. + // alpha must be 1.0f when NO_RESIDUAL mode. + CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptor_v2( + impl_->ffn_desc_, eps, alpha, beta, compute_type, layernorm_residual_mode)); + act_coef = act_name == "silu" ? 1.0f : act_coef; + cnnlActivationMode_t act_mode = strToActivationMode(act_name); + CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v5(impl_->act_desc_, act_mode, CNNL_ACTIVATION_FAST, + CNNL_NOT_PROPAGATE_NAN, act_coef, + // following parameters are not used + 0, 0, 0, 0)); +} + +FeedForwardDescImpl::FeedForwardDescImpl() { + CNNL_CHECK_FATAL(cnnlCreateTransformerFeedForwardDescriptor(&this->ffn_desc_)); + CNNL_CHECK_FATAL(cnnlCreateTransformerFeedForwardQuantizeDescriptor(&this->quant_desc_)); + CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_)); +} + +} // namespace op_desc +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.h new file mode 100644 index 0000000..116c512 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.h @@ -0,0 +1,102 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_ + +#include +#include +#include "base_descriptor.h" +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +/** + * Note [FeedForwardDesc] + * ~~~~~~~~~~~~~~~~ + * FeedForwardDesc is class to create a Transformer FeedForward op descriptor. + * + * ffn_inner_size is intermediate_size; + * act_name mean act type, and will convert act type when set op desc; + * act_coef is only using when act_name is CNNL_ACTIVATION_CLIPPED_RELU, + * CNNL_ACTIVATION_ELU, CNNL_ACTIVATION_ELU_V2, CNNL_ACTIVATION_LEAKYRELU, + * CNNL_ACTIVATION_TF_LEAKYRELU, CNNL_ACTIVATION_CAFFE_RELU6. + * + */ + +struct TMO_HIDDEN FeedForwardDescImpl { + FeedForwardDescImpl(); + // using clear function to avoid compilation errors. + ~FeedForwardDescImpl() { clear(); } + DELETE_COPY_ASSIGN_CONSTRUCT(FeedForwardDescImpl); + // variable + cnnlTransformerFeedForwardDescriptor_t ffn_desc_; + cnnlTransformerFeedForwardQuantizeDescriptor_t quant_desc_; + cnnlActivationDescriptor_t act_desc_; + + private: + inline void clear() { + CNNL_CHECK_FATAL(cnnlDestroyTransformerFeedForwardDescriptor(this->ffn_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyTransformerFeedForwardQuantizeDescriptor(this->quant_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_)); + } +}; + +class TMO_EXPORT FeedForwardDesc { + public: + using impl_type = + std::unique_ptr>; + FeedForwardDesc(); + FeedForwardDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode, + const std::string &act_name, + cnnlDataType_t compute_type, + float eps, + float alpha, + float beta, + float act_coef = 0); + + void setFeedForwardDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode, + const std::string &act_name, + cnnlDataType_t compute_type, + float eps, + float alpha, + float beta, + float act_coef = 0); + + // TODO(SG): quanti op desc is not support now, will support later. + // void setQuantiDesc(); + + // TODO(SG): using std::string to pass layer_norm type? It's not clear now. + // Add layer_norm fused later. + // void setLayerNormInfo(); + + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlTransformerFeedForwardDescriptor_t, impl_->ffn_desc_); + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlActivationDescriptor_t, impl_->act_desc_); + + operator cnnlTransformerFeedForwardQuantizeDescriptor_t() const { + return is_quanti_ == false + ? nullptr + : const_cast(impl_->quant_desc_); + } + operator cnnlTransformerFeedForwardQuantizeDescriptor_t() { + return is_quanti_ == false ? nullptr : impl_->quant_desc_; + } + + private: + impl_type impl_; + bool is_quanti_{false}; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/group_gemm_descriptor.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/group_gemm_descriptor.cpp new file mode 100644 index 0000000..b57252d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/group_gemm_descriptor.cpp @@ -0,0 +1,178 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "group_gemm_descriptor.h" +#include +#include + +namespace tmo { +namespace op_desc { + +namespace { +static OpDescPool instance; +} // end of anonymous namespace + +GroupGemmDesc::GroupGemmDesc(int num_expert, + int max_m, + int n, + int k, + cnnlDataType_t dtype, + bool idx_mode, + QuantMode quant_mode) { + this->impl_ = instance.get(); + this->idx_mode_ = idx_mode; + this->quant_mode_ = quant_mode; + set(num_expert, max_m, n, k, dtype); +} + +void GroupGemmDesc::set(int num_expert, int max_m, int n, int k, cnnlDataType_t dtype) { + this->num_expert_ = num_expert; + this->max_m_ = max_m; + this->n_ = n; + this->k_ = k; + this->dtype_ = dtype; + // no need to set compute_dtype, default is CNNL_DTYPE_FLOAT + CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_M_MAX, + &max_m, sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_N_MAX, + &n, sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_K_MAX, + &k, sizeof(int32_t))); + int dim_nb = 1; + int64_t dims[] = {num_expert}; + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->group_host_tensor_, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_INT32, dim_nb, dims)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->group_device_tensor_, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_INT32, dim_nb, dims)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->scale_factor_tensor_, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, dim_nb, dims)); + CNNL_CHECK_FATAL( + cnnlSetTensorDescriptorPointerMode(impl_->group_host_tensor_, CNNL_POINTER_MODE_HOST)); + CNNL_CHECK_FATAL( + cnnlSetTensorDescriptorPointerMode(impl_->group_device_tensor_, CNNL_POINTER_MODE_DEVICE)); + CNNL_CHECK_FATAL( + cnnlSetTensorDescriptorPointerMode(impl_->scale_factor_tensor_, CNNL_POINTER_MODE_DEVICE)); +} + +void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale, void *b_scale) { + setPerRowColScaleBiasAct(a_scale, b_scale, nullptr, nullptr, CNNL_DTYPE_INVALID, + CNNL_ACTIVATION_IDENTITY, CNNL_PROPAGATE_NAN, 0.f, false, 0, 0, 0); +} + +void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale, + void *b_scale, + void *quant_flag, + void *bias, + cnnlDataType_t bias_dtype, + int channels, + int blk, + int pos) { + setPerRowColScaleBiasAct(a_scale, b_scale, quant_flag, bias, bias_dtype, CNNL_ACTIVATION_IDENTITY, + CNNL_PROPAGATE_NAN, 0.f, false, channels, blk, pos); +} + +void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale, + void *b_scale, + void *quant_flag, + void *bias, + cnnlDataType_t bias_dtype, + cnnlActivationMode_t mode, + cnnlNanPropagation_t nan_prop, + float coef, + bool has_active, + int channels, + int scale_block_size, + int channel_pos) { + bool is_a_quant = a_scale != nullptr; + bool use_b_scale = false; + bool quant_grouped = (scale_block_size > 0 && channels / scale_block_size > 0) ? true : false; + if (a_scale) { + CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(a_scale_desc(), CNNL_DTYPE_FLOAT, a_scale, + CNNL_POINTER_MODE_DEVICE, nullptr)); + } + if (!quant_grouped && b_scale) { + use_b_scale = true; + CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(b_scale_desc(), CNNL_DTYPE_FLOAT, b_scale, + CNNL_POINTER_MODE_DEVICE, nullptr)); + } + if (bias) { + this->has_bias_ = true; + CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(bias_desc(), bias_dtype, bias, + CNNL_POINTER_MODE_DEVICE, nullptr)); + } + this->has_active_ = has_active; + if (has_active) { + CNNL_CHECK_FATAL(cnnlSetActivationDescriptor(active_desc(), mode, nan_prop, coef)); + } + + if (quant_grouped) { + const int dim_nb = 2; + int group_num = channels / scale_block_size; + int64_t dims[dim_nb] = {group_num, num_expert_ * n_}; + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(b_scale_quant_grouped_desc(), CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, dim_nb, dims)); + + CNNL_CHECK_FATAL(cnnlSetGroupGemmGroupwiseScale( + group_gemm_desc(), CNNL_MATMUL_DESC_B_SCALE_POINTER, channels, scale_block_size, + channel_pos, b_scale_quant_grouped_desc(), b_scale)); + if (quant_flag) { + const int dim_qf = 2; + int64_t dims[dim_qf] = {num_expert_, group_num}; + CNNL_CHECK_FATAL( + cnnlSetTensorDescriptorPointerMode(quant_flag_desc(), CNNL_POINTER_MODE_HOST)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(quant_flag_desc(), CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_INT32, dim_qf, dims)); + CNNL_CHECK_FATAL(cnnlSetGroupGemmGroupMixedQuantBitFlag( + group_gemm_desc(), CNNL_MATMUL_DESC_B_QUANT_FLAG_POINTER, channels, scale_block_size, + quant_flag_desc(), quant_flag)); + } + } + + CNNL_CHECK_FATAL(cnnlSetGroupGemmPerRowColScaleBiasAct( + group_gemm_desc(), is_a_quant ? a_scale_desc() : nullptr, + use_b_scale ? b_scale_desc() : nullptr, this->has_bias_ ? bias_desc() : nullptr, + has_active ? active_desc() : nullptr)); +} + +void GroupGemmDesc::setInputOutputTensor(cnnlDataType_t a_dtype, + cnnlDataType_t b_dtype, + cnnlDataType_t d_dtype, + cnnlDataType_t idx_dtype, + void *a, + void *b, + void *c, + void *d, + void *idx, + int64_t k, + int64_t k_stride, + int64_t total_m, + bool has_c, + int64_t *b_offset) { + this->has_c_ = has_c; + if (idx_mode()) { + cnnlSetGroupGemmTensorGatherIdxDesc(a_desc(), a_dtype, idx_dtype, a, idx, k, k_stride, total_m); + } else { + CNNL_CHECK_FATAL( + cnnlSetGroupGemmTensorDescriptor(a_desc(), a_dtype, a, CNNL_POINTER_MODE_DEVICE, nullptr)); + } + auto b_ptr_mode = b_offset ? CNNL_POINTER_MODE_HOST : CNNL_POINTER_MODE_DEVICE; + CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(b_desc(), b_dtype, b, b_ptr_mode, b_offset)); + if (has_c) { + CNNL_CHECK_FATAL( + cnnlSetGroupGemmTensorDescriptor(c_desc(), d_dtype, c, CNNL_POINTER_MODE_DEVICE, nullptr)); + } + CNNL_CHECK_FATAL( + cnnlSetGroupGemmTensorDescriptor(d_desc(), d_dtype, d, CNNL_POINTER_MODE_DEVICE, nullptr)); +} + +} // namespace op_desc +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/group_gemm_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/group_gemm_descriptor.h new file mode 100644 index 0000000..d58ecce --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/group_gemm_descriptor.h @@ -0,0 +1,177 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_ + +#include +#include +#include +#include +#include "base_descriptor.h" +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +struct TMO_HIDDEN GroupGemmDescImpl { + GroupGemmDescImpl() { + CNNL_CHECK_FATAL(cnnlCreateGroupGemmDescriptor(&group_gemm_desc_)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&a_idx_desc_, CNNL_GROUP_GEMM_ADDRESSING_GATHER_IDX)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&a_offset_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&b_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&c_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&d_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&a_scale_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&b_scale_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL( + cnnlCreateGroupGemmTensorDescriptor(&bias_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS)); + CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&b_scale_quant_grouped_desc_)); + CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&group_host_tensor_)); + CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&group_device_tensor_)); + CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&scale_factor_tensor_)); + CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&quant_flag_tensor_)); + CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&active_desc_)); + } + // using clear function to avoid compilation errors. + ~GroupGemmDescImpl() { clear(); } + DELETE_COPY_ASSIGN_CONSTRUCT(GroupGemmDescImpl); + // variable + cnnlGroupGemmDescriptor_t group_gemm_desc_; + cnnlGroupGemmTensorDescriptor_t a_idx_desc_; + cnnlGroupGemmTensorDescriptor_t a_offset_desc_; + cnnlGroupGemmTensorDescriptor_t b_desc_; + cnnlGroupGemmTensorDescriptor_t c_desc_; + cnnlGroupGemmTensorDescriptor_t d_desc_; + cnnlGroupGemmTensorDescriptor_t a_scale_desc_; + cnnlGroupGemmTensorDescriptor_t b_scale_desc_; + cnnlGroupGemmTensorDescriptor_t bias_desc_; + cnnlTensorDescriptor_t b_scale_quant_grouped_desc_; + cnnlTensorDescriptor_t group_host_tensor_; + cnnlTensorDescriptor_t group_device_tensor_; + cnnlTensorDescriptor_t scale_factor_tensor_; + cnnlTensorDescriptor_t quant_flag_tensor_; + cnnlActivationDescriptor_t active_desc_; + + private: + inline void clear() { + CNNL_CHECK_FATAL( + cnnlSetTensorDescriptorPointerMode(group_host_tensor_, CNNL_POINTER_MODE_DEVICE)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmDescriptor(group_gemm_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_idx_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_offset_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(b_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(c_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(d_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_scale_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(b_scale_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(bias_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(b_scale_quant_grouped_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(group_host_tensor_)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(group_device_tensor_)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(scale_factor_tensor_)); + CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(quant_flag_tensor_)); + CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(active_desc_)); + } +}; + +class TMO_EXPORT GroupGemmDesc { + public: + enum class QuantMode { noQuant, W8, W4, W4W8 }; + using impl_type = std::unique_ptr>; + GroupGemmDesc() = delete; + GroupGemmDesc(int num_expert, + int max_m, + int n, + int k, + cnnlDataType_t dtype, + bool idx_mode, + QuantMode quant_mode); + void set(int num_expert, int max_m, int n, int k, cnnlDataType_t dtype); + void setPerRowColScaleBiasAct(void *a_scale, void *b_scale); + void setPerRowColScaleBiasAct(void *a_scale, + void *b_scale, + void *quant_flag, + void *bias, + cnnlDataType_t bias_dtype, + int channels, + int blk, + int pos); + void setPerRowColScaleBiasAct(void *a_scale, + void *b_scale, + void *quant_flag, // for w4/w8 quantize + void *bias, // nullptr + cnnlDataType_t bias_dtype, // CNNL_DTYPE_INVALID + cnnlActivationMode_t mode, // CNNL_ACTIVATION_SILU + cnnlNanPropagation_t nan_prop, // CNNL_PROPAGATE_NAN + float coef, // 0.f + bool has_active, // false + int channels, + int scale_block_size, + int channel_pos); + void setInputOutputTensor(cnnlDataType_t a_dtype, + cnnlDataType_t b_dtype, + cnnlDataType_t d_dtype, + cnnlDataType_t idx_dtype, + void *a, + void *b, + void *c, + void *d, + void *idx, + int64_t k, + int64_t k_stride, + int64_t total_m, + bool has_c, + int64_t *b_offset = nullptr); + cnnlGroupGemmDescriptor_t group_gemm_desc() { return impl_->group_gemm_desc_; } + cnnlGroupGemmTensorDescriptor_t a_desc() { + return idx_mode_ ? impl_->a_idx_desc_ : impl_->a_offset_desc_; + } + cnnlGroupGemmTensorDescriptor_t b_desc() { return impl_->b_desc_; } + cnnlGroupGemmTensorDescriptor_t c_desc() { return impl_->c_desc_; } + cnnlGroupGemmTensorDescriptor_t d_desc() { return impl_->d_desc_; } + cnnlGroupGemmTensorDescriptor_t a_scale_desc() { return impl_->a_scale_desc_; } + cnnlGroupGemmTensorDescriptor_t b_scale_desc() { return impl_->b_scale_desc_; } + cnnlGroupGemmTensorDescriptor_t bias_desc() { return impl_->bias_desc_; } + cnnlTensorDescriptor_t b_scale_quant_grouped_desc() { return impl_->b_scale_quant_grouped_desc_; } + cnnlTensorDescriptor_t group_host_tensor() { return impl_->group_host_tensor_; } + cnnlTensorDescriptor_t group_device_tensor() { return impl_->group_device_tensor_; } + cnnlTensorDescriptor_t scale_factor_tensor() { return impl_->scale_factor_tensor_; } + cnnlTensorDescriptor_t quant_flag_desc() { return impl_->quant_flag_tensor_; } + cnnlActivationDescriptor_t active_desc() { return impl_->active_desc_; } + bool idx_mode() { return idx_mode_; } + bool has_c() { return has_c_; } + + private: + impl_type impl_; + bool idx_mode_ = false; + bool has_c_ = false; + bool has_bias_ = false; + bool has_active_ = false; + int num_expert_; + int max_m_; + int n_; + int k_; + cnnlDataType_t dtype_; + QuantMode quant_mode_ = QuantMode::noQuant; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/matmul_descriptor.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/matmul_descriptor.cpp new file mode 100644 index 0000000..cb96f6e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/matmul_descriptor.cpp @@ -0,0 +1,104 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "matmul_descriptor.h" + +namespace tmo { +namespace op_desc { + +namespace { +static OpDescPool matmul_instance; +} + +MatMulDesc::MatMulDesc(const cnnlTensorDescriptor_t &bias_tensor, + const void *bias_data, + size_t &workspace_size, + int32_t lda, + int32_t ldb, + int32_t ldc, + const std::string &act_name, + bool use_beta, + bool trans_a, + bool trans_b, + bool fast_act, + bool approximate) { + impl_ = matmul_instance.get(); + auto matmul_desc_ = impl_->matmul_desc_; + auto act_desc_ = impl_->act_desc_; + workspace_size = max_workspace_size_; + + int32_t matmul_trans_a = int(trans_a); + int32_t matmul_trans_b = int(trans_b); + int32_t matmul_use_tf32 = 0; + int32_t matmul_use_beta = int(use_beta); + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_TRANSA, + &(matmul_trans_a), sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_TRANSB, + &(matmul_trans_b), sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_ALLOW_TF32, + &(matmul_use_tf32), sizeof(int32_t))); + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_USE_BETA, + &(matmul_use_beta), sizeof(int32_t))); + CNNL_CHECK_FATAL( + cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDA, &(lda), sizeof(int32_t))); + CNNL_CHECK_FATAL( + cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDB, &(ldb), sizeof(int32_t))); + CNNL_CHECK_FATAL( + cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDC, &(ldc), sizeof(int32_t))); + + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_PREF_MAX_WORKSPACE_BYTES, + &(max_workspace_size_), sizeof(int64_t))); + if (act_name == "none") { + if (bias_data == nullptr) { + cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_NONE; + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr( + matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE, + (void *)&fuse_type, sizeof(fuse_type))); + + } else { + cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_BIAS; + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr( + matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE, + (void *)&fuse_type, sizeof(fuse_type))); + CNNL_CHECK_FATAL(cnnlSetMatMulExBias(matmul_desc_, bias_tensor, (void *)bias_data)); + } + } else { + float act_coef = 0.0f; + cnnlActivationMode_t act_mode = strToActivationMode(act_name); + if (act_name == "silu") { + act_coef = 1.0f; + act_mode = CNNL_ACTIVATION_SILU; + } + cnnlActivationPreference_t act_pref = CNNL_ACTIVATION_HIGH_PRECISION; + if (fast_act) { + act_pref = CNNL_ACTIVATION_FAST; + } + CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v6( + act_desc_, act_mode, act_pref, CNNL_NOT_PROPAGATE_NAN, act_coef, 0, 0, 0, 0, approximate)); + // fuse info + cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_BIAS_SCALE_BN_ACTIVATION; + CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr( + matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE, &(fuse_type), + sizeof(fuse_type))); + CNNL_CHECK_FATAL(cnnlSetMatMulExBiasScaleBNActive(matmul_desc_, bias_tensor, bias_data, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, + 0, 0, 0, 0, 0, act_desc_)); + } +} + +MatMulDescImpl::MatMulDescImpl() { + CNNL_CHECK_FATAL(cnnlCreateMatMulExDescriptor(&this->matmul_desc_)); + CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_)); +} + +} // namespace op_desc +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/matmul_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/matmul_descriptor.h new file mode 100644 index 0000000..ef3355c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/matmul_descriptor.h @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#ifndef CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_ + +#include +#include +#include "base_descriptor.h" +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +class TMO_HIDDEN MatMulDescImpl { + public: + MatMulDescImpl(); + // using clear function to avoid compilation errors. + ~MatMulDescImpl() { clear(); } + DELETE_COPY_ASSIGN_CONSTRUCT(MatMulDescImpl); + + cnnlMatMulExDescriptor_t matmul_desc_; + cnnlActivationDescriptor_t act_desc_; + + private: + inline void clear() { + CNNL_CHECK_FATAL(cnnlDestroyMatMulExDescriptor(this->matmul_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_)); + } +}; + +class TMO_EXPORT MatMulDesc { + public: + using impl_type_ = std::unique_ptr>; + MatMulDesc(const cnnlTensorDescriptor_t &bias_tensor, + const void *bias_data, + size_t &workspace_size, + int32_t lda, + int32_t ldb, + int32_t ldc, + const std::string &act_name = "none", + bool use_beta = false, + bool trans_a = false, + bool trans_b = true, + bool fast_act = true, + bool approximate = true); + + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulExDescriptor_t, impl_->matmul_desc_) + CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlActivationDescriptor_t, impl_->act_desc_) + + private: + size_t max_workspace_size_ = 128 * 1024 * 1024; + impl_type_ impl_; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/quant_matmul_descriptor.cpp b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/quant_matmul_descriptor.cpp new file mode 100644 index 0000000..90655d0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/quant_matmul_descriptor.cpp @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "quant_matmul_descriptor.h" + +namespace tmo { +namespace op_desc { + +namespace { +static OpDescPool instance; +} + +QuantMatmulDesc::QuantMatmulDesc(const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + unsigned quant_bit_size, + cnnlDataType_t compute_dtype, + const std::string &act_name, + bool use_hp_active, + float act_coef, + bool trans_a, + bool trans_b) { + impl_ = instance.get(); + this->setQuantMatmulDesc(quant_algo, a_quant_layout, b_quant_layout, quant_bit_size, + compute_dtype, act_name, use_hp_active, act_coef, trans_a, trans_b); +} + +void QuantMatmulDesc::setQuantMatmulDesc(const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + unsigned quant_bit_size, + cnnlDataType_t compute_dtype, + const std::string &act_name, + bool use_hp_active, + float act_coef, + bool trans_a, + bool trans_b) { + cnnlQuantizeDescriptor_t no_quant; + CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&no_quant)); + // set quantize algo and layout + cnnlLLMQuantAlgo_t cnnl_quant_algo = strToQuantizeAlgo(quant_algo); + CNNL_CHECK_FATAL(cnnlSetQuantizeDescriptorBehaviour( + impl_->a_quant_desc_, strToQuantizeLayout(a_quant_layout), quant_bit_size, 0, false)); + CNNL_CHECK_FATAL(cnnlSetQuantizeDescriptorBehaviour( + impl_->b_quant_desc_, strToQuantizeLayout(b_quant_layout), quant_bit_size, 0, false)); + // set activation + act_coef = act_name == "silu" ? 1.0f : act_coef; + cnnlActivationMode_t act_mode = strToActivationMode(act_name); + CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v5( + impl_->act_desc_, act_mode, + use_hp_active ? CNNL_ACTIVATION_HIGH_PRECISION : CNNL_ACTIVATION_FAST, CNNL_NOT_PROPAGATE_NAN, + act_coef, 0, 0, 0, 0)); + // set into op_desc + CNNL_CHECK_FATAL(cnnlSetLLMQuantMatmulDescriptor( + impl_->op_desc_, impl_->a_quant_desc_, impl_->b_quant_desc_, no_quant, no_quant, + impl_->act_desc_, compute_dtype, cnnl_quant_algo, trans_a, trans_b)); + CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(no_quant)); +} + +QuantMatmulDescImpl::QuantMatmulDescImpl() { + CNNL_CHECK_FATAL(cnnlCreateLLMQuantMatmulDescriptor(&this->op_desc_)); + CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->a_quant_desc_)); + CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->b_quant_desc_)); + CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->c_quant_desc_)); + CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->d_quant_desc_)); + CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_)); +} + +} // namespace op_desc +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/quant_matmul_descriptor.h b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/quant_matmul_descriptor.h new file mode 100644 index 0000000..2f737ca --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/quant_matmul_descriptor.h @@ -0,0 +1,86 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#ifndef CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_ +#define CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_ + +#include +#include +#include "base_descriptor.h" +#include "common/utils.h" + +namespace tmo { +namespace op_desc { + +class TMO_HIDDEN QuantMatmulDescImpl { + public: + QuantMatmulDescImpl(); + // using clear function to avoid compilation errors. + ~QuantMatmulDescImpl() { clear(); } + DELETE_COPY_ASSIGN_CONSTRUCT(QuantMatmulDescImpl); + // variable + cnnlLLMQuantMatmulDescriptor_t op_desc_; + cnnlQuantizeDescriptor_t a_quant_desc_; + cnnlQuantizeDescriptor_t b_quant_desc_; + cnnlQuantizeDescriptor_t c_quant_desc_; + cnnlQuantizeDescriptor_t d_quant_desc_; + cnnlActivationDescriptor_t act_desc_; + + private: + inline void clear() { + CNNL_CHECK_FATAL(cnnlDestroyLLMQuantMatmulDescriptor(this->op_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->a_quant_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->b_quant_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->c_quant_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->d_quant_desc_)); + CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_)); + } +}; + +class TMO_EXPORT QuantMatmulDesc { + public: + using impl_type_ = + std::unique_ptr>; + QuantMatmulDesc(const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + unsigned quant_bit_size, + cnnlDataType_t compute_dtype, + const std::string &act_name, + bool use_hp_active, + float act_coef, + bool trans_a, + bool trans_b); + + // set descriptor + void setQuantMatmulDesc(const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + unsigned quant_bit_size, + cnnlDataType_t compute_dtype, + const std::string &act_name, + bool use_hp_active, + float act_coef, + bool trans_a, + bool trans_b); + + operator cnnlLLMQuantMatmulDescriptor_t() { return this->impl_->op_desc_; } + operator cnnlLLMQuantMatmulDescriptor_t() const { return this->impl_->op_desc_; } + + private: + impl_type_ impl_; +}; + +} // namespace op_desc +} // namespace tmo + +#endif // CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp new file mode 100644 index 0000000..a681ca7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/active.cpp @@ -0,0 +1,58 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/moe/add_bias_activation.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +void active(const torch::Tensor &input, + const torch::Tensor &output, + const c10::optional &bias, + const c10::optional &cusum_token_count, + const std::string &act_mode, + bool is_gated, + int64_t start_expert_id, + int64_t expert_size, + double active_coef) { + TORCH_CHECK( + act_mode == "silu" || act_mode == "gelu" || act_mode == "quick_gelu" || act_mode == "swish", + "act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.") + cnnlActivationMode_t act_type = act_mode == "gelu" ? CNNL_ACTIVATION_GELU : CNNL_ACTIVATION_SWISH; + if (act_mode == "quick_gelu") { + active_coef = 1.702f; + } else if (act_mode == "silu") { + active_coef = 1.0f; + } + TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2") + auto input_shape = input.sizes(); + int64_t in_channel = input_shape.back(); + TORCH_CHECK(in_channel > 0, "in_channel > 0") + if (is_gated) { + TORCH_CHECK(in_channel % 2 == 0, "in_channel % 2 == 0 if is_gated is true") + } + int64_t total_tokens = input.numel() / in_channel; + int64_t inner_size = is_gated ? in_channel / 2 : in_channel; + int64_t num_expert = cusum_token_count.has_value() ? (cusum_token_count.value().size(0) - 1) : 0; + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + int64_t output_stride = output.stride(-2); + auto data_dtype = getCnnlDataType(input.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + tmo::invokeGroupAddBiasActivationKernel( + queue, getAtTensorPtr(output), getAtTensorPtr(input), getAtTensorPtr(bias), + (int *)getAtTensorPtr(cusum_token_count), num_expert, total_tokens, inner_size, output_stride, + data_dtype, is_gated, act_type, start_expert_id, expert_size, active_coef); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/apply_rotary.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/apply_rotary.cpp new file mode 100644 index 0000000..029b98f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/apply_rotary.cpp @@ -0,0 +1,145 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/rotary_embedding.mluh" +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { +void apply_rotary(const torch::Tensor &input, + const torch::Tensor &sin_cache, + const torch::Tensor &cos_cache, + const c10::optional &position_ids, + const c10::optional &cu_seqlens, + bool interleaved, + bool discrete, + bool dynamic_ntk, + int64_t max_seqlen) { + auto output = input; + // 1.check device and tensor type + checkTensorSameAttr(input, sin_cache, cos_cache); + + const bool has_position_ids = position_ids.has_value(); + const bool has_cu_seqlens = cu_seqlens.has_value(); + const int origin_device_id = input.get_device(); + const void *position_ids_ptr = nullptr; + const void *cu_seqlens_ptr = nullptr; + + if (has_position_ids) { + TORCH_CHECK(position_ids.value().dtype() == torch::kInt32, + "position_ids type need be torch::kInt32"); + TORCH_CHECK(position_ids.value().get_device() == origin_device_id, + "Tensor device index is not the same, original index: ", origin_device_id, + "now index is: ", position_ids.value().get_device()); + position_ids_ptr = position_ids.value().data_ptr(); + } + + if (has_cu_seqlens) { + TORCH_CHECK(cu_seqlens.value().dtype() == torch::kInt32, + "cu_seqlens type need be torch::kInt32"); + TORCH_CHECK(cu_seqlens.value().get_device() == origin_device_id, + "Tensor device index is not the same, original index: ", origin_device_id, + "now index is: ", cu_seqlens.value().get_device()); + cu_seqlens_ptr = cu_seqlens.value().data_ptr(); + } + + // 2. check shape + int total_seqlen = 0; + int batch_size = 0; + int head_size = input.size(-1); + if (input.dim() == 3) { // pack mode + TORCH_CHECK(has_cu_seqlens, + "input has 3 dims: (total_seq_len, head_num, head_size)," + " which means pack mode, cu_seqlens should not be None"); + total_seqlen = input.size(0); + batch_size = cu_seqlens.value().size(0) - 1; + } else if (input.dim() == 4) { + TORCH_CHECK(!has_cu_seqlens, + "input has 4 dims: (batch_size, seq_len, head_num, head_size)," + " which means pad mode, cu_seqlens should be None"); + TORCH_CHECK(max_seqlen == input.size(1), + "input has 4 dims: (batch_size, seq_len, head_num, head_size)," + " which means pad mode, max_seqlen must be equtals to input.size(1)"); + batch_size = input.size(0); + total_seqlen = batch_size * input.size(1); + } else { + TORCH_CHECK(false, "input only support 3 or 4 dims"); + } + TORCH_CHECK(head_size <= 256, "only support input head_size <= 256"); + + const int rope_seqlen = dynamic_ntk ? sin_cache.size(1) : sin_cache.size(0); + const int rope_dim = dynamic_ntk ? sin_cache.size(2) : sin_cache.size(1); + + if (has_position_ids) { + if (discrete) { + CHECK_SHAPE(position_ids.value(), total_seqlen); + } else { + CHECK_SHAPE(position_ids.value(), batch_size); + } + } else { + TORCH_CHECK(!discrete, "discrete must be false if position ids is null.") + } + + if (!(has_position_ids && discrete)) { + TORCH_CHECK(max_seqlen <= rope_seqlen, "max_seqlen must less than or equal to rope_seqlen.") + } + + if (dynamic_ntk) { + CHECK_SHAPE(sin_cache, batch_size, rope_seqlen, rope_dim); + CHECK_SHAPE(cos_cache, batch_size, rope_seqlen, rope_dim); + } else { + CHECK_SHAPE(sin_cache, rope_seqlen, rope_dim); + CHECK_SHAPE(cos_cache, rope_seqlen, rope_dim); + } + + // 3. check strides + TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous"); + + if (dynamic_ntk) { + TORCH_CHECK(sin_cache.stride(1) == cos_cache.stride(1), + "sin_cache second stride must be equal to cos_cache second stride"); + } else { + TORCH_CHECK(sin_cache.stride(0) == cos_cache.stride(0), + "sin_cache first stride must be equal to cos_cache second stride"); + } + + if (has_position_ids) { + TORCH_CHECK(position_ids.value().is_contiguous(), "position_ids must be contiguous"); + } + + if (has_cu_seqlens) { + TORCH_CHECK(cu_seqlens.value().is_contiguous(), "cu_seqlens must be contiguous"); + } + + // prepare inputs + auto dims = input.dim(); + const int64_t num_heads = input.size(dims - 2); + const int64_t head_dim = input.size(dims - 1); + const int64_t input_seq_stride = input.strides()[dims - 3]; + const int64_t input_head_stride = input.strides()[dims - 2]; + const int64_t output_seq_stride = output.strides()[dims - 3]; + const int64_t output_head_stride = output.strides()[dims - 2]; + + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + auto queue = torch_mlu::getCurMLUStream(); + auto data_type = getCnnlDataType(input.scalar_type()); + + invokeRotaryEmbedding(queue, output.data_ptr(), input.data_ptr(), sin_cache.data_ptr(), + cos_cache.data_ptr(), (int *)position_ids_ptr, (int *)cu_seqlens_ptr, + batch_size, max_seqlen, num_heads, head_dim, rope_seqlen, rope_dim, + dynamic_ntk ? sin_cache.strides()[1] : sin_cache.strides()[0], + input_seq_stride, input_head_stride, output_seq_stride, output_head_stride, + interleaved, discrete, dynamic_ntk, data_type); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/attn_proj.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/attn_proj.cpp new file mode 100644 index 0000000..c53925b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/attn_proj.cpp @@ -0,0 +1,151 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +std::vector attention_project(const at::Tensor &input, + const at::Tensor &q_weight, + const c10::optional &q_bias, + const c10::optional &k_weight, + const c10::optional &k_bias, + const c10::optional &v_weight, + const c10::optional &v_bias, + const c10::optional &norm_weight, + const c10::optional &norm_bias, + const c10::optional &residual, + const std::string &out_layout, + int64_t head_size, + double eps, + double alpha, + double beta, + bool norm_out) { + // check device and dtype + checkTensorSameAttr(input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias, + norm_weight, norm_bias, residual); + + // check contiguous + CHECK_TENSOR_CONTIGUOUS(input) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(residual) + + // check size + const int64_t nDim = input.dim(); + at::Tensor input_view = input; + if (nDim == 2) { + input_view = input.unsqueeze(0); + } + + const bool has_k = k_weight.has_value(); + const bool has_v = v_weight.has_value(); + const int64_t n = input_view.size(0); + const int64_t t = input_view.size(1); + const int64_t hidden_size_q = q_weight.size(0); + const int64_t hidden_size_k = has_k ? k_weight.value().size(0) : 0; + const int64_t hidden_size_v = has_v ? v_weight.value().size(0) : 0; + // check params + bool has_bias = q_bias.has_value(); + bool has_ln = norm_weight.has_value(); + bool has_residual = residual.has_value(); + TORCH_CHECK(!(has_ln && has_residual), "cannot support layernorm and residual at the same time.") + TORCH_CHECK(out_layout == "nthc" || (out_layout == "nhtc" && nDim == 3), + "input must be 3-D if out_layout is 'nhtc'") + bool trans_out = (out_layout == "nhtc"); + const int64_t head_num_q = hidden_size_q / head_size; + const int64_t head_num_k = hidden_size_k / head_size; + const int64_t head_num_v = hidden_size_v / head_size; + + const torch_mlu::mlu::MLUGuard device_guard(input_view.device()); + auto q_shape = trans_out ? std::vector({n, head_num_q, t, head_size}) + : std::vector({n, t, hidden_size_q}); + auto k_shape = trans_out ? std::vector({n, head_num_k, t, head_size}) + : std::vector({n, t, hidden_size_k}); + auto v_shape = trans_out ? std::vector({n, head_num_v, t, head_size}) + : std::vector({n, t, hidden_size_v}); + auto out_q = at::empty(q_shape, input_view.options()); + auto out_k = has_k ? at::empty(k_shape, input_view.options()) : at::Tensor(); + auto out_v = has_v ? at::empty(q_shape, input_view.options()) : at::Tensor(); + auto out_ln = norm_out ? at::empty(input.sizes(), input_view.options()) : at::Tensor(); + + // create tensor descs + auto descs = + createTensorDescs({input_view, q_weight, q_bias.value_or(at::Tensor()), + k_weight.value_or(at::Tensor()), k_bias.value_or(at::Tensor()), + v_weight.value_or(at::Tensor()), v_bias.value_or(at::Tensor()), + norm_weight.value_or(at::Tensor()), norm_bias.value_or(at::Tensor()), + residual.value_or(at::Tensor()), out_q, out_k, out_v, out_ln}); + + // create and set attn_proj_desc + cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode = + has_ln ? CNNL_TRANSFORMER_PRE_LAYERNORM_NO_RESIDUAL + : (has_residual ? CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL + : CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL); + auto compute_dtype = getCnnlDataType(input_view.scalar_type()); + auto attn_proj_desc = tmo::op_desc::AttnProjDesc(); + attn_proj_desc.setDesc(layernorm_residual_mode, /*layernorm_residual_mode*/ + compute_dtype, true, /*has_q*/ + has_k, /*has_k*/ + has_v, /*has_v*/ + has_bias, false, 0, /*no packed*/ + trans_out, /*trans_out*/ + norm_out, /*store_layernorm out*/ + alpha, beta, /*alpha && beta */ + eps /*eps*/); + + auto handle = torch_mlu::getCurrentHandle(); + // get workspace size + size_t workspace_size = 0; + CNNL_CHECK_FATAL(cnnlGetTransformerAttnProjWorkspaceSize(handle, attn_proj_desc, nullptr, + descs[0].get(), descs[1].get(), + descs[10].get(), &workspace_size)); + auto workspace = + at::empty({static_cast(workspace_size)}, input.options().dtype(at::kByte)); + + // run forward + CNNL_CHECK_FATAL(cnnlTransformerAttnProj( + handle, attn_proj_desc, nullptr, /*quant_desc*/ + descs[0].get(), getAtTensorPtr(input_view), /* input */ + descs[9].get(), getAtTensorPtr(residual), /* residual */ + descs[1].get(), getAtTensorPtr(q_weight), /* q weight */ + descs[3].get(), getAtTensorPtr(k_weight), /* k weight */ + descs[5].get(), getAtTensorPtr(v_weight), /* v weight */ + descs[2].get(), getAtTensorPtr(q_bias), /* q bias */ + descs[4].get(), getAtTensorPtr(k_bias), /* k bias */ + descs[6].get(), getAtTensorPtr(v_bias), /* v bias */ + nullptr, nullptr, /*no valid token*/ + descs[7].get(), getAtTensorPtr(norm_weight), /* layernorm weight */ + descs[8].get(), getAtTensorPtr(norm_bias), /* layernorm bias */ + getAtTensorPtr(workspace), workspace_size, /* workspace */ + descs[10].get(), getAtTensorPtr(out_q), /* q out */ + descs[11].get(), getAtTensorPtr(out_k), /* k out */ + descs[12].get(), getAtTensorPtr(out_v), /* v out */ + descs[13].get(), getAtTensorPtr(out_ln) /* layernorm out */ + )); + + // return + if (nDim == 2) { + out_q.squeeze_(0); + if (has_k) out_k.squeeze_(0); + if (has_v) out_v.squeeze_(0); + if (norm_out) out_ln.squeeze_(0); + } + std::vector output_list; + output_list.emplace_back(out_q); + if (has_k) output_list.emplace_back(out_k); + if (has_v) output_list.emplace_back(out_v); + if (norm_out) output_list.emplace_back(out_ln); + return output_list; +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/batch_matmul.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/batch_matmul.cpp new file mode 100644 index 0000000..c8785bc --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/batch_matmul.cpp @@ -0,0 +1,91 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" +namespace tmo { +namespace torch_api { + +void batch_matmul(const at::Tensor &a, + const at::Tensor &b, + const at::Tensor &c, + double alpha, + double beta, + double a_scale, + double b_scale, + bool trans_a, + bool trans_b) { + bool use_beta = beta == 0 ? false : true; + int batch = a.size(0); + int m = trans_a ? a.size(-1) : a.size(-2); + int n = trans_b ? b.size(-2) : b.size(-1); + checkTensorSameAttr(a, b, c); + // check contiguous + TORCH_CHECK(a.is_contiguous(), "a must be contiguous.") + TORCH_CHECK(b.is_contiguous(), "b must be contiguous.") + TORCH_CHECK(c.is_contiguous(), "c must be contiguous.") + CHECK_SHAPE(c, batch, m, n); + + // get cnnl data type and init output + auto a_dtype = getCnnlDataType(a.scalar_type()); + auto b_dtype = getCnnlDataType(b.scalar_type()); + TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype."); + auto c_dtype = getCnnlDataType(c.scalar_type()); + if (a_dtype == CNNL_DTYPE_BFLOAT16) c_dtype = CNNL_DTYPE_FLOAT; + // create tensor desc + auto descs = createTensorDescs({a, b, c}); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[2].get(), c_dtype)); + + if (a_dtype == CNNL_DTYPE_INT8) { + float int_max = 127.0; + int quant_bit = 8; + float max_a = int_max / a_scale; + float max_b = int_max / b_scale; + int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2)); + int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2)); + float new_a_scale = std::pow(2.0f, pos_a) * a_scale; + float new_b_scale = std::pow(2.0f, pos_b) * b_scale; + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale)); + TORCH_CHECK(c_dtype != CNNL_DTYPE_BFLOAT16, + "output dtype cannot be bfloat16 when a/b is fixed-point") + } + + // get && set Op desc + cnnlMatMulHeuristicResult_t heuristic_result; + cnnlMatMulAlgo_t algo; + auto bmm_ex_desc = + tmo::op_desc::BatchMatMulDesc(heuristic_result, algo, use_beta, trans_a, trans_b); + + const torch_mlu::mlu::MLUGuard device_guard(a.device()); + auto handle = torch_mlu::getCurrentHandle(); + + size_t workspace_size = 0; + int requested_algo_count = 1; + int returned_algo_count = 0; + CNNL_CHECK_FATAL(cnnlGetBatchMatMulExAlgoHeuristic( + handle, bmm_ex_desc, descs[0].get(), descs[1].get(), descs[2].get(), nullptr, + requested_algo_count, &heuristic_result, &returned_algo_count)); + CNNL_CHECK_FATAL(cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size)); + auto workspace = at::empty({static_cast(workspace_size)}, a.options().dtype(at::kByte)); + + // run forward + float alpha_f = alpha; + float beta_f = beta; + CNNL_CHECK_FATAL(cnnlBatchMatMulEx(handle, bmm_ex_desc, algo, &alpha_f, descs[0].get(), + getAtTensorPtr(a), descs[1].get(), getAtTensorPtr(b), &beta_f, + descs[2].get(), getAtTensorPtr(c), getAtTensorPtr(workspace), + workspace_size)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/cnpx.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/cnpx.cpp new file mode 100644 index 0000000..04e389a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/cnpx.cpp @@ -0,0 +1,45 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "cnpx.h" +#include +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +static cnpxDomainHandle_t domain = cnpxDomainCreate("CNPERF_KERNEL_TMO"); +static bool cnperf_kernel_analysis = getenv("CNPERF_KERNEL_ANALYSIS"); + +void cnpxPush(const OpTheory &op) { + if (cnperf_kernel_analysis) { + size_t calc = op.getTheoryCalc(); + cnnlDataType_t calc_dtype = op.getCalcDtype(); + size_t io = op.getTheoryIO(); + auto op_name = op.getOpName(); + std::ostringstream jsonStream; + jsonStream << "{\"name\":\"" << op_name << "\", \"theo_calc\":" << calc + << ", \"theo_bytes\":" << io << ", \"calc_type\":" << calc_dtype << "}"; + std::string json = jsonStream.str(); + // std::cout << json.c_str() << std::endl; + cnpxDomainRangePush(domain, json.c_str()); + } +} + +void cnpxPop() { + if (cnperf_kernel_analysis) { + cnpxDomainRangePop(domain); + } +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/comm_overlap.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/comm_overlap.cpp new file mode 100644 index 0000000..b64ed56 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/comm_overlap.cpp @@ -0,0 +1,39 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "comm_overlap.h" +namespace tmo { +namespace torch_api { +#define CNCL_TYPE_AND_SCALAR_TYPE(_) \ + _(cnclFloat, at::kFloat) \ + _(cnclBfloat16, at::kBFloat16) \ + _(cnclHalf, at::kHalf) \ + _(cnclInt32, at::kInt) \ + _(cnclInt64, at::kLong) \ + _(cnclInt8, at::kChar) \ + _(cnclUint8, at::kByte) \ + _(cnclInt16, at::kShort) + +cnclDataType_t getCnclDataType(const at::ScalarType &data_type) { + switch (data_type) { +#define DEFINE_CASE(cncl_dtype, scalar_type) \ + case scalar_type: \ + return cncl_dtype; + + CNCL_TYPE_AND_SCALAR_TYPE(DEFINE_CASE) +#undef DEFINE_CASE + default: + std::string msg("getCnclDataType() not supported for "); + throw std::runtime_error(msg + c10::toString(data_type)); + } +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/comm_overlap.h b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/comm_overlap.h new file mode 100644 index 0000000..998aef7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/comm_overlap.h @@ -0,0 +1,69 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_ +#define CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_ +#include +#include +#include +#include +#include "cncl.h" +#include "framework/core/MLUEvent.h" +#include "framework/core/MLUStream.h" +#include "torch/extension.h" +#include "torch_api/utils.h" +namespace tmo { +namespace torch_api { + +cnclDataType_t getCnclDataType(const at::ScalarType &data_type); + +template +struct ParallelAllReduce { + static std::vector event_; + static std::optional stream_; + cnclComm_t cncl_comm_; + std::vector d_list_; + + ParallelAllReduce(int64_t cncl_comm) { cncl_comm_ = reinterpret_cast(cncl_comm); } + + at::Tensor operator()(MmOp &mm) { + auto compute_stream_ = torch_mlu::getCurrentMLUStream(); + if (!stream_.has_value()) { + stream_.emplace(torch_mlu::getStreamFromPool()); + } + auto comm_stream_ = stream_.value(); + uint64_t loop = mm.getLoopNum(); + if (event_.size() < loop + 1) { + event_.resize(loop + 1); + } + for (uint64_t i = 0; i < loop; i++) { + d_list_.push_back(mm.forward(i)); + event_[i].place(compute_stream_); + event_[i].wait(comm_stream_); + // reduce_sum d_send + CNCL_CHECK(cnclAllReduce(getAtTensorPtr(d_list_[i]), getAtTensorPtr(d_list_[i]), + d_list_[i].numel(), getCnclDataType(d_list_[i].scalar_type()), + cnclSum, cncl_comm_, comm_stream_.stream())); + } + event_[loop].place(comm_stream_); + event_[loop].wait(compute_stream_); + return mm.getOutput(); + } +}; + +template +std::vector ParallelAllReduce::event_; + +template +std::optional ParallelAllReduce::stream_; +} // namespace torch_api +} // namespace tmo +#endif // CSRC_TORCH_API_COMPUTE_ALLREDUCE_COMM_OVERLAP_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/flash_attn_sq_matmul_allreduce.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/flash_attn_sq_matmul_allreduce.cpp new file mode 100644 index 0000000..3cab915 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/flash_attn_sq_matmul_allreduce.cpp @@ -0,0 +1,180 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "comm_overlap.h" +#include "torch_api/torch_ops_api.h" +namespace tmo { +namespace torch_api { + +using namespace torch::indexing; + +struct FlashAttnSmoothQuantMatmul { + at::Tensor q_; + at::Tensor k_; + at::Tensor v_; + const c10::optional &cu_seq_lens_q_; + const c10::optional &cu_seq_lens_kv_; + const at::Tensor &smooth_; + const at::Tensor &weight_; + const at::Tensor &weight_scale_; + const c10::optional &bias_; + const std::string &compute_dtype_; + std::string input_dtype_str_; + int64_t max_seq_len_q_; + int64_t max_seq_len_kv_; + double softmax_scale_; + bool is_causal_; + int64_t block_seq_; + std::vector cumsum_seq_; + at::Tensor output_; + + FlashAttnSmoothQuantMatmul(const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const c10::optional &cu_seq_lens_q, + const c10::optional &cu_seq_lens_kv, + const at::Tensor &smooth, + const at::Tensor &weight, + const at::Tensor &weight_scale, + const c10::optional &bias, + const int64_t max_seq_len_q, + const int64_t max_seq_len_kv, + const double softmax_scale, + const bool is_causal, + const std::string &compute_dtype, + const int64_t block_seq) + : q_(q), + k_(k), + v_(v), + cu_seq_lens_q_(cu_seq_lens_q), + cu_seq_lens_kv_(cu_seq_lens_kv), + smooth_(smooth), + weight_(weight), + weight_scale_(weight_scale), + bias_(bias), + compute_dtype_(compute_dtype), + max_seq_len_q_(max_seq_len_q), + max_seq_len_kv_(max_seq_len_kv), + softmax_scale_(softmax_scale), + is_causal_(is_causal) { + input_dtype_str_ = q.scalar_type() == at::kBFloat16 ? "bfloat16" + : q.scalar_type() == at::kHalf ? "half" + : "float"; + bool is_pack = cu_seq_lens_q.has_value(); + if (is_pack) { + TORCH_CHECK(cu_seq_lens_q.value().size(0) == 2 && cu_seq_lens_kv.value().size(0) == 2, + "only support 1 batch.") + TORCH_CHECK(q.dim() == 3 && k.dim() == 3 && v.dim() == 3, "q,k,v must be 3-d in pack mode.") + // 1 batch pack to pad + q_ = q_.unsqueeze(0); + k_ = k_.unsqueeze(0); + v_ = v_.unsqueeze(0); + } else { + TORCH_CHECK(q.size(0) == 1, "only support 1 batch.") + TORCH_CHECK(q.dim() == 4, "q must be 4-d in pad mode.") + } + int total_seq_q = q_.size(0) * q_.size(1); + output_ = at::empty({total_seq_q, weight.size(0)}, q.options()); + if (total_seq_q >= 4096) { + block_seq_ = 4; + split_4(total_seq_q); + } else { + block_seq_ = 1; + split_1(total_seq_q); + } + } + + void split_1(int seq) { + cumsum_seq_.resize(2); + cumsum_seq_[0] = 0; + cumsum_seq_[1] = seq; + } + + void split_4(int seq) { + auto pad_up = [](int x, int y) -> int { return (x + y - 1) / y * y; }; + int seq_4 = pad_up(seq / 8, 256); + int seq_3 = pad_up(seq / 8, 256); + int seq_2 = pad_up(seq / 4, 256); + int seq_1 = seq - seq_2 - seq_3 - seq_4; + cumsum_seq_.resize(5); + cumsum_seq_[0] = 0; + cumsum_seq_[1] = cumsum_seq_[0] + seq_1; + cumsum_seq_[2] = cumsum_seq_[1] + seq_2; + cumsum_seq_[3] = cumsum_seq_[2] + seq_3; + cumsum_seq_[4] = cumsum_seq_[3] + seq_4; + } + + auto split_tensor(const at::Tensor &a, int64_t dim, int64_t start, int64_t end) { + return a.narrow(dim, start, end - start); + } + + at::Tensor forward(const int64_t block_id) { + // flash attn + int64_t end_kv = is_causal_ ? cumsum_seq_[block_id + 1] : cumsum_seq_[block_seq_]; + auto q_i = split_tensor(q_, 1, cumsum_seq_[block_id], cumsum_seq_[block_id + 1]); + auto k_i = split_tensor(k_, 1, 0, end_kv); + auto v_i = split_tensor(v_, 1, 0, end_kv); + auto attn_out = at::empty_like(q_i); + flash_attention(q_i, k_i, v_i, attn_out, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, + c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, q_i.size(1), + k_i.size(1), softmax_scale_, is_causal_, -1, -1, compute_dtype_, false); + // smooth quant + auto smooth_quant_input = + attn_out.flatten(-2, -1).flatten(0, 1); // (b, s, hn, hs) -> (b*s, hn*hs) + auto quant_out = + at::empty(smooth_quant_input.sizes(), + torch::TensorOptions().dtype(torch::kInt8).device(smooth_quant_input.device())); + auto quant_out_scale = at::empty({smooth_quant_input.size(0)}, smooth_.options()); + smooth_quant(smooth_quant_input, smooth_, quant_out, quant_out_scale, c10::nullopt, + c10::nullopt, c10::nullopt, c10::nullopt, "per_token", true); + // quant matmul + auto d_i = quant_matmul( + quant_out, quant_out_scale, c10::nullopt, weight_, weight_scale_, c10::nullopt, bias_, + c10::nullopt, c10::nullopt, c10::nullopt, weight_scale_, c10::nullopt, input_dtype_str_, + split_tensor(output_, 0, cumsum_seq_[block_id], cumsum_seq_[block_id + 1]), "smooth_quant", + "quantize_per_token", "quantize_per_channel", 8, "none", false, 1.0, 1.0, 1.0, false, true); + return d_i; + } + + int64_t getLoopNum() const { return block_seq_; } + at::Tensor getOutput() const { return output_; } +}; + +at::Tensor flash_attn_sq_mm_allreduce(const int64_t cncl_comm, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const c10::optional &cu_seq_lens_q, + const c10::optional &cu_seq_lens_kv, + const c10::optional &alibi_slope, + const c10::optional &attn_bias, + const at::Tensor &smooth, + const at::Tensor &weight, + const at::Tensor &weight_scale, + const c10::optional &bias, + const int64_t max_seq_len_q, + const int64_t max_seq_len_kv, + const double softmax_scale, + const bool is_causal, + const int64_t window_size_left, + const int64_t window_size_right, + const std::string &compute_dtype, + const int64_t block_seq) { + FlashAttnSmoothQuantMatmul mm(q, k, v, cu_seq_lens_q, cu_seq_lens_kv, smooth, weight, + weight_scale, bias, max_seq_len_q, max_seq_len_kv, softmax_scale, + is_causal, compute_dtype, block_seq); + ParallelAllReduce parallel_rs(cncl_comm); + return parallel_rs(mm); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/group_gemm_combine_result_allreduce.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/group_gemm_combine_result_allreduce.cpp new file mode 100644 index 0000000..d80d5e8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/group_gemm_combine_result_allreduce.cpp @@ -0,0 +1,192 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include +#include "comm_overlap.h" +#include "kernels/moe/combine_result.mluh" +#include "torch_api/torch_ops_api.h" +#include "torch_api/utils.h" + +namespace tmo { +namespace torch_api { + +using namespace torch::indexing; + +struct GroupGemmCombineResultSplitN { + const at::Tensor &a_; + const at::Tensor &b_; + const at::Tensor &m_list_; + const at::Tensor &combine_idx_; + const at::Tensor &combine_weight_; + const c10::optional &c_; + const c10::optional &alpha_; + const c10::optional &beta_; + const c10::optional &a_scale_; + const c10::optional &b_scale_; + const c10::optional &data_type_; + int64_t num_token_; + int64_t topk_; + int64_t block_n_; + int64_t expert_num_; + int64_t n_; + bool has_c_; + bool has_alpha_; + bool has_beta_; + bool has_a_scale_; + bool has_b_scale_; + bool has_data_type_; + int64_t n_per_blk_; + cnrtQueue_t queue_; + at::Tensor output_buff_; + at::Tensor b_scale_trans_; + c10::optional b_offset_ = c10::nullopt; + cnnlDataType_t cnnl_dtype_; + + GroupGemmCombineResultSplitN(const at::Tensor &a_tensor, + const at::Tensor &b_tensor, + const at::Tensor &m_list, + const at::Tensor &combine_idx, + const at::Tensor &combine_weight, + const c10::optional &c_tensor, + const c10::optional &alpha, + const c10::optional &beta, + const c10::optional &a_scale, + const c10::optional &b_scale, + const c10::optional &data_type, + const int64_t num_token, + const int64_t topk, + const int64_t block_n) + : a_(a_tensor), + b_(b_tensor), + m_list_(m_list), + combine_idx_(combine_idx), + combine_weight_(combine_weight), + c_(c_tensor), + alpha_(alpha), + beta_(beta), + a_scale_(a_scale), + b_scale_(b_scale), + data_type_(data_type), + num_token_(num_token), + topk_(topk), + block_n_(block_n) { + has_c_ = c_.has_value(); + has_alpha_ = alpha_.has_value(); + has_beta_ = beta_.has_value(); + has_a_scale_ = a_scale_.has_value(); + has_b_scale_ = b_scale_.has_value(); + has_data_type_ = data_type_.has_value(); + auto b_shape = b_.sizes(); + expert_num_ = b_.dim() == 3 ? b_shape[0] : b_shape[0] * b_shape[2]; + n_ = b_shape[1]; + auto k = b_.size(-1); + auto m = a_.size(0); + set_block(m); + + queue_ = torch_mlu::getCurMLUStream(); + auto a_options = a_.options(); + if (has_data_type_) { + auto dtype = data_type_.value(); + TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16", + "data type must be 'float', 'half' or 'bfloat16'"); + auto torch_dtype = str2TorchDtype(dtype); + cnnl_dtype_ = str2CnnlDtype(dtype); + output_buff_ = at::empty({block_n_, num_token_, n_per_blk_}, a_options.dtype(torch_dtype)); + } else { + output_buff_ = at::empty({block_n_, num_token_, n_per_blk_}, a_options); + cnnl_dtype_ = getCnnlDataType(a_.scalar_type()); + } + if (has_b_scale_) { + b_scale_trans_ = b_scale_.value().view({expert_num_, block_n_, -1}); + b_scale_trans_ = b_scale_trans_.transpose(0, 1).contiguous(); + } + + auto b_offset = at::empty({expert_num_}, a_options.dtype(at::kLong).device(at::kCPU)); + auto element_size = b_.element_size(); + if (block_n_ > 1 && b_.dim() == 3) { + for (int64_t i = 0; i < expert_num_; i++) { + b_offset[i] = n_ * k * i * element_size; + } + b_offset_ = b_offset; + } else if (b_.dim() == 4) { + for (int64_t i = 0; i < expert_num_; i++) { + b_offset[i] = (i / b_shape[2] * b_.stride(0) + i % b_shape[2] * b_shape[3]) * element_size; + } + b_offset_ = b_offset; + } + } + + auto split(const at::Tensor &a, int64_t dim, int64_t start, int64_t block) { + return a.narrow(dim, start, block); + } + + at::Tensor forward(const int64_t block_id) { + auto gg_o = + group_gemm(a_, split(b_, 1, block_id * n_per_blk_, n_per_blk_), m_list_, c10::nullopt, + (has_c_) ? split(c_.value(), 1, block_id * n_per_blk_, n_per_blk_) : c_, alpha_, + beta_, a_scale_, has_b_scale_ ? b_scale_trans_[block_id] : b_scale_, + c10::nullopt, data_type_, c10::nullopt, b_offset_, num_token_); + + tmo::invokeMoeCombineResultKernel( + queue_, getAtTensorPtr(output_buff_[block_id]), getAtTensorPtr(gg_o), nullptr, nullptr, + (float *)getAtTensorPtr(combine_weight_), nullptr, (int *)getAtTensorPtr(combine_idx_), + num_token_, topk_, expert_num_, n_per_blk_, 0, expert_num_, cnnl_dtype_); + return output_buff_[block_id]; + } + + void set_block(int64_t m) { + if (block_n_ < 1) { + if (m >= 4096 && m < 8192 && n_ % 2048 == 0) { + n_per_blk_ = 2048; + block_n_ = n_ / n_per_blk_; + } else if (m >= 8192 && n_ % 1024 == 0) { + n_per_blk_ = 1024; + block_n_ = n_ / n_per_blk_; + } else { + n_per_blk_ = n_; + block_n_ = 1; + } + } else { + TORCH_CHECK(n_ % block_n_ == 0, "n must be divisible by block_n"); + n_per_blk_ = n_ / block_n_; + } + } + + int64_t getLoopNum() const { return block_n_; } + at::Tensor getOutput() const { return output_buff_.transpose(0, 1).reshape({num_token_, n_}); } +}; + +at::Tensor group_gemm_combine_result_allreduce(int64_t cncl_comm, + const at::Tensor &a_tensor, + const at::Tensor &b_tensor, + const at::Tensor &m_list, + const at::Tensor &combine_idx, + const at::Tensor &combine_weight, + const c10::optional &c_tensor, + const c10::optional &alpha, + const c10::optional &beta, + const c10::optional &a_scale, + const c10::optional &b_scale, + const c10::optional &data_type, + const int64_t num_token, + const int64_t topk, + const int64_t block_n) { + GroupGemmCombineResultSplitN gg(a_tensor, b_tensor, m_list, combine_idx, combine_weight, c_tensor, + alpha, beta, a_scale, b_scale, data_type, num_token, topk, + block_n); + ParallelAllReduce parallel_rs(cncl_comm); + return parallel_rs(gg); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/matmul_allreduce.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/matmul_allreduce.cpp new file mode 100644 index 0000000..704a808 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/matmul_allreduce.cpp @@ -0,0 +1,94 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "comm_overlap.h" +#include "torch_api/torch_ops_api.h" +namespace tmo { +namespace torch_api { + +using namespace torch::indexing; + +struct MatmulSplitM { + const at::Tensor &a_; + const at::Tensor &b_; + const c10::optional &bias_; + const c10::optional &c_; + const c10::optional &d_; + double alpha_; + double beta_; + int64_t block_m_; + bool has_res_; + bool has_output_; + at::Tensor output_; + + MatmulSplitM(const at::Tensor &a, + const at::Tensor &b, + const c10::optional &bias, + const c10::optional &c, + const c10::optional &d, + const double alpha, + const double beta, + const int64_t block_m) + : a_(a), b_(b), bias_(bias), c_(c), d_(d), alpha_(alpha), beta_(beta), block_m_(block_m) { + auto m = a_.size(0); + if (block_m_ > m) { + block_m_ = 1; + } else if (block_m_ < 1) { + block_m_ = m > 8192 ? 8 : (m < 2048 ? 1 : 4); + } + + has_res_ = c_.has_value(); + has_output_ = d_.has_value(); + if (has_output_) { + output_ = d_.value(); + } else { + output_ = at::empty({a.size(0), b.size(0)}, a.options()); + } + } + + auto split(const at::Tensor &a, const int64_t block_id) { + auto m = a.size(0); + auto m_per_blk = m / block_m_; + auto remain = m % block_m_; + auto start = block_id * m_per_blk + std::min(block_id, remain); + auto end = (block_id + 1) * m_per_blk + std::min(block_id + 1, remain); + return a.narrow(0, start, end - start); + } + + at::Tensor forward(const int64_t block_id) { + auto Di = split(output_, block_id); + matmul(split(a_, block_id), b_, Di, bias_, has_res_ ? split(c_.value(), block_id) : c_, None, + "none", alpha_, beta_, true, true, 1.0, 1.0, false, true); + return Di; + } + + int64_t getLoopNum() const { return block_m_; } + at::Tensor getOutput() const { return output_; } + at::Tensor getDSplit(const int64_t block_id) { return split(output_, block_id); } +}; + +at::Tensor matmul_allreduce(const int64_t cncl_comm, + const at::Tensor &a, + const at::Tensor &b, + const c10::optional &bias, + const c10::optional &c, + const c10::optional &d, + const double alpha, + const double beta, + const int64_t block_m) { + MatmulSplitM mm(a, b, bias, c, d, alpha, beta, block_m); + ParallelAllReduce parallel_rs(cncl_comm); + return parallel_rs(mm); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/quant_matmul_allreduce.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/quant_matmul_allreduce.cpp new file mode 100644 index 0000000..3694295 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/compute_allreduce/quant_matmul_allreduce.cpp @@ -0,0 +1,198 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include +#include +#include "comm_overlap.h" +#include "torch_api/torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +using namespace torch::indexing; + +struct QuantMatmulSplitM { + const at::Tensor &a_; + const c10::optional &a_scale_; + const c10::optional &a_zero_; + const at::Tensor &b_; + const c10::optional &b_scale_; + const c10::optional &b_zero_; + const c10::optional &bias_; + const c10::optional &c_; + const c10::optional &c_scale_; + const c10::optional &c_zero_; + const c10::optional &output_scale_; + const c10::optional &output_zero_; + const c10::optional &data_type_; + const c10::optional &d_; + const std::string &quant_algo_; + const std::string &a_quant_layout_; + const std::string &b_quant_layout_; + int64_t quant_bit_size_; + double alpha_; + double beta_; + bool trans_a_; + bool trans_b_; + int64_t block_m_; + bool has_a_scale_; + bool has_a_zero_; + bool has_b_scale_; + bool has_b_zero_; + bool has_bias_; + bool has_c_; + bool has_c_scale_; + bool has_c_zero_; + bool has_output_scale_; + bool has_output_zero_; + bool has_output_; + bool has_dtype_; + at::Tensor output_; + + QuantMatmulSplitM(const at::Tensor &a_tensor, + const c10::optional &a_scale, + const c10::optional &a_zero, + const at::Tensor &b_tensor, + const c10::optional &b_scale, + const c10::optional &b_zero, + const c10::optional &bias, + const c10::optional &c_tensor, + const c10::optional &c_scale, + const c10::optional &c_zero, + const c10::optional &gemm_output_scale, + const c10::optional &gemm_output_zero, + const c10::optional &data_type, + const c10::optional &d, + const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + int64_t quant_bit_size, + double alpha, + double beta, + bool trans_a, + bool trans_b, + const int64_t block_m) + : a_(a_tensor), + a_scale_(a_scale), + a_zero_(a_zero), + b_(b_tensor), + b_scale_(b_scale), + b_zero_(b_zero), + bias_(bias), + c_(c_tensor), + c_scale_(c_scale), + c_zero_(c_zero), + output_scale_(gemm_output_scale), + output_zero_(gemm_output_zero), + data_type_(data_type), + d_(d), + quant_algo_(quant_algo), + a_quant_layout_(a_quant_layout), + b_quant_layout_(b_quant_layout), + quant_bit_size_(quant_bit_size), + alpha_(alpha), + beta_(beta), + trans_a_(trans_a), + trans_b_(trans_b), + block_m_(block_m) { + auto m = a_.size(0); + if (block_m_ > m) { + block_m_ = 1; + } else if (block_m_ < 1) { + block_m_ = m > 8192 ? 8 : (m < 2048 ? 1 : 4); + } + has_a_scale_ = a_scale_.has_value(); + has_a_zero_ = a_zero_.has_value(); + has_b_scale_ = b_scale_.has_value(); + has_b_zero_ = b_zero_.has_value(); + has_bias_ = bias_.has_value(); + has_c_ = c_.has_value(); + has_c_scale_ = c_scale_.has_value(); + has_c_zero_ = c_zero_.has_value(); + has_output_scale_ = output_scale_.has_value(); + has_output_zero_ = output_zero_.has_value(); + has_output_ = d_.has_value(); + has_dtype_ = data_type_.has_value(); + auto a_options = a_.options(); + if (has_output_) { + output_ = d_.value(); + } else if (has_dtype_) { + auto dtype = data_type_.value(); + TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16", + "data type must be 'float', 'half' or 'bfloat16'"); + auto torch_dtype = str2TorchDtype(dtype); + output_ = at::empty({a_.size(0), b_.size(0)}, a_options.dtype(torch_dtype)); + } else { + output_ = at::empty({a_.size(0), b_.size(0)}, a_options); + } + } + + auto split(const at::Tensor &a, const int64_t block_id) { + auto m = a.size(0); + auto m_per_blk = m / block_m_; + auto remain = m % block_m_; + auto start = block_id * m_per_blk + std::min(block_id, remain); + auto end = (block_id + 1) * m_per_blk + std::min(block_id + 1, remain); + return a.narrow(0, start, end - start); + } + + at::Tensor forward(const int64_t block_id) { + auto Di = quant_matmul( + split(a_, block_id), has_a_scale_ ? split(a_scale_.value(), block_id) : a_scale_, + has_a_zero_ ? split(a_zero_.value(), block_id) : a_zero_, b_, b_scale_, b_zero_, bias_, + has_c_ ? split(c_.value(), block_id) : c_, + has_c_scale_ ? split(c_scale_.value(), block_id) : c_scale_, + has_c_zero_ ? split(c_zero_.value(), block_id) : c_zero_, output_scale_, output_zero_, + data_type_, split(output_, block_id), quant_algo_, a_quant_layout_, b_quant_layout_, + quant_bit_size_, "none", false, 1.0, alpha_, beta_, trans_a_, trans_b_); + return Di; + } + + int64_t getLoopNum() const { return block_m_; } + at::Tensor getOutput() const { return output_; } +}; + +at::Tensor quant_matmul_allreduce(const int64_t cncl_comm, + const at::Tensor &a_tensor, + const c10::optional &a_scale, + const c10::optional &a_zero, + const at::Tensor &b_tensor, + const c10::optional &b_scale, + const c10::optional &b_zero, + const c10::optional &bias, + const c10::optional &c_tensor, + const c10::optional &c_scale, + const c10::optional &c_zero, + const c10::optional &gemm_output_scale, + const c10::optional &gemm_output_zero, + const c10::optional &data_type, + const c10::optional &d, + const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + int64_t quant_bit_size, + double alpha, + double beta, + bool trans_a, + bool trans_b, + const int64_t block_m) { + TORCH_CHECK(!trans_a && trans_b, "trans_a must be false and trans_b must be true"); + QuantMatmulSplitM quant_mm(a_tensor, a_scale, a_zero, b_tensor, b_scale, b_zero, bias, c_tensor, + c_scale, c_zero, gemm_output_scale, gemm_output_zero, data_type, d, + quant_algo, a_quant_layout, b_quant_layout, quant_bit_size, alpha, + beta, trans_a, trans_b, block_m); + ParallelAllReduce parallel_rs(cncl_comm); + return parallel_rs(quant_mm); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/copy_blocks.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/copy_blocks.cpp new file mode 100644 index 0000000..b820921 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/copy_blocks.cpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/copy_blocks.mluh" +#include +#include +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { +void copy_blocks(const std::vector &k_caches, + const std::vector &v_caches, + const c10::Dict> &block_mapping_dict) { + // Create block mapping array. + std::vector block_mapping_vec; + for (const auto &item : block_mapping_dict) { + int64_t src_block_number = item.key(); + auto value_vec = item.value().vec(); + for (int64_t dst_block_number : value_vec) { + block_mapping_vec.push_back(int32_t(src_block_number)); + block_mapping_vec.push_back(int32_t(dst_block_number)); + } + } + + TORCH_CHECK(!k_caches.empty() && !block_mapping_vec.empty(), + "k_caches and block_mapping_vec can not be empty.") + int32_t num_layers = k_caches.size(); + for (auto i = 0; i < num_layers; i++) { + TORCH_CHECK(k_caches[i].dim() == 4, "every layer k_cache must be 4d.") + } + // check same device and tensor type + TORCH_CHECK(isMlu(k_caches[0]), "k_caches must on mlu."); + TORCH_CHECK(k_caches[0].dtype() == torch::kInt8 || k_caches[0].dtype() == torch::kUInt8 || + k_caches[0].dtype() == torch::kInt16 || k_caches[0].dtype() == torch::kInt32 || + k_caches[0].dtype() == torch::kLong || k_caches[0].dtype() == torch::kFloat16 || + k_caches[0].dtype() == torch::kFloat32 || k_caches[0].dtype() == torch::kBFloat16, + "data type only supports torch::kInt8, torch::kUInt8, torch::kInt16, torch::kInt32, " + "torch::kLong, torch::kFloat16, torch::kFloat32 and torch::kBFloat16"); + if (!v_caches.empty()) { + TORCH_CHECK(k_caches.size() == v_caches.size(), + "k_caches size must equal to " + "v_caches size if v_caches is not none.") + TORCH_CHECK(isMlu(v_caches[0]), "v_caches must on mlu."); + for (auto i = 0; i < num_layers; i++) { + TORCH_CHECK(k_caches[i].dtype() == v_caches[i].dtype(), + "the data type of k_caches and v_caches are not the same.") + TORCH_CHECK(k_caches[i].dim() == v_caches[i].dim(), + "every layer k_cache dim must equal to v_cache dim.") + // check shape + TORCH_CHECK(k_caches[i][0].numel() == v_caches[0][0].numel(), + "the block_size of k_caches and v_caches are not the same.") + } + } + + const torch_mlu::mlu::MLUGuard device_guard(k_caches[0].device()); + auto queue = torch_mlu::getCurMLUStream(); + + size_t block_size_bytes = k_caches[0][0].numel() * k_caches[0].element_size(); + std::vector new_key_caches, new_value_caches; + for (auto i = 0; i < num_layers; ++i) { + new_key_caches.push_back(k_caches[i].data_ptr()); + if (!v_caches.empty()) { + new_value_caches.push_back(v_caches[i].data_ptr()); + } + } + TMO_KERNEL_CHECK_FATAL(invokeCopyBlocksKernel(queue, new_key_caches, new_value_caches, + block_mapping_vec, block_size_bytes)); +} + +std::tuple, std::vector> copy_blocks_out_of_place( + const std::vector &k_caches, + const std::vector &v_caches, + const c10::Dict> &block_mapping) { + copy_blocks(k_caches, v_caches, block_mapping); + return std::make_tuple(k_caches, v_caches); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_linear_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_linear_cache.cpp new file mode 100644 index 0000000..6b39830 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_linear_cache.cpp @@ -0,0 +1,176 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/dequant_from_linear_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +#define CHECK_TENSOR_DTYPE(x, expected_type) \ + TORCH_CHECK(x.scalar_type() == expected_type, "Tensor " #x " type should be ", \ + torchDtype2Str(expected_type), "."); + +void dequant_from_linear_cache( + at::Tensor &key, // [total_seqlen, head_num, head_size] + const c10::optional &value, // same as above + const at::Tensor &key_cache, // [max_batch_size, head_num, cache_mem_len, head_size] + const c10::optional &value_cache, // same as above + const at::Tensor &key_quant_scale, // quant_mode is 0: [head_num, head_size] + // quant_mode is 1: + // [max_batch_size, head_num, cache_mem_len] + const c10::optional &value_quant_scale, // same as above + const at::Tensor &context_lengths, + const int64_t max_context_len, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seq_offset, + const int64_t quant_mode, // 0:per_channel, 1:per_head + const int64_t quant_bit) { // 4 or 8 + // check same attr for tensors + checkTensorSameAttr(key, key_cache, key_quant_scale, value, value_cache, + value_quant_scale, context_lengths, context_seq_offset, + cache_bs_id, cache_seq_offset); + checkTensorSameAttr(context_lengths, context_seq_offset, cache_bs_id, + cache_seq_offset); + + // check quant parameters first + TORCH_CHECK(quant_mode >= 0 && quant_mode <= 1, "quantization mode support 0 and 1."); + TORCH_CHECK(quant_bit == 4 || quant_bit == 8, "quantization bit width support 4 and 8."); + + /****************************************check key***************************************/ + // check dtype + TORCH_CHECK(key.scalar_type() == torch::kFloat16 || key.scalar_type() == torch::kBFloat16, + "Tensor key type should be half or bfloat16."); + CHECK_TENSOR_DTYPE(key_cache, torch::kInt8); + CHECK_TENSOR_DTYPE(key_quant_scale, torch::kFloat32); + + // check_contiguous + TORCH_CHECK(key.stride(-1) == 1, "Tensor key last dim must be contiguous."); + CHECK_TENSOR_CONTIGUOUS(key_cache); + CHECK_TENSOR_CONTIGUOUS(key_quant_scale); + + // check shape + TORCH_CHECK(key.dim() == 3, "The dimensions of tensor key only support 3."); + TORCH_CHECK(key_cache.dim() == 4, "The dimensions of tensor key_cache only supports 4."); + TORCH_CHECK(context_lengths.dim() == 1, + "The dimensions of tensor context_lengths only supports 1."); + const int32_t total_seqlen = key.size(0); + const int32_t head_num = key.size(1); + const int32_t head_size = key.size(2); + const int32_t max_batch_size = key_cache.size(0); + const int32_t cache_mem_len = key_cache.size(2); + const int32_t batch_size = context_lengths.size(0); + CHECK_SHAPE(context_lengths, batch_size); + TORCH_CHECK(max_batch_size >= batch_size, + "max_batch_size should be greater than or equal to batch_size."); + TORCH_CHECK(cache_mem_len % 2 == 0, "cache_mem_len should be a multiply of 2."); + if (quant_mode == 0) { + CHECK_SHAPE(key_quant_scale, head_num, head_size); + } else if (quant_mode == 1) { + CHECK_SHAPE(key_quant_scale, max_batch_size, head_num, cache_mem_len); + } + + if (quant_bit == 4) { + TORCH_CHECK(head_size % 2 == 0, "head_size should be a multiply of 2 if quant_bit is 4."); + CHECK_SHAPE(key_cache, max_batch_size, head_num, cache_mem_len, head_size >> 1); + } else { + CHECK_SHAPE(key_cache, max_batch_size, head_num, cache_mem_len, head_size); + } + + /***************************************check value***************************************/ + if (value.has_value() || value_cache.has_value() || value_quant_scale.has_value()) { + TORCH_CHECK(value.has_value() && value_cache.has_value() && value_quant_scale.has_value(), + "value, value_cache, and value_quant_scale must all exists.") + } + + if (value_cache.has_value()) { + checkTensorSameAttr(key, value); + checkTensorSameAttr(key_cache, value_cache); + checkTensorSameAttr(key_quant_scale, value_quant_scale); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_quant_scale); + TORCH_CHECK(key_quant_scale.dim() == value_quant_scale.value().dim(), + "value_cache_quant_scale dim should keep same with key_cache_quant_scale."); + CHECK_SHAPE(value.value(), total_seqlen, head_num, head_size); + if (quant_mode == 0) { + CHECK_SHAPE(value_quant_scale.value(), head_num, head_size); + } else { + CHECK_SHAPE(value_quant_scale.value(), max_batch_size, head_num, cache_mem_len); + } + + if (quant_bit == 4) { + CHECK_SHAPE(value_cache.value(), max_batch_size, head_num, cache_mem_len >> 1, head_size); + } else { + CHECK_SHAPE(value_cache.value(), max_batch_size, head_num, cache_mem_len, head_size); + } + + for (int i = 0; i < key.dim(); i++) { + TORCH_CHECK(value.value().stride(i) == key.stride(i), + "key and value must have same stride along axi ", i, "."); + } + } + + TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32, + "context_lengths type need be torch::kInt32."); + /*********************************check optional tensor***********************************/ + const void *context_seq_offset_ptr = nullptr; + if (context_seq_offset.has_value()) { + CHECK_SHAPE(context_seq_offset.value(), batch_size); + context_seq_offset_ptr = context_seq_offset.value().data_ptr(); + } else { + TORCH_CHECK(batch_size <= 1024, + "batch_size greater than 1024 not support when context_seq_offset is None."); + } + + const void *cache_bs_id_ptr = nullptr; + if (cache_bs_id.has_value()) { + CHECK_SHAPE(cache_bs_id.value(), batch_size); + cache_bs_id_ptr = cache_bs_id.value().data_ptr(); + } + + const void *cache_seq_offset_ptr = nullptr; + if (cache_seq_offset.has_value()) { + CHECK_SHAPE(cache_seq_offset.value(), batch_size); + cache_seq_offset_ptr = cache_seq_offset.value().data_ptr(); + } + + // propare parameters before calling invokeDequantFromLinearCache + const int32_t key_group_num = 1; + const int32_t value_group_num = 1; + const size_t context_head_stride = key.stride(1); + const size_t context_seq_stride = key.stride(0); + const size_t cache_bs_stride = key_cache.stride(0); + const size_t cache_head_stride = key_cache.stride(1); + const size_t key_cache_seq_stride = key_cache.stride(2); + const size_t value_cache_seq_stride = value_cache.has_value() ? value_cache.value().stride(2) : 0; + const size_t cache_scale_bs_stride = key_quant_scale.dim() == 3 ? key_quant_scale.stride(0) : 0; + const size_t cache_scale_head_stride = + key_quant_scale.dim() == 3 ? key_quant_scale.stride(1) : key_quant_scale.stride(0); + + const torch_mlu::mlu::MLUGuard device_guard(key.device()); + auto data_dtype = getCnnlDataType(key.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + + // run forward + TMO_KERNEL_CHECK_FATAL(invokeDequantFromLinearCache( + queue, getAtTensorPtr(key), getAtTensorPtr(value), getAtTensorPtr(key_cache), + getAtTensorPtr(value_cache), getAtTensorPtr(key_quant_scale), + getAtTensorPtr(value_quant_scale), getAtTensorPtr(context_lengths), context_seq_offset_ptr, + cache_bs_id_ptr, cache_seq_offset_ptr, (int)max_context_len, batch_size, head_num, + key_group_num, value_group_num, cache_mem_len, head_size, quant_mode, quant_bit, + context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride, + key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride, + data_dtype)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_paged_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_paged_cache.cpp new file mode 100644 index 0000000..57514e5 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/dequant_from_paged_cache.cpp @@ -0,0 +1,162 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/dequant_from_paged_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +#define CHECK_TENSOR_DTYPE(x, expected_type) \ + TORCH_CHECK(x.scalar_type() == expected_type, "Tensor " #x " type should be ", \ + torchDtype2Str(expected_type), "."); + +void dequant_from_paged_cache(at::Tensor &key, // [total_seqlen, head_num, head_size] + const c10::optional &value, // same as above + const at::Tensor &key_cache, // [token_num, head_num, block_size] + const c10::optional &value_cache, // same as above + // [token_num, head_num, block_size, head_size] for per-token + // [head_num, head_size] for per-channel + const at::Tensor &key_cache_quant_scale, + const c10::optional &value_cache_quant_scale, + const at::Tensor &context_lengths, // [batch_size] + int64_t max_context_len, + const c10::optional &context_seq_offset, // [batch_size] + const at::Tensor &block_tables, // [batch_size, max_block_num] + int64_t quant_mode, // 0 is per-channel, 1 is per-token + int64_t quant_bit) { // quantization bit only support 8 + // check same attr for tensors + checkTensorSameAttr(key, key_cache, key_cache_quant_scale, value, value_cache, + value_cache_quant_scale, context_lengths, + context_seq_offset, block_tables); + checkTensorSameAttr(context_lengths, context_seq_offset, block_tables); + + // check quant parameters first + TORCH_CHECK(quant_bit == 8, "quantization bit width only supports 8."); + TORCH_CHECK(quant_mode >= 0 && quant_mode <= 1, "quantization mode support 0 and 1."); + + /***************************************check key***************************************/ + // check dtype + TORCH_CHECK(key.scalar_type() == torch::kFloat16 || key.scalar_type() == torch::kBFloat16, + "Tensor key type should be half or bfloat16."); + CHECK_TENSOR_DTYPE(key_cache, torch::kInt8); + CHECK_TENSOR_DTYPE(key_cache_quant_scale, torch::kFloat32); + CHECK_TENSOR_DTYPE(block_tables, torch::kInt32); + + // check contiguous + TORCH_CHECK(key.stride(-1) == 1, "Tensor key last dim must be contiguous.") + CHECK_TENSOR_CONTIGUOUS(key_cache); + CHECK_TENSOR_CONTIGUOUS(key_cache_quant_scale); + CHECK_TENSOR_CONTIGUOUS(block_tables); + + // check shape + TORCH_CHECK(key.dim() == 3, "The dimensions of tensor key only supports 3."); + TORCH_CHECK(key_cache.dim() == 4, "The dimensions of tensor key_cache only supports 4."); + TORCH_CHECK(context_lengths.dim() == 1, + "The dimensions of tensor context_lengths only supports 1."); + TORCH_CHECK(block_tables.dim() == 2, "The dimensions of tensor block_tables only supports 2."); + const int32_t token_num = key.size(0); + const int32_t head_num = key.size(1); + const int32_t head_size = key.size(2); + const int32_t block_num = key_cache.size(0); + const int32_t block_size = key_cache.size(2); + const int32_t batch_size = context_lengths.size(0); + const int32_t max_block_num = block_tables.size(1); + CHECK_SHAPE(key_cache, block_num, head_num, block_size, head_size); + int64_t kv_cache_range = + (int64_t)block_num * head_num * block_size * head_size * key_cache.element_size(); + // kernel use uint32_t to calculate offsets + TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of key_cache cannot exceed 4GB."); + if (quant_mode == 0) { + CHECK_SHAPE(key_cache_quant_scale, head_num, head_size); + } else if (quant_mode == 1) { + CHECK_SHAPE(key_cache_quant_scale, block_num, head_num, block_size); + } + + /***************************************check value***************************************/ + if (value.has_value() || value_cache.has_value() || value_cache_quant_scale.has_value()) { + TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_quant_scale.has_value(), + "value, value_cache, and value_cache_quant_scale must all exists.") + } + + if (value_cache.has_value()) { + checkTensorSameAttr(key, value); + checkTensorSameAttr(key_cache, value_cache); + checkTensorSameAttr(key_cache_quant_scale, value_cache_quant_scale); + TORCH_CHECK(value.value().stride(-1) == 1, "Tensor value last dim must be contiguous.") + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(value_cache_quant_scale); + CHECK_SHAPE(value.value(), token_num, head_num, head_size); + TORCH_CHECK(value_cache_quant_scale.value().dim() == key_cache_quant_scale.dim(), + "value_cache_quant_scale dim should keep same with key_cache_quant_scale."); + if (quant_mode == 0) { + CHECK_SHAPE(value_cache_quant_scale.value(), head_num, head_size); + } else { + CHECK_SHAPE(value_cache_quant_scale.value(), block_num, head_num, block_size); + } + + CHECK_SHAPE(value_cache.value(), block_num, head_num, block_size, head_size); + + for (int i = 0; i < key.dim(); i++) { + TORCH_CHECK(value.value().stride(i) == key.stride(i), + "key and value must have same stride along axi ", i, "."); + } + } + + /*********************************check index tensor***********************************/ + CHECK_SHAPE(context_lengths, batch_size); + CHECK_SHAPE(block_tables, batch_size, max_block_num); + TORCH_CHECK(max_context_len <= block_size * max_block_num, + "max_context_len should smaller than or equal to block_size * max_block_num."); + + /*********************************check optional tensor***********************************/ + const void *context_seq_offset_ptr = nullptr; + if (context_seq_offset.has_value()) { + CHECK_SHAPE(context_seq_offset.value(), batch_size); + context_seq_offset_ptr = context_seq_offset.value().data_ptr(); + } else { + TORCH_CHECK(batch_size <= 1024, + "batch_size greater than 1024 not support when context_seq_offset is None."); + } + + // propare parameters before calling invokeDequantFromLinearCache + const int32_t key_group_num = 1; + const int32_t value_group_num = 1; + const size_t context_head_stride = key.stride(1); + const size_t context_seq_stride = key.stride(0); + const size_t cache_bs_stride = key_cache.stride(0); + const size_t cache_head_stride = key_cache.stride(1); + const size_t key_cache_seq_stride = key_cache.stride(2); + const size_t value_cache_seq_stride = value_cache.has_value() ? value_cache.value().stride(2) : 0; + const size_t cache_scale_bs_stride = + key_cache_quant_scale.dim() == 3 ? key_cache_quant_scale.stride(0) : 0; + const size_t cache_scale_head_stride = key_cache_quant_scale.dim() == 3 + ? key_cache_quant_scale.stride(1) + : key_cache_quant_scale.stride(0); + + const torch_mlu::mlu::MLUGuard device_guard(key.device()); + auto queue = torch_mlu::getCurMLUStream(); + auto data_dtype = getCnnlDataType(key.scalar_type()); + + // run forward + TMO_KERNEL_CHECK_FATAL(invokeDequantFromPagedCache( + queue, getAtTensorPtr(key), getAtTensorPtr(value), getAtTensorPtr(key_cache), + getAtTensorPtr(value_cache), getAtTensorPtr(key_cache_quant_scale), + getAtTensorPtr(value_cache_quant_scale), getAtTensorPtr(context_lengths), + context_seq_offset_ptr, getAtTensorPtr(block_tables), (int)max_context_len, max_block_num, + batch_size, head_num, key_group_num, value_group_num, block_size, head_size, quant_mode, + quant_bit, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride, + key_cache_seq_stride, value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride, + data_dtype)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/ffn.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/ffn.cpp new file mode 100644 index 0000000..e7b9bca --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/ffn.cpp @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "common/utils.h" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +/* +torch_ffn_api API funtion in pytorch op is: +class FeedForward(torch.nn.Module): + def __init__(self, hidden_size: int, inner_size: int, act_mode: str, + bias = False, gated = False): + super(FeedForward, self).__init__() + self.up_linear = torch.nn.Linear(hidden_size, inner_size, bias) + self.gated = gated + if self.gated: + self.gated_linear = torch.nn.Linear(hidden_size, inner_size, bias) + self.down_linear = torch.nn.Linear(inner_size, hidden_size, bias) + self.act = act_mode_dict[act_mode] + + def forward(self, x): + act_out = self.act(self.up_linear(x).float()).to(x.dtype) + return self.down_linear(act_out * self.gated_linear(x)) \ + if self.gated else self.down_linear(act_out) + +Demo of pytorch python module in TGI: +class LlamaBtMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.act = config.hidden_act + self.gate_proj_weight = weights.get_multi_weights_col( + f"{prefix}.gate_proj", quantize=config.quantize, dim=dim + ) + self.gate_proj_bias = None + self.up_proj_weight = weights.get_multi_weights_col( + f"{prefix}.up_proj", quantize=config.quantize, dim=dim + ) + self.up_proj_bias = None + self.down_proj_weight = weights.get_multi_weights_col( + f"{prefix}.down_proj", quantize=config.quantize, dim=dim + ) + self.down_proj_bias = None + + def forward(self, hidden_states): + return tmo.ffn(hidden_states, self.up_proj_weight, self.up_proj_bias, + self.down_proj_weight, self.down_proj_bias, self.gate_proj_weight, + self.gate_proj_bias, self.act) +*/ + +// act_mode now only support silu, gelu and relu. +// std::string pytorch func +// silu --> nn.SiLU +// gelu --> nn.functional.gelu +// relu --> nn.ReLU +// Maybe aligned with transformer act mode later. +// https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L200 + +// Dimensions of ffn . +// Dimension of input: [batch_num, seq_len, hidden_size] +// Dimension of up_fc_filters: [filter_size, hidden_size] +// Dimension of up_fc_bias: [filter_size] +// Dimension of down_fc_filters: [hidden_size, filter_size] +// Dimension of down_fc_bias: [hidden_size] +// Dimension of gated_fc_filters: [filter_size, hidden_size] only for gated fnn +// Dimension of gated_fc_bias: [filter_size] only for gated fnn +// Dimension of layer norm weight: [seq_len, hidden_size] only for layer norm fused +// Dimension of layer norm bias: [seq_len, hidden_size] only for layer norm fused +// Dimension of output is same as input. +// Dimension of \b output: [batch_num, seq_len, hidden_size] + +// Fused layer norm is not support now. And only support fused +// per layer norm later. + +at::Tensor ffn(const at::Tensor &input, + const at::Tensor &up_fc_weight, + const c10::optional &up_fc_bias, + const at::Tensor &down_proj_weight, + const c10::optional &down_proj_bias, + const c10::optional &gate_up_proj_weight, + const c10::optional &gate_up_proj_bias, + const c10::optional &layernorm_weight, + const c10::optional &layernorm_bias, + const std::string &act_mode, + const std::string &residual_is, + double eps, + double alpha, + double beta) { + // Check tensor type and tensor device. + checkTensorSameAttr(input, up_fc_weight, up_fc_bias, down_proj_weight, + down_proj_bias, gate_up_proj_weight, gate_up_proj_bias, + layernorm_weight, layernorm_bias); + // Check dims. + const int64_t nDim = input.dim(); + at::Tensor input_view = input; + if (nDim == 2) input_view = input_view.unsqueeze(0); + + // Check contiguous. + CHECK_TENSOR_CONTIGUOUS(input_view) + + const torch_mlu::mlu::MLUGuard device_guard(input_view.device()); + at::Tensor output = at::empty_like(input_view); + + // Convert torch tensor to tensor desc + auto descs = createTensorDescs( + {input_view, up_fc_weight, up_fc_bias.value_or(torch::Tensor()), down_proj_weight, + down_proj_bias.value_or(torch::Tensor()), gate_up_proj_weight.value_or(torch::Tensor()), + gate_up_proj_bias.value_or(torch::Tensor()), layernorm_weight.value_or(torch::Tensor()), + layernorm_bias.value_or(torch::Tensor()), output}); + + bool has_ln = layernorm_weight.has_value(); + TORCH_CHECK(residual_is == "input" || residual_is == "normed_input" || residual_is == "none", + "residual_is must be 'input' or 'normed_input' or 'none'.") + bool has_residual = residual_is != "none"; + auto ln_res_mode = tmo::lnres::makeLnresEnum(has_ln, has_residual, residual_is == "input"); + auto compute_type = getCnnlDataType(input_view.scalar_type()); + + tmo::op_desc::FeedForwardDesc ffn_desc(ln_res_mode, act_mode, compute_type, eps, alpha, beta); + if (gate_up_proj_weight.has_value() && gate_up_proj_weight->defined()) { + CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptorGateFiltersBias( + ffn_desc, descs[5].get(), getAtTensorPtr(gate_up_proj_weight), descs[6].get(), + getAtTensorPtr(gate_up_proj_bias))); + } + + // Get current handle. + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + // Get workspace size and malloc workspace. + size_t workspace_size = 0; + CNNL_CHECK_FATAL(cnnlGetTransformerFeedForwardWorkspaceSize_v2( + handle, ffn_desc, descs[0].get(), descs[1].get(), nullptr, ffn_desc, &workspace_size)); + auto workspace = + at::empty({static_cast(workspace_size)}, input.options().dtype(at::kByte)); + + // forward + CNNL_CHECK_FATAL(cnnlTransformerFeedForward( + handle, ffn_desc, ffn_desc, nullptr, descs[0].get(), getAtTensorPtr(input_view), + descs[1].get(), getAtTensorPtr(up_fc_weight), descs[2].get(), getAtTensorPtr(up_fc_bias), + descs[3].get(), getAtTensorPtr(down_proj_weight), descs[4].get(), + getAtTensorPtr(down_proj_bias), descs[7].get(), getAtTensorPtr(layernorm_weight), + descs[8].get(), getAtTensorPtr(layernorm_bias), getAtTensorPtr(workspace), workspace_size, + descs[9].get(), getAtTensorPtr(output))); + return nDim == 2 ? output.squeeze_(0) : output; +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/flash_attention.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/flash_attention.cpp new file mode 100644 index 0000000..63952ef --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/flash_attention.cpp @@ -0,0 +1,117 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" +namespace tmo { +namespace torch_api { + +void flash_attention(const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &output, + const c10::optional &output_lse, + const c10::optional &cu_seq_lens_q, + const c10::optional &cu_seq_lens_kv, + const c10::optional &alibi_slope, + const c10::optional &attn_bias, + const c10::optional &k_cache_quant_scale, + const c10::optional &v_cache_quant_scale, + const c10::optional &block_tables, + const int64_t max_seq_len_q, + const int64_t max_seq_len_kv, + const double softmax_scale, + const bool is_causal, + const int64_t window_size_left, + const int64_t window_size_right, + const std::string &compute_dtype, + bool return_lse) { + TORCH_CHECK(compute_dtype == "float" || compute_dtype == "half" || compute_dtype == "bfloat16", + "compute_dtype must be 'float', 'half' or 'bfloat16'."); + TORCH_CHECK(k_cache_quant_scale.has_value() == false, "k_cache_scale for reserve."); + TORCH_CHECK(v_cache_quant_scale.has_value() == false, "v_cache_scale for reserve."); + bool has_block_table = block_tables.has_value(); + bool is_pack = cu_seq_lens_q.has_value(); + int64_t batch = is_pack ? cu_seq_lens_q.value().size(0) - 1 : q.size(0); + int qk_head_size = q.size(-1); + int v_head_size = v.size(-1); + + // Check tensor type and tensor device. + checkTensorSameAttr(q, k, v, output); + // 3d for packed + TORCH_CHECK(q.dim() == 3 || q.dim() == 4, "query must be 3d or 4d."); + if (has_block_table) { + TORCH_CHECK(block_tables.value().dim() == 2, "block_tables must be 2d."); + TORCH_CHECK(k.dim() == 4, "with block table, key_cache must be 4d."); + TORCH_CHECK(v.dim() == 4, "with block table, value_cache must be 4d."); + int max_num_blocks_per_seq = block_tables.value().size(1); + int num_blocks = k.size(0); + int block_size = k.size(2); + int k_head_num = k.size(1); + CHECK_SHAPE(k, num_blocks, k_head_num, block_size, qk_head_size); + CHECK_SHAPE(v, num_blocks, k_head_num, block_size, v_head_size); + if (max_num_blocks_per_seq > 1) { // paged + CHECK_SHAPE(cu_seq_lens_kv.value(), batch + 1); + } + } else { + // 3d for packed + TORCH_CHECK(k.dim() == 3 || k.dim() == 4, "key_cache must be 3d or 4d."); + TORCH_CHECK(v.dim() == 3 || v.dim() == 4, "value_cache must be 3d or 4d."); + if (k.dim() == 3) { // packed_kv + CHECK_SHAPE(cu_seq_lens_kv.value(), batch + 1); + } + } + // Convert torch tensor to tensor descs + auto descs = createTensorDescs( + {q, k, v, cu_seq_lens_q.value_or(at::Tensor()), cu_seq_lens_kv.value_or(at::Tensor()), + alibi_slope.value_or(at::Tensor()), attn_bias.value_or(at::Tensor()), + block_tables.value_or(at::Tensor()), output, output_lse.value_or(at::Tensor())}); + + // Get current handle. + const torch_mlu::mlu::MLUGuard device_guard(q.device()); + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + // Get workspace size and malloc workspace. + cnnlDataType_t cnnl_compute_dtype = compute_dtype == "float" ? CNNL_DTYPE_FLOAT + : compute_dtype == "half" ? CNNL_DTYPE_HALF + : CNNL_DTYPE_BFLOAT16; + size_t workspace_size = 0; + CNNL_CHECK_FATAL(cnnlGetScaledDotProductAttnWorkspaceSize_v2( + handle, nullptr /*op_desc*/, nullptr /*quant_desc*/, descs[0].get(), descs[1].get(), + descs[2].get(), descs[3].get(), descs[4].get(), descs[6].get(), descs[5].get(), + descs[7].get(), max_seq_len_q, max_seq_len_kv, is_causal, window_size_left, window_size_right, + return_lse, CNNL_ACTIVATION_FAST, cnnl_compute_dtype, &workspace_size)); + auto workspace = at::empty({static_cast(workspace_size)}, q.options().dtype(at::kByte)); + + int64_t total_q = is_pack ? q.size(0) : q.size(0) * q.size(1); + int64_t total_k = + has_block_table ? k.size(0) * k.size(2) : (k.dim() == 3 ? k.size(0) : k.size(0) * k.size(1)); + int64_t head_q = q.size(-2); + int64_t head_k = has_block_table ? k.size(1) : k.size(-2); + cnnlDataType_t data_dtype = getCnnlDataType(q.scalar_type()); + FlashAttnTheory obj(batch, total_q, total_k, head_q, head_k, qk_head_size, v_head_size, is_causal, + data_dtype); + cnpxPush(obj); + // call cnnl extra op. + CNNL_CHECK_FATAL(cnnlScaledDotProductAttn_v3( + handle, nullptr, nullptr, descs[0].get(), getAtTensorPtr(q), descs[1].get(), + getAtTensorPtr(k), descs[2].get(), getAtTensorPtr(v), descs[3].get(), + getAtTensorPtr(cu_seq_lens_q), descs[4].get(), getAtTensorPtr(cu_seq_lens_kv), nullptr, + nullptr, descs[6].get(), getAtTensorPtr(attn_bias), descs[5].get(), + getAtTensorPtr(alibi_slope), descs[7].get(), getAtTensorPtr(block_tables), max_seq_len_q, + max_seq_len_kv, is_causal, window_size_left, window_size_right, softmax_scale, + CNNL_ACTIVATION_FAST, cnnl_compute_dtype, getAtTensorPtr(workspace), workspace_size, + return_lse, descs[9].get(), getAtTensorPtr(output_lse), descs[8].get(), + getAtTensorPtr(output))); + cnpxPop(); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/fuse_norm.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/fuse_norm.cpp new file mode 100644 index 0000000..7ab9850 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/fuse_norm.cpp @@ -0,0 +1,134 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +void fused_layernorm(const at::Tensor &input, + const at::Tensor &output, + const c10::optional &residual, + const c10::optional &gamma, + const c10::optional &beta, + const c10::optional &bias, + const c10::optional &quant_scale, + const c10::optional &residual_out, + const c10::optional &smooth_quant_scale, + const std::string &norm_mode, + double eps, + bool store_output_before_norm, + bool dynamic_quant) { + // check device and dtype + checkTensorSameAttr(input, residual, gamma, beta, bias); + checkTensorSameAttr(input, output, residual_out, smooth_quant_scale); + cnnlQuantizeScheme_t output_quant_scheme = CNNL_QUANTIZE_NONE; + if (dynamic_quant) { + TORCH_CHECK(quant_scale.has_value(), "dynamic_quant output, must have quant_scale"); + output_quant_scheme = CNNL_QUANTIZE_PER_TOKEN; + } else if (quant_scale.has_value()) { + output_quant_scheme = CNNL_QUANTIZE_PER_CHANNEL; + } + // check params + bool has_residual = residual.has_value(); + bool quant_out = quant_scale.has_value(); + int hidden_size = input.size(-1); + TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2."); + TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous."); + TORCH_CHECK(input.sizes() == output.sizes(), "input and output must have the same shape"); + if (has_residual) { + TORCH_CHECK(input.sizes() == residual.value().sizes(), + "input and residual must have the same shape"); + } + TORCH_CHECK(norm_mode == "layernorm" || norm_mode == "rmsnorm", + "norm_mode must be 'layernorm' or 'rmsnorm'."); + cnnlTransformerNormType_t mode = + norm_mode == "layernorm" ? CNNL_TRANSFORMER_LAYERNORM : CNNL_TRANSFORMER_RMSNORM; + if (norm_mode == "layernorm") { + TORCH_CHECK(gamma.has_value() && beta.has_value(), "layernorm mode need gamma and beta."); + TORCH_CHECK(gamma.value().sizes() == beta.value().sizes(), + "gamma and beta must have the same shape") + TORCH_CHECK(gamma.value().dim() == 1 && gamma.value().size(0) == hidden_size, + "layernorm mode, gamma and beta size must be hidden_size."); + } else { + TORCH_CHECK(gamma.has_value(), "rmsnorm mode need gamma."); + TORCH_CHECK(gamma.value().dim() == 1 && gamma.value().size(0) == hidden_size, + "rmsnorm mode, gamma size must be hidden_size."); + } + if (quant_out) { + TORCH_CHECK(quant_scale.value().dim() == 1 && quant_scale.value().size(0) == hidden_size, + "quant_scale shape must be [hidden_size]"); + } + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + at::Tensor smooth_quant_scale_flat = + dynamic_quant ? smooth_quant_scale.value().flatten() : at::Tensor(); + at::Tensor input_flat; + at::Tensor output_flat; + at::Tensor residual_flat; + at::Tensor residual_out_flat; + if (quant_out) { + // input must be 2-dim, output dim must be same as input + TORCH_CHECK(input.is_contiguous(), "quant_out is not support when input has stride."); + input_flat = input.flatten(0, -2); + output_flat = output.flatten(0, -2); + residual_flat = has_residual ? residual.value().flatten(0, -2) : residual_flat; + residual_out_flat = + store_output_before_norm ? residual_out.value().flatten(0, -2) : residual_out_flat; + } else if (input.dim() > 3) { + input_flat = input.flatten(0, -3); + output_flat = output.flatten(0, -3); + residual_flat = has_residual ? residual.value().flatten(0, -3) : residual_flat; + residual_out_flat = + store_output_before_norm ? residual_out.value().flatten(0, -3) : residual_out_flat; + } else { // 2-dim or 3-dim + input_flat = input; + output_flat = output; + residual_flat = has_residual ? residual.value() : residual_flat; + residual_out_flat = store_output_before_norm ? residual_out.value() : residual_out_flat; + } + TORCH_CHECK(input.data_ptr() == input_flat.data_ptr(), "check the strides of input."); + TORCH_CHECK(output_flat.data_ptr() == output.data_ptr(), "check the strides of output."); + if (has_residual) + TORCH_CHECK(residual.value().data_ptr() == residual_flat.data_ptr(), + "check the strides of residual."); + if (store_output_before_norm) + TORCH_CHECK(residual_out.value().data_ptr() == residual_out_flat.data_ptr(), + "check the strides of residual_out."); + + // create tensor desc + auto descs = createTensorDescs({input_flat, gamma.value_or(at::Tensor()), + beta.value_or(at::Tensor()), bias.value_or(at::Tensor()), + residual_flat, quant_scale.value_or(at::Tensor()), + residual_out_flat, output_flat, smooth_quant_scale_flat}); + auto compute_dtype = getCnnlDataType(input_flat.scalar_type()); + // forward + auto handle = torch_mlu::getCurrentHandle(); + FusedNormTheory obj(input_flat.size(0), input_flat.size(-1), has_residual, bias.has_value(), + quant_out, dynamic_quant, store_output_before_norm, compute_dtype, norm_mode); + cnpxPush(obj); + CNNL_CHECK_FATAL(cnnlFuseNorm_v3(handle, descs[0].get(), getAtTensorPtr(input_flat), // input + descs[5].get(), getAtTensorPtr(quant_scale), // input_scale + descs[1].get(), getAtTensorPtr(gamma), // norm_scale + descs[2].get(), getAtTensorPtr(beta), // norm_bias + descs[4].get(), getAtTensorPtr(residual_flat), // residual + descs[3].get(), getAtTensorPtr(bias), // bias + eps, output_quant_scheme, store_output_before_norm, mode, + compute_dtype, nullptr, 0, // set workspace + descs[7].get(), getAtTensorPtr(output_flat), // output + descs[6].get(), + getAtTensorPtr(residual_out_flat), // residual_out + descs[8].get(), getAtTensorPtr(smooth_quant_scale_flat))); + cnpxPop(); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/fused_moe.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/fused_moe.cpp new file mode 100644 index 0000000..957cf83 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/fused_moe.cpp @@ -0,0 +1,373 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include +#include +#include "kernels/moe/moe.mluh" +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { + +const std::string arch_370 = "MLU370"; +using GroupGemmDesc = tmo::op_desc::GroupGemmDesc; +using QuantMode = tmo::op_desc::GroupGemmDesc::QuantMode; + +at::Tensor fused_moe(const at::Tensor &hidden_states, + const at::Tensor &gating_output, + const at::Tensor &w1, + const at::Tensor &w2, + const c10::optional &bias1, + const c10::optional &bias2, + const c10::optional &residual, + const c10::optional &input_smooth, + const c10::optional &act_smooth, + const c10::optional &w1_scale, + const c10::optional &w2_scale, + const c10::optional> &w1_quant_flag, + const c10::optional> &w2_quant_flag, + const int64_t topk, + const bool renormalize, + const bool gated, + const std::string &act_mode, + const int64_t start_expert_id, + const int64_t block_n, + const int64_t cncl_comm) { + auto sizes_1 = hidden_states.sizes(); + auto sizes_2 = gating_output.sizes(); + auto w2_shape = w2.sizes(); + TORCH_CHECK(sizes_1.size() == sizes_2.size(), + "hidden_states and gating_output must have the same rank.") + TORCH_CHECK(sizes_1.size() == 2 || sizes_1.size() == 3, "hidden_states must be 2-D or 3-D.") + TORCH_CHECK(sizes_1[0] == sizes_2[0], + "hidden_states and gating_output must have the same batch."); + if (sizes_1.size() == 3) { + TORCH_CHECK(sizes_1[1] == sizes_2[1], + "hidden_states and gating_output must have the same seq."); + } + if (residual.has_value()) { + TORCH_CHECK(residual.value().sizes() == sizes_1, + "hidden_states and residual must have the same shape."); + } + TORCH_CHECK(!bias1.has_value() && !bias2.has_value(), "Currently not support bias1 and bias2.") + TORCH_CHECK(w1.dim() == 3 || w1.dim() == 1, "w1 should be 1-D or 3-D.") + TORCH_CHECK(w2.dim() == 3 || w2.dim() == 4 || w2.dim() == 1, "w2 should be 1-D or 3-D or 4-D.") + TORCH_CHECK((w1.dim() == 1 && w2.dim() == 1) || (w1.dim() != 1 && w2.dim() != 1), + "w1 and w2 should be both 1-D or not both 1-D.") + if (w1.dim() == 1) { + TORCH_CHECK(w1_quant_flag.has_value() && w2_quant_flag.has_value(), + "w1_quant_flag and w2_quant_flag need to exist simultaneously."); + } + // check contiguous + CHECK_TENSOR_CONTIGUOUS(hidden_states) + CHECK_TENSOR_CONTIGUOUS(gating_output) + CHECK_TENSOR_CONTIGUOUS(w1) + CHECK_TENSOR_CONTIGUOUS(w2) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(act_smooth) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(w1_scale) + const int64_t hidden_size = hidden_states.size(-1); + const int64_t num_expert = gating_output.size(-1); + const int64_t expert_size = w1_quant_flag.has_value() ? w1_scale.value().size(1) : w1.size(0); + auto hidden_states_ = hidden_states.view({-1, hidden_size}); + auto gating_output_ = gating_output.view({-1, num_expert}); + const int64_t num_token = hidden_states_.size(0); + const int64_t num_expand_token = num_token * topk; + const int64_t gemm1_co = w1_quant_flag.has_value() ? w1_scale.value().size(2) : w1.size(1); + const int64_t inner_size = gemm1_co / (1 + gated); + bool has_input_smooth = input_smooth.has_value(); + bool has_act_smooth = act_smooth.has_value(); + bool has_w1_scale = w1_scale.has_value(); + bool has_w2_scale = w2_scale.has_value(); + int opt_num = has_input_smooth + has_act_smooth + has_w1_scale + has_w2_scale; + QuantMode quant_mode = QuantMode::noQuant; + bool quant_grouped = false; + + bool per_token_sq = opt_num == 4; + TORCH_CHECK(opt_num == 0 || opt_num == 4, + "input_smooth, act_smooth, w1_scale and w2_scale must be present and absent at the " + "same time.") + if (per_token_sq) { + TORCH_CHECK(input_smooth.value().dtype() == torch::kFloat32, + "the data type of input_smooth must be float."); + checkTensorSameAttr(input_smooth, act_smooth, w1_scale, w2_scale); + CHECK_SHAPE(input_smooth.value(), expert_size, hidden_size); + CHECK_SHAPE(act_smooth.value(), expert_size, inner_size); + + quant_grouped = w1_scale.value().dim() == 3 ? true : false; + if (quant_grouped) { + CHECK_SHAPE(w1_scale.value(), w1_scale.value().size(0), expert_size, gemm1_co); + CHECK_SHAPE(w2_scale.value(), w2_scale.value().size(0), expert_size, hidden_size); + } else { + CHECK_SHAPE(w1_scale.value(), expert_size, gemm1_co); + CHECK_SHAPE(w2_scale.value(), expert_size, hidden_size); + } + + if (w1_quant_flag.has_value()) { + quant_mode = QuantMode::W4W8; + quant_grouped = true; + } else { + TORCH_CHECK(hidden_size == w1.size(-1) || hidden_size == w1.size(-1) * 2, + "hidden_size == w1.size(-1) || hidden_size == w1.size(-1) * 2."); + quant_mode = hidden_size == w1.size(-1) ? QuantMode::W8 : QuantMode::W4; + } + } + + if (quant_mode == QuantMode::W8 || quant_mode == QuantMode::noQuant) { + CHECK_SHAPE(w1, expert_size, gemm1_co, hidden_size); + if (w2.dim() == 3) { + CHECK_SHAPE(w2, expert_size, hidden_size, inner_size); + } else { + TORCH_CHECK(w2_shape[0] * w2_shape[2] == expert_size, + "w2_shape[0] * w2_shape[2] == expert_size"); + CHECK_SHAPE(w2, w2_shape[0], hidden_size, w2_shape[2], inner_size); + } + } else if (quant_mode == QuantMode::W4) { + CHECK_SHAPE(w1, expert_size, gemm1_co, hidden_size / 2); + CHECK_SHAPE(w2, expert_size, hidden_size, inner_size / 2); + } + + if (bias1.has_value()) { + CHECK_SHAPE(bias1.value(), num_expert, gemm1_co); + } + if (bias2.has_value()) { + CHECK_SHAPE(bias2.value(), num_expert, hidden_size); + } + TORCH_CHECK(topk <= num_expert, "topk <= num_expert.") + TORCH_CHECK(act_mode == "silu" || act_mode == "gelu", + "act_mode must be 'silu' or 'gelu', but got ", act_mode) + + checkTensorSameAttr(hidden_states, bias1, bias2); + checkTensorSameAttr(hidden_states, gating_output, w1, w2, residual, + input_smooth, act_smooth, w1_scale, w2_scale); + + const torch_mlu::mlu::MLUGuard device_guard(hidden_states_.device()); + torch_mlu::DeviceProp *dev_prop = torch_mlu::getDeviceProperties(hidden_states.get_device()); + std::string dev_name = dev_prop->name; + bool is_mlu370 = arch_370.compare(3, 3, dev_name, 3, 3) >= 0 ? true : false; + auto handle = torch_mlu::getCurrentHandle(); + auto gating_output_dtype = getCnnlDataType(gating_output.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + auto tensor_options = hidden_states_.options(); + auto data_dtype = getCnnlDataType(hidden_states.scalar_type()); + auto weight_dtype = getCnnlDataType(w1.scalar_type()); + auto input_dtype = per_token_sq ? CNNL_DTYPE_INT8 : weight_dtype; + if (quant_mode == QuantMode::W4) { + weight_dtype = CNNL_DTYPE_INT4X2; + } + + // create tensors + auto reduce_weight = at::empty({num_token, topk}, tensor_options.dtype(torch::kFloat)); + auto expert_id = at::empty({num_token, topk}, tensor_options.dtype(torch::kInt32)); + auto int32_idx = at::empty({2, num_token, topk}, tensor_options.dtype(torch::kInt32)); + auto gather_expand_idx = int32_idx[0]; + auto gather_combine_idx = int32_idx[1]; + auto token_count = at::empty({num_expert}, tensor_options.dtype(torch::kInt32)); + auto cumsum_token_count = at::empty({num_expert + 1}, tensor_options.dtype(torch::kInt32)); + auto gen_idx_workspace = + at::empty({num_expert + 1 + num_expand_token}, tensor_options.dtype(torch::kInt32)); + int64_t mid_dim = (1 + gated) * (1 + per_token_sq); + auto gemm1_output = at::empty({num_expand_token, mid_dim, inner_size}, w1.options()); + auto gemm2_output = at::empty({num_expand_token, hidden_size}, tensor_options); + auto quant_input = at::Tensor(); + auto input_scale = at::Tensor(); + auto act_scale = at::Tensor(); + auto expand_hidden_state = at::Tensor(); + if (is_mlu370 || !per_token_sq) { + expand_hidden_state = at::empty({num_expand_token, hidden_size}, tensor_options); + } + if (per_token_sq) { + quant_input = at::empty({num_expand_token, hidden_size}, tensor_options.dtype(torch::kInt8)); + input_scale = at::empty({num_expand_token}, tensor_options.dtype(torch::kFloat)); + act_scale = at::empty({num_expand_token}, tensor_options.dtype(torch::kFloat)); + } + + //=========================================1:topk_softmax========================================= + int normalize_mode = renormalize ? 1 : 0; + tmo::invokeMoeSoftmaxTopkKernel(queue, (float *)getAtTensorPtr(reduce_weight), + (int *)getAtTensorPtr(expert_id), getAtTensorPtr(gating_output_), + nullptr, num_token, num_expert, 1, topk, -1, 0, + gating_output_dtype, normalize_mode); + //=========================================2:generate_idx========================================= + tmo::invokeMoeGenIdxKernel( + queue, (int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(gather_combine_idx), + (int *)getAtTensorPtr(token_count), (int *)getAtTensorPtr(cumsum_token_count), + getAtTensorPtr(gen_idx_workspace), getAtTensorPtr(expert_id), num_token, num_expert, topk); + if (per_token_sq) { + //========================================3:pertoken_sq(optinal)================================== + if (is_mlu370) { + tmo::invokeMoeExpandInputKernel( + queue, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(hidden_states_), + (int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(cumsum_token_count), + num_token, hidden_size, topk, data_dtype, num_expert, start_expert_id, expert_size); + SmoothQuantTheory sm_obj1(num_expand_token, hidden_size, data_dtype, "per_token"); + cnpxPush(sm_obj1); + tmo::ops::SmoothQuant( + handle, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(input_smooth), + getAtTensorPtr(token_count[start_expert_id]), nullptr, nullptr, + getAtTensorPtr(quant_input), getAtTensorPtr(input_scale), num_expand_token, hidden_size, + expert_size, hidden_size, hidden_size, 1, data_dtype); + cnpxPop(); + } else { + SmoothQuantTheory sm_obj1(num_expand_token, hidden_size, data_dtype, "per_token"); + cnpxPush(sm_obj1); + tmo::ops::SmoothQuant(handle, getAtTensorPtr(hidden_states_), getAtTensorPtr(input_smooth), + getAtTensorPtr(token_count[start_expert_id]), + getAtTensorPtr(gather_expand_idx), + getAtTensorPtr(cumsum_token_count[start_expert_id]), + getAtTensorPtr(quant_input), getAtTensorPtr(input_scale), num_token, + hidden_size, expert_size, hidden_size, hidden_size, topk, data_dtype); + cnpxPop(); + } + } else { + //========================================3:expand_input======================================== + tmo::invokeMoeExpandInputKernel( + queue, getAtTensorPtr(expand_hidden_state), getAtTensorPtr(hidden_states_), + (int *)getAtTensorPtr(gather_expand_idx), (int *)getAtTensorPtr(cumsum_token_count), + num_token, hidden_size, topk, data_dtype, num_expert, start_expert_id, expert_size); + } + //========================================4:group_gemm1=========================================== + GroupGemmDesc group_gemm_desc1( + expert_size, num_token /*the maximum number of token that may be processed by each expert*/, + gemm1_co, hidden_size, data_dtype, false, quant_mode); + group_gemm_desc1.setInputOutputTensor( + input_dtype, weight_dtype, data_dtype, CNNL_DTYPE_INT32, + per_token_sq ? getAtTensorPtr(quant_input) : getAtTensorPtr(expand_hidden_state), + getAtTensorPtr(w1), nullptr, getAtTensorPtr(gemm1_output), nullptr, hidden_size, hidden_size, + num_expand_token, false); + std::vector w1_flag_vec; + if (per_token_sq) { + if (w1_quant_flag.has_value()) { + auto vec = w1_quant_flag.value().vec(); + w1_flag_vec.resize(vec.size()); + std::copy(vec.begin(), vec.end(), w1_flag_vec.begin()); + } + group_gemm_desc1.setPerRowColScaleBiasAct( + getAtTensorPtr(input_scale), getAtTensorPtr(w1_scale), w1_flag_vec.data(), nullptr, + data_dtype, quant_grouped ? hidden_size : 0, + quant_grouped ? hidden_size / w1_scale.value().size(0) : 0, 0); + } + size_t group_gemm1_wsize = + tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc1, expert_size); + auto group_gemm1_workspace = at::empty({static_cast(group_gemm1_wsize)}, + hidden_states.options().dtype(at::kByte)); + std::vector ldb_array; + if (quant_mode != QuantMode::W4W8) ldb_array.assign(expert_size, hidden_size); + GroupGemmTheory gg_obj1(num_expand_token, expert_size, hidden_size, gemm1_co, false /*has_res*/, + input_dtype, data_dtype); + cnpxPush(gg_obj1); + tmo::ops::GroupGemm(handle, group_gemm_desc1, getAtTensorPtr(token_count[start_expert_id]), + nullptr, nullptr, getAtTensorPtr(group_gemm1_workspace), group_gemm1_wsize, + expert_size, hidden_size, /*k*/ + gemm1_co, /*n*/ + hidden_size /*lda*/, ldb_array /*ldb*/); + cnpxPop(); + //========================================5:activation============================================ + cnnlActivationMode_t act_type = act_mode == "silu" ? CNNL_ACTIVATION_SWISH : CNNL_ACTIVATION_GELU; + GroupAddBiasActiveTheory add_bias_obj(expert_size, num_expand_token, inner_size, gated, + bias1.has_value(), data_dtype, act_mode); + cnpxPush(add_bias_obj); + tmo::invokeGroupAddBiasActivationKernel( + queue, getAtTensorPtr(gemm1_output), getAtTensorPtr(gemm1_output), + nullptr /*getAtTensorPtr(bias1)*/, (int *)getAtTensorPtr(cumsum_token_count), num_expert, + num_expand_token, inner_size, gemm1_co, data_dtype, gated, act_type, start_expert_id, + expert_size, 1.0f); + cnpxPop(); + //========================================6:smooth_quant========================================== + if (per_token_sq) { + SmoothQuantTheory sm_obj2(num_expand_token, inner_size * expert_size, data_dtype, "per_token"); + cnpxPush(sm_obj2); + tmo::ops::SmoothQuant(handle, getAtTensorPtr(gemm1_output), getAtTensorPtr(act_smooth), + getAtTensorPtr(token_count[start_expert_id]), nullptr /*gather_idx*/, + nullptr, getAtTensorPtr(gemm1_output), getAtTensorPtr(act_scale), + num_expand_token, inner_size, expert_size, gemm1_co /*input_stride*/, + 2 * gemm1_co /*output_stride*/, 1 /*topk*/, data_dtype); + cnpxPop(); + } + if (cncl_comm > 0) { + TORCH_CHECK(num_expert == expert_size, "expert_size must be num_expert when cncl_comm > 0"); + c10::optional act_scale_opt(act_scale); + c10::optional w2_scale_opt(w2_scale); + std::string dtype_s = torchDtype2Str(hidden_states.scalar_type()); + auto gg_input = + gemm1_output.as_strided({num_expand_token, inner_size}, {mid_dim * inner_size, 1}); + auto output = group_gemm_combine_result_allreduce( + cncl_comm, gg_input, w2, token_count, gather_combine_idx, reduce_weight, c10::nullopt, + c10::nullopt, c10::nullopt, per_token_sq ? act_scale_opt : c10::nullopt, + per_token_sq ? w2_scale_opt : c10::nullopt, dtype_s, num_token, topk, block_n); + if (residual.has_value()) { + output.view(sizes_1) += residual.value(); + } + return output.view(sizes_1); + } else { + //========================================8:group_gemm2=========================================== + GroupGemmDesc group_gemm_desc2(expert_size, num_token, hidden_size, inner_size, data_dtype, + false, quant_mode); + std::vector b_offset(expert_size); + int w2_ldb = inner_size; + if (w2.dim() == 4) { + w2_ldb = w2_shape[2] * inner_size; + auto elem_size = w2.element_size(); + for (int64_t i = 0; i < expert_size; i++) { + b_offset[i] = (i / w2_shape[2] * w2.stride(0) + i % w2_shape[2] * w2_shape[3]) * elem_size; + } + } + group_gemm_desc2.setInputOutputTensor(input_dtype, weight_dtype, data_dtype, CNNL_DTYPE_INT32, + getAtTensorPtr(gemm1_output), getAtTensorPtr(w2), nullptr, + getAtTensorPtr(gemm2_output), nullptr, inner_size, + gemm1_co * (per_token_sq + 1), num_expand_token, false, + w2.dim() == 4 ? b_offset.data() : nullptr); + std::vector w2_flag_vec; + if (per_token_sq) { + if (w2_quant_flag.has_value()) { + auto vec = w2_quant_flag.value().vec(); + w2_flag_vec.resize(vec.size()); + std::copy(vec.begin(), vec.end(), w2_flag_vec.begin()); + } + group_gemm_desc2.setPerRowColScaleBiasAct( + getAtTensorPtr(act_scale), getAtTensorPtr(w2_scale), w2_flag_vec.data(), nullptr, + data_dtype, quant_grouped ? inner_size : 0, + quant_grouped ? inner_size / w2_scale.value().size(0) : 0, 0); + } + + size_t group_gemm2_wsize = + tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc2, expert_size); + auto group_gemm2_workspace = at::empty({static_cast(group_gemm2_wsize)}, + hidden_states.options().dtype(at::kByte)); + std::vector w2_ldb_array; + if (quant_mode != QuantMode::W4W8) w2_ldb_array.assign(expert_size, w2_ldb); + GroupGemmTheory gg_obj2(num_expand_token, expert_size, inner_size, hidden_size, + false /*has_res*/, input_dtype, data_dtype); + cnpxPush(gg_obj2); + tmo::ops::GroupGemm(handle, group_gemm_desc2, getAtTensorPtr(token_count[start_expert_id]), + nullptr, nullptr, getAtTensorPtr(group_gemm2_workspace), group_gemm2_wsize, + expert_size, inner_size, /*k*/ + hidden_size, /*n*/ + gemm1_co * (per_token_sq + 1) /*lda*/, w2_ldb_array /*ldb*/); + cnpxPop(); + //========================================9:combine_result======================================= + auto output = at::empty(hidden_states_.sizes(), hidden_states_.options()); + MoeCombineResultTheory cr_obj(num_token, topk, hidden_size, expert_size, bias2.has_value(), + residual.has_value(), data_dtype); + cnpxPush(cr_obj); + tmo::invokeMoeCombineResultKernel( + queue, getAtTensorPtr(output), getAtTensorPtr(gemm2_output), nullptr, + getAtTensorPtr(residual), (float *)getAtTensorPtr(reduce_weight), + (int *)getAtTensorPtr(cumsum_token_count), (int *)getAtTensorPtr(gather_combine_idx), + num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, data_dtype); + cnpxPop(); + return output.view(sizes_1); + } +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/fused_rope.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/fused_rope.cpp new file mode 100644 index 0000000..01f1a77 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/fused_rope.cpp @@ -0,0 +1,281 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "kernels/fused_rope.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +void fused_rope(at::Tensor &qkv, + at::Tensor &key_cache_hp, + at::Tensor &value_cache_hp, + const c10::optional &key_cache_lp, + const c10::optional &value_cache_lp, + const at::Tensor &sin_table, + const at::Tensor &cos_table, + const at::Tensor &position_ids, + const at::Tensor &gamma, + const at::Tensor &beta, + const c10::optional &key_scale_hp, + const c10::optional &value_scale_hp, + const c10::optional &key_scale_lp, + const c10::optional &value_scale_lp, + const c10::optional &cache_bs_id_hp, + const c10::optional &cache_seq_offsets_hp, + const c10::optional &cache_bs_id_lp, + const c10::optional &cache_seq_offsets_lp, + const c10::optional &slot_mapping_hp, + const c10::optional &slot_mapping_lp, + const double eps) { + bool paged_cache_hp = slot_mapping_hp.has_value(); + bool paged_cache_lp = slot_mapping_lp.has_value(); + bool mixed_cache = key_cache_lp.has_value() && value_cache_lp.has_value(); + const int origin_device_id = qkv.get_device(); + + checkTensorSameAttr(qkv, sin_table, cos_table, gamma, beta); + TORCH_CHECK(qkv.is_contiguous(), "qkv tensor must be contiguous.") + TORCH_CHECK(key_cache_hp.is_contiguous(), "key_cache_hp tensor must be contiguous") + TORCH_CHECK(value_cache_hp.is_contiguous(), "value_cache_hp tensor must be contiguous") + if (mixed_cache) { + TORCH_CHECK(key_cache_lp.value().is_contiguous(), "key_cache_lp tensor must be contiguous") + TORCH_CHECK(key_scale_hp.has_value() && value_scale_hp.has_value(), + "key_scale_hp and value_scale_hp must not be null under mixed cache.") + } + if (mixed_cache) { + TORCH_CHECK(value_cache_lp.value().is_contiguous(), "value_cache_lp tensor must be contiguous") + } + TORCH_CHECK(position_ids.is_contiguous(), "position_ids tensor must be contiguous"); + TORCH_CHECK(gamma.is_contiguous(), "gamma tensor must be contiguous"); + TORCH_CHECK(beta.is_contiguous(), "beta tensor must be contiguous"); + if (cache_bs_id_hp.has_value()) { + TORCH_CHECK(cache_bs_id_hp.value().is_contiguous(), "cache_bs_id_hp tensor must be contiguous.") + } + if (cache_seq_offsets_hp.has_value()) { + TORCH_CHECK(cache_seq_offsets_hp.value().is_contiguous(), + "cache_seq_offsets_hp tensor must be contiguous") + } + if (cache_bs_id_lp.has_value()) { + TORCH_CHECK(cache_bs_id_lp.value().is_contiguous(), "cache_bs_id_lp tensor must be contiguous.") + } + if (cache_seq_offsets_lp.has_value()) { + TORCH_CHECK(cache_seq_offsets_lp.value().is_contiguous(), + "cache_seq_offsets_lp tensor must be contiguous") + } + if (paged_cache_hp) { + TORCH_CHECK(slot_mapping_hp.value().is_contiguous(), + "slot_mapping_hp tensor must be contiguous.") + } + if (paged_cache_lp) { + TORCH_CHECK(slot_mapping_lp.value().is_contiguous(), + "slot_mapping_lp tensor must be contiguous.") + } + + // check qkv key_cache value_cache + auto dtype = qkv.dtype(); + TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, + "qkv dtype must be half or bfloat16.") + TORCH_CHECK(qkv.dim() == 4, + "qkv tensor dim must be 4: (batch_size, 1, head_num_q + head_num_kv * 2, head_size).") + TORCH_CHECK(qkv.size(1) == 1, "only support seq_len == 1."); + TORCH_CHECK(key_cache_hp.dim() == 4, + "key_cache_hp dim must be 4: (max_bs, head_num_kv, max_decode_len_hp, head_size)," + "or (num_blocks, head_num_kv, block_size_hp, head_size)."); + TORCH_CHECK(value_cache_hp.dim() == 4, + "value_cache_hp dim must be 4: (max_bs, head_num_kv, max_decode_len_hp, head_size)," + "or (num_blocks, head_num_kv, block_size_hp, head_size)."); + void *key_cache_lp_ptr = nullptr; + void *value_cache_lp_ptr = nullptr; + if (mixed_cache) { + TORCH_CHECK( + key_cache_lp.value().dim() == 4, + "key_cache_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp, head_size / 2)," + "or (num_blocks, head_num_kv, block_size_lp, head_size / 2)."); + TORCH_CHECK( + value_cache_lp.value().dim() == 4, + "value_cache_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp / 2, head_size)," + "or (num_blocks, head_num_kv, block_size_lp / 2, head_size)."); + TORCH_CHECK(key_cache_lp.value().get_device() == origin_device_id, + "key_scale_hp tensor device index is not same, original index: ", origin_device_id, + "now index is: ", key_cache_lp.value().get_device()); + TORCH_CHECK(value_cache_lp.value().get_device() == origin_device_id, + "value_scale_hp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", value_cache_lp.value().get_device()); + key_cache_lp_ptr = key_cache_lp.value().data_ptr(); + value_cache_lp_ptr = value_cache_lp.value().data_ptr(); + } + + int batch_size = qkv.size(0); + int head_num_qkv = qkv.size(-2); + int head_size = qkv.size(-1); + int head_num_kv = key_cache_hp.size(1); // linear cache and paged cache same + int max_bs_hp = key_cache_hp.size(0); // for linear cache, paged cache no use + int max_bs_lp = mixed_cache ? key_cache_lp.value().size(0) : 0; + int max_decode_len_hp = key_cache_hp.size(2); // for linear cache, paged cache no use + int max_decode_len_lp = mixed_cache ? key_cache_lp.value().size(2) : 0; + int num_blocks_hp = key_cache_hp.size(0); + int num_blocks_lp = mixed_cache ? key_cache_lp.value().size(0) : 0; + int block_size_hp = key_cache_hp.size(2); // for paged cache, linear cache no use + int block_size_lp = mixed_cache ? key_cache_lp.value().size(2) : 0; + int head_num_q = head_num_qkv - 2 * head_num_kv; + int group_size = head_size; + + TORCH_CHECK(head_size <= 128, "only support qkv head_size <= 128."); + TORCH_CHECK(head_size % 2 == 0, "head_size must be divided by 2."); + TORCH_CHECK(head_num_q <= 32, "only support head_num_q <= 32."); + TORCH_CHECK(head_num_kv <= 32, "only support head_num_kv <= 32."); + int kv_out_size = key_scale_hp.has_value() && value_scale_hp.has_value() ? 1 : 2; + size_t hp_cache_total_bytes = + paged_cache_hp + ? (size_t)num_blocks_hp * head_num_kv * block_size_hp * head_size * kv_out_size + : (size_t)max_bs_hp * max_decode_len_hp * head_num_kv * head_size * kv_out_size; + TORCH_CHECK(hp_cache_total_bytes <= INT32_MAX, + "hp_cache memory btyes must be less than or equal to INT32_MAX."); + if (mixed_cache) { + size_t lp_cache_total_bytes = + paged_cache_lp ? (size_t)num_blocks_lp * head_num_kv * block_size_lp * head_size / 2 + : (size_t)max_bs_lp * max_decode_len_lp * head_num_kv * head_size / 2; + TORCH_CHECK(lp_cache_total_bytes <= INT32_MAX, + "lp_cache memory btyes must be less than or equal to INT32_MAX."); + } + + // check sin_table cos_table gamma beta + TORCH_CHECK(sin_table.dim() == 2 && cos_table.dim() == 2, + "sin_table and cos_table tensor dim must be 2: (rotary_seqlen, head_size)."); + TORCH_CHECK(sin_table.size(1) == head_size && cos_table.size(1) == head_size, + "rotary_dim must be same with head_size."); + TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0), + "sin_table first stride must be equal to cos_table first_stride."); + int rotary_stride = sin_table.stride(0); + + CHECK_SHAPE(gamma, head_size); + CHECK_SHAPE(beta, head_size); + + // check key value scale + bool has_kv_scale = key_scale_hp.has_value() && value_scale_hp.has_value(); + const void *key_scale_hp_ptr = nullptr; + const void *value_scale_hp_ptr = nullptr; + if (has_kv_scale) { + CHECK_SHAPE(key_scale_hp.value(), head_num_kv, head_size); + CHECK_SHAPE(value_scale_hp.value(), head_num_kv, head_size); + TORCH_CHECK(key_scale_hp.value().scalar_type() == torch::kFloat32, + "key_scale_hp dtype must be float."); + TORCH_CHECK(value_scale_hp.value().scalar_type() == torch::kFloat32, + "value_scale_hp dtype must be float."); + TORCH_CHECK(key_scale_hp.value().get_device() == origin_device_id, + "key_scale_hp tensor device index is not same, original index: ", origin_device_id, + "now index is: ", key_scale_hp.value().get_device()); + TORCH_CHECK(value_scale_hp.value().get_device() == origin_device_id, + "value_scale_hp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", value_scale_hp.value().get_device()); + key_scale_hp_ptr = key_scale_hp.value().data_ptr(); + value_scale_hp_ptr = value_scale_hp.value().data_ptr(); + } + + void *key_scale_lp_ptr = nullptr; + void *value_scale_lp_ptr = nullptr; + if (mixed_cache) { + TORCH_CHECK(key_scale_lp.value().dim() == 4, + "key_scale_lp dim must be 4: (max_bs, head_num_kv, max_decode_len_lp, group_num)," + "or (num_blocks, head_num_kv, block_size_lp, group_num)."); + TORCH_CHECK(key_scale_lp.value().scalar_type() == torch::kFloat32, + "key_scale_hp dtype must be float."); + TORCH_CHECK(value_scale_lp.value().scalar_type() == torch::kFloat32, + "value_scale_hp dtype must be float."); + TORCH_CHECK(key_scale_lp.value().get_device() == origin_device_id, + "key_scale_hp tensor device index is not same, original index: ", origin_device_id, + "now index is: ", key_scale_lp.value().get_device()); + TORCH_CHECK(value_scale_lp.value().get_device() == origin_device_id, + "value_scale_hp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", value_scale_lp.value().get_device()); + group_size = head_size / key_scale_lp.value().size(-1); + key_scale_lp_ptr = key_scale_lp.value().data_ptr(); + value_scale_lp_ptr = value_scale_lp.value().data_ptr(); + } + + // check cache_bs_id cache_seq_offsets + const void *cache_bs_id_hp_ptr = nullptr; + const void *cache_bs_id_lp_ptr = nullptr; + bool has_cache_bs_id_hp = cache_bs_id_hp.has_value(); + bool has_cache_bs_id_lp = cache_bs_id_lp.has_value(); + if (has_cache_bs_id_hp) { + CHECK_SHAPE(cache_bs_id_hp.value(), batch_size); + TORCH_CHECK(cache_bs_id_hp.value().scalar_type() == torch::kInt32, + "cache_bs_id_hp dtype need be torch::kInt32"); + TORCH_CHECK(cache_bs_id_hp.value().get_device() == origin_device_id, + "cache_bs_id_hp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", cache_bs_id_hp.value().get_device()); + cache_bs_id_hp_ptr = cache_bs_id_hp.value().data_ptr(); + } + if (mixed_cache && has_cache_bs_id_lp) { + CHECK_SHAPE(cache_bs_id_lp.value(), batch_size); + TORCH_CHECK(cache_bs_id_lp.value().scalar_type() == torch::kInt32, + "cache_bs_id_lp dtype need be torch::kInt32"); + TORCH_CHECK(cache_bs_id_lp.value().get_device() == origin_device_id, + "cache_bs_id_lp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", cache_bs_id_lp.value().get_device()); + cache_bs_id_lp_ptr = cache_bs_id_lp.value().data_ptr(); + } + const void *slot_mapping_hp_ptr = nullptr; + const void *slot_mapping_lp_ptr = nullptr; + if (paged_cache_hp) { + CHECK_SHAPE(slot_mapping_hp.value(), batch_size); + TORCH_CHECK(slot_mapping_hp.value().scalar_type() == torch::kInt32, + "slot_mapping_hp dtype need be torch::kInt32"); + TORCH_CHECK(slot_mapping_hp.value().get_device() == origin_device_id, + "slot_mapping_hp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", slot_mapping_hp.value().get_device()); + slot_mapping_hp_ptr = slot_mapping_hp.value().data_ptr(); + } + if (mixed_cache && paged_cache_lp) { + CHECK_SHAPE(slot_mapping_lp.value(), batch_size); + TORCH_CHECK(slot_mapping_lp.value().scalar_type() == torch::kInt32, + "slot_mapping_lp dtype need be torch::kInt32"); + TORCH_CHECK(slot_mapping_lp.value().get_device() == origin_device_id, + "slot_mapping_lp tensor device index is not same, original index: ", + origin_device_id, "now index is: ", slot_mapping_lp.value().get_device()); + slot_mapping_lp_ptr = slot_mapping_lp.value().data_ptr(); + } + const void *cache_seq_offsets_hp_ptr = nullptr; + const void *cache_seq_offsets_lp_ptr = nullptr; + if (!paged_cache_hp) { + CHECK_SHAPE(cache_seq_offsets_hp.value(), batch_size); + TORCH_CHECK(cache_seq_offsets_hp.value().dtype() == torch::kInt32, + "cache_seq_offsets_hp dtype need be torch::kInt32"); + TORCH_CHECK(cache_seq_offsets_hp.value().get_device() == origin_device_id, + "cache_seq_offsets_hp tensor device index is not same, original index:", + origin_device_id, "now index is: ", cache_seq_offsets_hp.value().get_device()); + cache_seq_offsets_hp_ptr = cache_seq_offsets_hp.value().data_ptr(); + } + if (mixed_cache && !paged_cache_lp) { + CHECK_SHAPE(cache_seq_offsets_lp.value(), batch_size); + TORCH_CHECK(cache_seq_offsets_lp.value().dtype() == torch::kInt32, + "cache_seq_offsets_lp dtype need be torch::kInt32"); + TORCH_CHECK(cache_seq_offsets_lp.value().get_device() == origin_device_id, + "cache_seq_offsets_lp tensor device index is not same, original index:", + origin_device_id, "now index is: ", cache_seq_offsets_lp.value().get_device()); + cache_seq_offsets_lp_ptr = cache_seq_offsets_lp.value().data_ptr(); + } + + auto queue = torch_mlu::getCurrentMLUStream(); + auto data_type = getCnnlDataType(qkv.scalar_type()); + invokeFusedRope(queue, getAtTensorPtr(qkv), getAtTensorPtr(key_cache_hp), + getAtTensorPtr(value_cache_hp), key_cache_lp_ptr, value_cache_lp_ptr, + getAtTensorPtr(sin_table), getAtTensorPtr(cos_table), + getAtTensorPtr(position_ids), getAtTensorPtr(gamma), getAtTensorPtr(beta), + key_scale_hp_ptr, value_scale_hp_ptr, key_scale_lp_ptr, value_scale_lp_ptr, + cache_bs_id_hp_ptr, cache_seq_offsets_hp_ptr, cache_bs_id_lp_ptr, + cache_seq_offsets_lp_ptr, slot_mapping_hp_ptr, slot_mapping_lp_ptr, rotary_stride, + batch_size, head_num_q, head_num_kv, head_size, max_decode_len_hp, + max_decode_len_lp, block_size_hp, block_size_lp, group_size, data_type, eps); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.cpp new file mode 100644 index 0000000..103f41f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.cpp @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "glue_ops.h" +#include + +void copy_blocks__functionalization_glue( + const std::vector &k_caches, + const std::vector &v_caches, + const c10::Dict> &block_mapping) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(k_caches)); + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(v_caches)); + at::functionalization::impl::sync(k_caches); + at::functionalization::impl::sync(v_caches); + auto new_k_caches = at::functionalization::impl::from_functional_tensor(k_caches); + auto new_v_caches = at::functionalization::impl::from_functional_tensor(v_caches); + + // Grab the dispatcher entry corresponding to the out-of-place op, "foo" + static auto op_handle = + c10::Dispatcher::singleton() + // specify namespace::op_name, op_overload_name + .findSchemaOrThrow("torch_mlu_ops::copy_blocks_out_of_place", "") + // Specify the C++ schema of the out-of-place op. + .typed, std::vector>( + const std::vector &k_caches, const std::vector &v_caches, + const c10::Dict> &block_mapping)>(); + // Next, redispatch to the out-of-place op, foo() (user called foo_, we call fooo) + std::tuple, std::vector> tmp_output; + { + at::AutoDispatchSkipFunctionalize guard; + tmp_output = op_handle.call(new_k_caches, new_v_caches, block_mapping); + } + // Finally, tell functionalization about this mutation. + at::functionalization::impl::replace_(k_caches, std::get<0>(tmp_output)); + at::functionalization::impl::replace_(v_caches, std::get<1>(tmp_output)); + at::functionalization::impl::commit_update(k_caches); + at::functionalization::impl::commit_update(v_caches); +} diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.h b/torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.h new file mode 100644 index 0000000..dfe3f3a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/glue_ops.h @@ -0,0 +1,22 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_TORCH_API_GLUE_OPS_H_ +#define CSRC_TORCH_API_GLUE_OPS_H_ +#include "torch/extension.h" +#include "utils.h" + +void copy_blocks__functionalization_glue( + const std::vector &k_caches, + const std::vector &v_caches, + const c10::Dict> &block_mapping); + +#endif // CSRC_TORCH_API_GLUE_OPS_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/group_gemm.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/group_gemm.cpp new file mode 100644 index 0000000..2feacb3 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/group_gemm.cpp @@ -0,0 +1,231 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { + +using GroupGemmDesc = op_desc::GroupGemmDesc; +using QuantMode = GroupGemmDesc::QuantMode; + +at::Tensor group_gemm(const at::Tensor &a_tensor, + const at::Tensor &b_tensor, + const at::Tensor &m_list, + const c10::optional &gather_idx, + const c10::optional &c_tensor, + const c10::optional &alpha, + const c10::optional &beta, + const c10::optional &a_scale, + const c10::optional &b_scale, + const c10::optional &bias, + const c10::optional &data_type, + const c10::optional> &quant_flag, + const c10::optional &b_offset, + const int64_t max_m) { + bool has_dtype = data_type.has_value(); + bool has_c_tensor = c_tensor.has_value(); + bool has_gather_idx = gather_idx.has_value(); + bool has_a_scale = a_scale.has_value(); + bool has_b_scale = b_scale.has_value(); + bool has_alpha = alpha.has_value(); + bool has_beta = beta.has_value(); + bool has_bias = bias.has_value(); + bool has_quant_flag = quant_flag.has_value(); + QuantMode quant_mode = QuantMode::noQuant; + bool quant_grouped = false; + + // check stride + TORCH_CHECK(a_tensor.stride(-1) == 1, "a_tensor last dim must be contiguous"); + CHECK_TENSOR_CONTIGUOUS(m_list); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(c_tensor); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_idx); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(a_scale); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(b_scale); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(alpha); + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(beta); + + // check dim + TORCH_CHECK(m_list.dim() == 1, "the shape of m_list should be 2D"); + TORCH_CHECK(a_tensor.dim() == 2, "the shape of a_tensor should be 2D"); + TORCH_CHECK(b_tensor.dim() == 3 || b_tensor.dim() == 4 || (b_tensor.dim() == 1 && has_quant_flag), + "the shape of b_tensor should be 3D or 4D"); + if (has_c_tensor) { + TORCH_CHECK(c_tensor.value().dim() == 2, "the shape of c_tensor should be 2D"); + } + if (has_gather_idx) { + TORCH_CHECK(gather_idx.value().dim() == 1, "the shape of gather_idx should be 1D"); + } + if (has_a_scale) { + TORCH_CHECK(a_scale.value().dim() == 1, "the shape of a_scale should be 1D"); + } + if (has_b_scale) { + TORCH_CHECK(b_scale.value().dim() == 2 || b_scale.value().dim() == 3, + "the shape of b_scale should be 2D or 3D"); + } + if (has_alpha) { + TORCH_CHECK(alpha.value().dim() == 1, "the shape of alpha should be 1D"); + } + if (has_beta) { + TORCH_CHECK(beta.value().dim() == 1, "the shape of beta should be 1D"); + } + + // check shape + int64_t experts_num = m_list.size(0); + int64_t num_token = a_tensor.size(0); + int64_t k = a_tensor.size(-1); + int64_t n = has_quant_flag ? b_scale.value().size(2) : b_tensor.size(1); + int64_t lda = a_tensor.stride(0); + int64_t total_m = has_gather_idx ? gather_idx.value().size(0) : num_token; + + TORCH_CHECK(experts_num > 0, "number of groups must be large than zero"); + CHECK_SHAPE(a_tensor, num_token, k); + if (has_a_scale) { + CHECK_SHAPE(a_scale.value(), total_m); + } + if (has_quant_flag) { + quant_grouped = true; + quant_mode = QuantMode::W4W8; + TORCH_CHECK(has_b_scale, "if quant_flag given, b_scale must exist"); + CHECK_SHAPE(b_scale.value(), b_scale.value().size(0), experts_num, n); + } else if (has_b_scale) { + int64_t b_k = b_tensor.size(-1); + quant_grouped = b_scale.value().dim() == 3 ? true : false; + if (quant_grouped) { + CHECK_SHAPE(b_scale.value(), b_scale.value().size(0), experts_num, n); + } else { + CHECK_SHAPE(b_scale.value(), experts_num, n); + } + TORCH_CHECK(k == b_k || k == b_k * 2, "k == b_k || k == b_k * 2."); + quant_mode = (b_k == k) ? QuantMode::W8 : QuantMode::W4; + } + if (has_alpha) { + CHECK_SHAPE(alpha.value(), experts_num); + } + if (has_beta) { + CHECK_SHAPE(beta.value(), experts_num); + } + auto b_shape = b_tensor.sizes(); + cnnlDataType_t a_dtype = getCnnlDataType(a_tensor.scalar_type()); + cnnlDataType_t b_dtype = getCnnlDataType(b_tensor.scalar_type()); + torch::Dtype dtype = a_tensor.scalar_type(); + if (quant_mode == QuantMode::W8 || quant_mode == QuantMode::noQuant) { + if (b_tensor.dim() == 3) { + CHECK_SHAPE(b_tensor, experts_num, n, k); + } else { + TORCH_CHECK(b_shape[0] * b_shape[2] == experts_num, "b_shape[0] * b_shape[2] == experts_num"); + CHECK_SHAPE(b_tensor, b_shape[0], n, b_shape[2], k); + } + } else if (quant_mode == QuantMode::W4) { + CHECK_SHAPE(b_tensor, experts_num, n, k / 2); + b_dtype = CNNL_DTYPE_INT4X2; + } + + TORCH_CHECK(m_list.dtype() == torch::kInt32, "data type of m_list should be int32"); + if (has_gather_idx) { + TORCH_CHECK(gather_idx.value().dtype() == torch::kInt32, + "data type of gather_idx should be int32"); + } + + if (has_a_scale) { + TORCH_CHECK(has_dtype, "data_type must given when a_scale and b_scale are given"); + TORCH_CHECK(data_type.value() == "float" || data_type.value() == "half" || + data_type.value() == "bfloat16", + "data_type must be 'float', 'half' or 'bfloat16'."); + dtype = data_type.value() == "float" ? torch::kFloat32 + : data_type.value() == "half" ? torch::kFloat16 + : torch::kBFloat16; + } + + const torch_mlu::mlu::MLUGuard device_guard(a_tensor.device()); + at::Tensor d_tensor = at::empty({total_m, n}, a_tensor.options().dtype(dtype)); + cnnlDataType_t d_dtype = getCnnlDataType(d_tensor.scalar_type()); + + // check device + TORCH_CHECK(isMlu(m_list), "m_list must on mlu"); + TORCH_CHECK(isMlu(a_tensor), "a_tensor must on mlu"); + TORCH_CHECK(isMlu(b_tensor), "b_tensor must on mlu"); + TORCH_CHECK(isMlu(d_tensor), "d_tensor must on mlu"); + if (has_c_tensor) { + TORCH_CHECK(isMlu(c_tensor.value()), "c_tensor must on mlu"); + } + if (has_gather_idx) { + TORCH_CHECK(isMlu(gather_idx.value()), "gather_idx must on mlu"); + } + if (has_a_scale) { + TORCH_CHECK(isMlu(a_scale.value()), "a_scale must on mlu"); + } + if (has_b_scale) { + TORCH_CHECK(isMlu(b_scale.value()), "b_scale must on mlu"); + } + if (has_alpha) { + TORCH_CHECK(isMlu(alpha.value()), "alpha must on mlu"); + } + if (has_beta) { + TORCH_CHECK(isMlu(beta.value()), "beta must on mlu"); + } + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + GroupGemmDesc group_gemm_desc(experts_num, max_m, n, k, d_dtype, has_gather_idx, quant_mode); + group_gemm_desc.setInputOutputTensor(a_dtype, b_dtype, d_dtype, CNNL_DTYPE_INT32, + getAtTensorPtr(a_tensor), getAtTensorPtr(b_tensor), + getAtTensorPtr(c_tensor), getAtTensorPtr(d_tensor), + getAtTensorPtr(gather_idx), k, lda, total_m, has_c_tensor, + (int64_t *)getAtTensorPtr(b_offset)); + std::vector flag_vec; + if (quant_mode != QuantMode::noQuant || has_bias) { + if (has_quant_flag) { + auto vec = quant_flag.value().vec(); + flag_vec.resize(vec.size()); + std::copy(vec.begin(), vec.end(), flag_vec.begin()); + } + group_gemm_desc.setPerRowColScaleBiasAct( + has_a_scale ? getAtTensorPtr(a_scale) : nullptr, + has_b_scale ? getAtTensorPtr(b_scale) : nullptr, flag_vec.data(), + has_bias ? getAtTensorPtr(bias) : nullptr, d_dtype, quant_grouped ? k : 0, + quant_grouped ? k / b_scale.value().size(0) : 0, 0); + } + + size_t group_gemm_wsize = + tmo::ops::getGroupGemmWorkspaceSize(handle, group_gemm_desc, experts_num); + auto group_gemm_workspace = + at::empty({static_cast(group_gemm_wsize)}, a_tensor.options().dtype(at::kByte)); + + std::vector ldb_array; + // if (quant_mode == QuantMode::W4W8) { + // ldb_array.resize(experts_num); + // auto q_group = b_scale.value().size(0); + // for (auto i = 0; i < experts_num; ++i) { + // int sum = 0; + // for (auto j = 0; j < q_group; j++) { + // sum += flag_vec[i * q_group + j]; + // } + // ldb_array[i] = (sum / 4) * (k / q_group / 2); + // } + // } + if (quant_mode != QuantMode::W4W8) { + int ldb = b_tensor.dim() == 3 ? k : b_shape[2] * k; + ldb_array.assign(experts_num, ldb); + } + GroupGemmTheory obj(total_m, experts_num, k, n, has_c_tensor, a_dtype, getCnnlDataType(dtype)); + cnpxPush(obj); + ops::GroupGemm(handle, group_gemm_desc, getAtTensorPtr(m_list), getAtTensorPtr(alpha), + getAtTensorPtr(beta), getAtTensorPtr(group_gemm_workspace), group_gemm_wsize, + experts_num, k, n, lda, ldb_array); + cnpxPop(); + return d_tensor; +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/matmul.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/matmul.cpp new file mode 100644 index 0000000..89329f1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/matmul.cpp @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" +namespace tmo { +namespace torch_api { + +at::Tensor matmul(const at::Tensor &a, + const at::Tensor &b, + const c10::optional &d, + const c10::optional &bias, + const c10::optional &c, + const c10::optional &dtype, + const std::string &act_mode, + double alpha, + double beta, + bool fast_act, + bool approximate, + double a_scale, + double b_scale, + bool trans_a, + bool trans_b) { + bool has_bias = bias.has_value(); + bool has_res = c.has_value(); + bool use_beta = false; + int m = trans_a ? a.size(1) : a.size(0); + int n = trans_b ? b.size(0) : b.size(1); + int k = trans_a ? a.size(0) : a.size(1); + // check device and dtype + checkTensorSameAttr(a, b, c, bias); + // check contiguous + TORCH_CHECK(a.stride(-1) == 1, "a last dim must be contiguous.") + TORCH_CHECK(b.stride(-1) == 1, "b last dim must be contiguous.") + if (has_bias) TORCH_CHECK(bias.value().is_contiguous(), "bias must be contiguous.") + if (has_res) { + use_beta = true; + TORCH_CHECK(c.value().is_contiguous(), "c must be contiguous.") + CHECK_SHAPE(c.value(), m, n); + } + at::Tensor bias_view; + if (has_bias) { + TORCH_CHECK(bias.value().dim() == 1, "bias must be 1-D tensor.") + bias_view = bias.value().unsqueeze(0); + } + // get cnnl data type and init output + auto a_dtype = getCnnlDataType(a.scalar_type()); + auto b_dtype = getCnnlDataType(b.scalar_type()); + TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype."); + if (a_dtype == CNNL_DTYPE_INT8) { + TORCH_CHECK(d.has_value() || dtype.has_value(), + "d and d_dtype cant be none at the same time when input dtype is int8"); + } + + at::Tensor output; + if (d.has_value()) { + output = d.value(); + } else { + torch::Dtype out_dtype = a.scalar_type(); + if (dtype.has_value()) { + TORCH_CHECK( + dtype.value() == "float" || dtype.value() == "half" || dtype.value() == "bfloat16", + "data_type must be 'float', 'half' or 'bfloat16'."); + out_dtype = dtype.value() == "float" ? torch::kFloat32 + : dtype.value() == "half" ? torch::kFloat16 + : torch::kBFloat16; + } + output = at::empty({m, n}, a.options().dtype(out_dtype)); + } + + auto cd_dtype = getCnnlDataType(output.scalar_type()); + if (a_dtype == CNNL_DTYPE_BFLOAT16) cd_dtype = CNNL_DTYPE_FLOAT; + // create tensor desc + auto descs = createTensorDescs({a, b, bias_view, c.value_or(at::Tensor()), output}); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[4].get(), cd_dtype)); + if (has_res) { + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[3].get(), cd_dtype)); + } + if (a_dtype == CNNL_DTYPE_INT8) { + float int_max = 127.0; + int quant_bit = 8; + float max_a = int_max / a_scale; + float max_b = int_max / b_scale; + int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2)); + int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2)); + float new_a_scale = std::pow(2.0f, pos_a) * a_scale; + float new_b_scale = std::pow(2.0f, pos_b) * b_scale; + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale)); + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale)); + TORCH_CHECK(cd_dtype != CNNL_DTYPE_BFLOAT16, + "output dtype cannot be bfloat16 when a/b is fixed-point") + } + + // get && set Op desc + int lda = a.stride(0); + int ldb = b.stride(0); + int ldc = output.stride(0); + size_t workspace_size = 0; + auto matmul_desc = + tmo::op_desc::MatMulDesc(descs[2].get(), getAtTensorPtr(bias_view), workspace_size, lda, ldb, + ldc, act_mode, use_beta, trans_a, trans_b, fast_act, approximate); + const torch_mlu::mlu::MLUGuard device_guard(a.device()); + auto handle = torch_mlu::getCurrentHandle(); + auto workspace = at::empty({static_cast(workspace_size)}, a.options().dtype(at::kByte)); + + // run forward + float alpha_f = alpha; + float beta_f = beta; + MatmulTheory obj(m, k, n, has_res, a_dtype, cd_dtype); + cnpxPush(obj); + CNNL_CHECK_FATAL(cnnlMatMulEx(handle, matmul_desc, &alpha_f, descs[0].get(), getAtTensorPtr(a), + descs[1].get(), getAtTensorPtr(b), &beta_f, descs[3].get(), + getAtTensorPtr(c), descs[4].get(), getAtTensorPtr(output), nullptr, + getAtTensorPtr(workspace), workspace_size)); + cnpxPop(); + return output; +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_cast_gating.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_cast_gating.cpp new file mode 100644 index 0000000..ad7235b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_cast_gating.cpp @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include +#include "kernels/moe/cast_gating.mluh" +#include "torch_ops_api.h" +namespace tmo { +namespace torch_api { + +at::Tensor moe_cast_gating(const at::Tensor &input, const at::Tensor &weight) { + // check + checkTensorSameAttr(input, weight); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous") + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous") + TORCH_CHECK(weight.dim() == 2, "weight.dim() == 2") + TORCH_CHECK(input.size(-1) == weight.size(-1), "input.size(-1) == weight.size(-1)") + TORCH_CHECK(input.scalar_type() == torch::kFloat16 || input.scalar_type() == torch::kBFloat16, + "input type need be torch::kFloat16 or torch::kBFloat16") + TORCH_CHECK(weight.scalar_type() == torch::kFloat32, "weight type need be torch::kFloat32") + + auto hidden_size = input.size(-1); + auto expert_num = weight.size(0); + TORCH_CHECK(hidden_size > 0 && hidden_size <= 16384, "hidden_size > 0 && hidden_size <= 16384") + TORCH_CHECK(expert_num > 0 && expert_num <= 128, "expert_num > 0 && expert_num <= 128") + auto input_shape = input.sizes(); // [..., hidden_size] + auto input_row = + std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies()); + std::vector output_shape(input_shape.begin(), input_shape.end() - 1); + output_shape.push_back(expert_num); + + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + auto output = at::empty(output_shape, weight.options()); + const int64_t workspace_size = 16 * 1024 * 1024; + auto workspace = at::empty({workspace_size}, weight.options().dtype(at::kByte)); + auto queue = torch_mlu::getCurMLUStream(); + auto input_dtype = getCnnlDataType(input.scalar_type()); + TMO_KERNEL_CHECK_FATAL(tmo::invokeCastGating( + queue, input.data_ptr(), weight.data_ptr(), output.data_ptr(), input_row, expert_num, + hidden_size, input_dtype, workspace.data_ptr(), workspace_size)); + return output; +} +} // namespace torch_api +} // namespace tmo \ No newline at end of file diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_combine_result.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_combine_result.cpp new file mode 100644 index 0000000..6a50b03 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_combine_result.cpp @@ -0,0 +1,124 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "kernels/moe/combine_result.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { +at::Tensor moe_combine_result(const at::Tensor &input, + const at::Tensor &reduce_weight, + const at::Tensor &gather_ids, + const c10::optional &residual, + const c10::optional &cusum_token_count, + int64_t start_expert_id, + int64_t expert_size, + const c10::optional &bias) { + // check device and dtype + checkTensorSameAttr(input, reduce_weight, residual, bias, cusum_token_count, + gather_ids); + + // check contiguous + CHECK_TENSOR_CONTIGUOUS(input) + CHECK_TENSOR_CONTIGUOUS(reduce_weight) + CHECK_TENSOR_CONTIGUOUS(gather_ids) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(residual) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(bias) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(cusum_token_count) + + // do not support bias + TORCH_CHECK(!bias.has_value(), "bias is not supported."); + + // check input + TORCH_CHECK(input.dim() == 2, "input should be with shape [num_tokens, hidden_size]."); + TORCH_CHECK(input.dtype() == torch::kFloat32 || input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16, + "input dtype should be float32/float16/bfloat16."); + + int num_tokens = input.size(0); + int hidden_size = input.size(1); + + // check reduce weight + TORCH_CHECK(reduce_weight.dim() == 2, "reduce_weight should be with shape [num_token, topk]."); + TORCH_CHECK(reduce_weight.dtype() == torch::kFloat32, + "reduce_weight should be with dtype float32."); + int topk = reduce_weight.size(1); + TORCH_CHECK(reduce_weight.numel() == num_tokens, + "reduce_weight.numel() should equal to input.size(0)."); + int num_token = num_tokens / topk; + + // check gather_ids + TORCH_CHECK(gather_ids.dim() == 1, "gather_ids should be with shape [num_tokens]."); + TORCH_CHECK(gather_ids.dtype() == torch::kInt32, "gather_ids should be with dtype int32."); + TORCH_CHECK(gather_ids.size(0) == num_tokens, + "gather_ids.numel() should equal to input.size(0)."); + + if (residual.has_value()) { + TORCH_CHECK(residual.value().dim() == 2, + "residual should be with shape [num_token, hidden_size]."); + TORCH_CHECK(residual.value().dtype() == input.dtype(), + "residual should have same dtype with input."); + TORCH_CHECK(residual.value().size(0) == num_token, + "residual.size(0) should equal to input.size(0)."); + TORCH_CHECK(residual.value().size(1) == hidden_size, + "residual.size(1) should equal to input.size(1)."); + } + + // check bias and cusum_token_count + bool has_bias = bias.has_value(); + bool has_cusum_token_count = cusum_token_count.has_value(); + TORCH_CHECK((!has_bias || (has_bias && has_cusum_token_count)), + "if bias is not None, cusum_token_count should not be None."); + + int num_expert = -1; + if (has_bias) { + TORCH_CHECK(bias.value().dim() == 2, "bias should be with shape [num_expert, hidden_size]."); + TORCH_CHECK(bias.value().dtype() == input.dtype(), "bias should have same dtype with input."); + TORCH_CHECK(bias.value().size(1) == hidden_size, "bias.size(1) should equal to input.size(1)."); + num_expert = bias.value().size(0); + } + + if (has_cusum_token_count) { + TORCH_CHECK(cusum_token_count.value().dim() == 1, + "cusum_token_count should be with shape [num_expert + 1]."); + TORCH_CHECK(cusum_token_count.value().dtype() == torch::kInt32, + "cusum_token_count should be with dtype int32."); + if (num_expert > 0) { + TORCH_CHECK(cusum_token_count.value().dim() == 1, + "cusum_token_count should be with shape [num_expert + 1]."); + } else { + num_expert = cusum_token_count.value().size(0) - 1; + } + } + + // check expert + TORCH_CHECK(start_expert_id >= 0, "start_expert_id shoule be larger or equal to 0."); + TORCH_CHECK(num_expert == -1 || num_expert >= (start_expert_id + expert_size), + "num_expert shape shoule be larger or equal to start_expert_id + expert_size."); + if (num_expert == -1) { + num_expert = expert_size; + } + + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + auto queue = torch_mlu::getCurMLUStream(); + auto output_shape = std::vector({num_token, hidden_size}); + auto output = at::empty(output_shape, input.options()); + auto dtype = getCnnlDataType(input.scalar_type()); + + invokeMoeCombineResultKernel(queue, output.data_ptr(), input.data_ptr(), getAtTensorPtr(bias), + getAtTensorPtr(residual), (float *)reduce_weight.data_ptr(), + (int *)getAtTensorPtr(cusum_token_count), + (int *)getAtTensorPtr(gather_ids), num_token, topk, num_expert, + hidden_size, start_expert_id, expert_size, dtype); + return output; +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_expand_input.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_expand_input.cpp new file mode 100644 index 0000000..bd32c91 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_expand_input.cpp @@ -0,0 +1,60 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/moe/expand_input.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +at::Tensor moe_expand_input(const torch::Tensor &input, + const torch::Tensor &gather_idx, + const c10::optional &cusum_token_count, + int64_t start_expert_id, + int64_t expert_size) { + TORCH_CHECK(input.dim() == 2, "input dim must be equal to 2.") + TORCH_CHECK(gather_idx.dim() == 1, "gather_idx dim must be equal to 1.") + // check device and dtype + TORCH_CHECK(isMlu(input), "input tensor must on mlu.") + TORCH_CHECK(isMlu(gather_idx), "gather_idx must on mlu.") + TORCH_CHECK(gather_idx.dtype() == torch::kInt32, "data type of gather_idx must be int32.") + int64_t expand_token_num = gather_idx.size(0); + int64_t token_num = input.size(0); + int64_t hidden_size = input.size(1); + TORCH_CHECK(expand_token_num % token_num == 0, "expand_token_num % token_num == 0.") + int64_t topk = expand_token_num / token_num; + int64_t expert_num = 0; + if (cusum_token_count.has_value()) { + TORCH_CHECK(isMlu(cusum_token_count.value()), "cusum_token_count must on mlu.") + TORCH_CHECK(cusum_token_count.value().dtype() == torch::kInt32, + "data type of cusum_token_count must be int32.") + expert_num = cusum_token_count.value().size(0) - 1; + TORCH_CHECK(start_expert_id >= 0 && start_expert_id < expert_num, + "start_expert_id >=0 && start_expert_id < expert_num.") + TORCH_CHECK((start_expert_id + expert_size) <= expert_num, + "start_expert_id + expert_size <= expert_num.") + } + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + std::vector output_shape = {expand_token_num, hidden_size}; + auto output = torch::empty(output_shape, input.options()); + + auto dtype = getCnnlDataType(input.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + + tmo::invokeMoeExpandInputKernel(queue, getAtTensorPtr(output), getAtTensorPtr(input), + (int *)getAtTensorPtr(gather_idx), + (int *)getAtTensorPtr(cusum_token_count), token_num, hidden_size, + topk, dtype, expert_num, start_expert_id, expert_size); + return output; +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_gen_idx.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_gen_idx.cpp new file mode 100644 index 0000000..4414355 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_gen_idx.cpp @@ -0,0 +1,44 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/moe/gen_idx.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +std::vector moe_gen_idx(const torch::Tensor &expert_id, int64_t expert_num) { + TORCH_CHECK(expert_id.dim() == 2, "expert_id dim must be equal to 2.") + int64_t token_num = expert_id.size(0); + int64_t topk = expert_id.size(1); + + // check device and dtype + TORCH_CHECK(isMlu(expert_id), "expert_id must on mlu.") + TORCH_CHECK(expert_id.dtype() == torch::kInt32, "data type of expert_id must be int32.") + + const torch_mlu::mlu::MLUGuard device_guard(expert_id.device()); + auto expand_idx = at::empty({token_num * topk}, expert_id.options()); + auto combine_idx = at::empty({token_num * topk}, expert_id.options()); + auto token_count = at::empty({expert_num}, expert_id.options()); + auto cusum_token_count = at::empty({expert_num + 1}, expert_id.options()); + auto gen_idx_workspace = at::empty({expert_num + 1 + token_num * topk}, expert_id.options()); + + auto queue = torch_mlu::getCurMLUStream(); + tmo::invokeMoeGenIdxKernel( + queue, (int *)getAtTensorPtr(expand_idx), (int *)getAtTensorPtr(combine_idx), + (int *)getAtTensorPtr(token_count), (int *)getAtTensorPtr(cusum_token_count), + getAtTensorPtr(gen_idx_workspace), getAtTensorPtr(expert_id), token_num, expert_num, topk); + std::vector output = {expand_idx, combine_idx, token_count, cusum_token_count}; + return output; +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_softmax_topk.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_softmax_topk.cpp new file mode 100644 index 0000000..0bdeb5b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/moe_softmax_topk.cpp @@ -0,0 +1,83 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/moe/softmax_topk.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +std::vector moe_softmax_topk(const at::Tensor &input, // (num_token, num_expert) + int64_t topk, + int64_t num_expert_group, + int64_t topk_group, + bool normalize, + const c10::optional &mask, + const std::string &normed_by) { + CHECK_TENSOR_CONTIGUOUS(input) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(mask) + int normalize_mode = 0; + if (normalize) { + TORCH_CHECK(normed_by == "topk_logit" || normed_by == "softmax_logit", + "normed_by must be 'topk_logit' or 'softmax_logit'") + if (normed_by == "topk_logit") { + normalize_mode = 1; + } else if (normed_by == "softmax_logit") { + normalize_mode = 2; + } + } + TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2") + int num_expert = input.size(-1); + int num_token = (int)(input.numel() / num_expert); + int num_mask = 1; + auto input_shape = input.sizes(); + TORCH_CHECK(topk > 0 && topk <= num_expert, "topk > 0 && topk <= num_expert") + bool has_mask = mask.has_value(); + if (has_mask) { + TORCH_CHECK(mask.value().dim() == input.dim(), "the dim of mask should be the same as input") + TORCH_CHECK(mask.value().size(-1) == num_expert, + "the last dim of mask should be the same as the last dim of input") + TORCH_CHECK(mask.value().size(-2) == input.size(-2), + "the penultimate dim of mask should be the same as the penultimate dim of input") + int high_dim = mask.value().numel() / (mask.value().size(-1) * mask.value().size(-2)); + TORCH_CHECK(high_dim == 1, "the product of all but the lower two dimensions of mask is 1") + num_mask = mask.value().numel() / num_expert; + TORCH_CHECK(mask.value().dtype() == input.dtype(), + "the dtype of mask should be the same as input") + } + if (num_expert_group > 1) { + TORCH_CHECK(has_mask == false, "if num_expert_group > 1, mask should be None") + TORCH_CHECK(num_expert % num_expert_group == 0, "num_expert % num_expert_group == 0") + TORCH_CHECK(topk_group > 0 && topk_group <= num_expert_group, + "topk_group > 0 && topk_group <= num_expert_group") + TORCH_CHECK(topk <= (num_expert / num_expert_group) * topk_group, + "topk <= (num_expert / num_expert_group) * topk_group") + } + std::vector out_shape(input_shape.begin(), input_shape.end()); + out_shape.back() = topk; + + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + auto tensor_options = input.options(); + auto reduce_weight = at::empty(out_shape, tensor_options.dtype(torch::kFloat)); + auto expert_id = at::empty(out_shape, tensor_options.dtype(torch::kInt32)); + + auto queue = torch_mlu::getCurMLUStream(); + auto input_dtype = getCnnlDataType(input.scalar_type()); + TMO_KERNEL_CHECK_FATAL(invokeMoeSoftmaxTopkKernel( + queue, (float *)reduce_weight.data_ptr(), (int *)expert_id.data_ptr(), input.data_ptr(), + getAtTensorPtr(mask), num_token, num_expert, num_mask, topk, num_expert_group, topk_group, + input_dtype, normalize_mode)); + return {reduce_weight, expert_id}; +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/offline_quant_to_linear_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/offline_quant_to_linear_cache.cpp new file mode 100644 index 0000000..f0309a9 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/offline_quant_to_linear_cache.cpp @@ -0,0 +1,177 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/offline_quant_to_linear_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +void offline_quant_to_linear_cache( + const at::Tensor &key, // [total_seqlen, num_heads, head_size] + // or [bs, max_context_len, num_heads, head_size] + const c10::optional &value, + at::Tensor &key_cache, // [max_bs, num_heads, cache_memory_len, head_size] + const c10::optional &value_cache, + const at::Tensor &key_cache_scale, // [num_heads, cache_memory_len] or [num_heads, head_size] + const c10::optional &value_cache_scale, + const at::Tensor &context_lengths, + const int64_t max_context_len, + const int64_t quant_mode, // 0:per_channel, others:per_head + const bool packed, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seqlen_offset) { + /****************************************check key***************************************/ + // check dtype + TORCH_CHECK(key.scalar_type() == torch::kFloat32 || key.scalar_type() == torch::kFloat16 || + key.scalar_type() == torch::kBFloat16, + "key type need be torch::kFloat32, torch::kFloat16 or torch::kBFloat16"); + TORCH_CHECK(key_cache.scalar_type() == torch::kInt8, "key_cache type need be torch::kInt8"); + TORCH_CHECK(key_cache_scale.scalar_type() == torch::kFloat32, + "key_cache_scale type need be torch::kFloat32"); + + // check shape + const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0); + const int max_bs = key_cache.size(0); + const int num_heads = key_cache.size(1); + const int cache_mem_len = key_cache.size(2); + const int head_size = key_cache.size(3); + const int total_seqlen = packed ? key.size(0) : batch_size * key.size(-2); + + TORCH_CHECK(key_cache.size(0) >= batch_size, + "first dim of key_cache must be great than or equal to batch_size"); + if (packed) { + CHECK_SHAPE(key, total_seqlen, num_heads, head_size); + CHECK_SHAPE(context_lengths, batch_size + 1); + } else { + CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size); + CHECK_SHAPE(context_lengths, batch_size); + } + if (quant_mode == 0) { + CHECK_SHAPE(key_cache_scale, num_heads, head_size); + } else { + CHECK_SHAPE(key_cache_scale, num_heads, cache_mem_len); + } + + // check contiguous + TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous."); + TORCH_CHECK(key_cache_scale.is_contiguous(), "key_cache_scale tensor must be contiguous."); + + // check device + const int64_t device_id = key.get_device(); + TORCH_CHECK(key_cache.get_device() == device_id, + "key_cache tensor device index is not same, original index: ", device_id, + "now index is: ", key_cache.get_device()); + TORCH_CHECK(key_cache_scale.get_device() == device_id, + "key_cache_scale tensor device index is not same, original index: ", device_id, + "now index is: ", key_cache_scale.get_device()); + + /***************************************check value***************************************/ + if (value.has_value() || value_cache.has_value() || value_cache_scale.has_value()) { + TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_scale.has_value(), + "value_cache, value_cache_scale, value must all have value.") + } + if (value_cache.has_value()) { + TORCH_CHECK(value_cache.value().scalar_type() == torch::kInt8, + "value_cache type need be torch::kInt8"); + TORCH_CHECK(value_cache_scale.value().scalar_type() == torch::kFloat32, + "value_cache_scale type need be torch::kFloat32"); + TORCH_CHECK(value.value().scalar_type() == key.scalar_type(), + "value type need be same with key"); + if (packed) { + CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size); + } else { + CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size); + } + if (quant_mode == 0) { + CHECK_SHAPE(value_cache_scale.value(), num_heads, head_size); + } else { + CHECK_SHAPE(value_cache_scale.value(), num_heads, cache_mem_len); + } + CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size); + + for (int i = 0; i < key.dim(); i++) { + TORCH_CHECK(value.value().stride(i) == key.stride(i), + "key and value must have same stride along axi ", i); + } + TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous."); + TORCH_CHECK(value_cache_scale.value().is_contiguous(), + "value_cache_scale tensor must be contiguous."); + TORCH_CHECK(value_cache.value().get_device() == device_id, + "value_cache tensor device index is not same, original index: ", device_id, + "now index is: ", value_cache.value().get_device()); + TORCH_CHECK(value_cache_scale.value().get_device() == device_id, + "value_cache_scale tensor device index is not same, original index: ", device_id, + "now index is: ", value_cache_scale.value().get_device()); + } + + TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32, + "context_lengths type need be torch::kInt32"); + /*********************************check optional tensor***********************************/ + const void *context_seq_offset_ptr = nullptr; + if (context_seq_offset.has_value()) { + CHECK_SHAPE(context_seq_offset.value(), batch_size); + TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32, + "context_seq_offset type need be torch::kInt32"); + TORCH_CHECK(context_seq_offset.value().get_device() == device_id, + "context_seq_offset tensor device index is not same, original index: ", device_id, + "now index is: ", context_seq_offset.value().get_device()); + context_seq_offset_ptr = context_seq_offset.value().data_ptr(); + } + + const void *cache_bs_id_ptr = nullptr; + if (cache_bs_id.has_value()) { + CHECK_SHAPE(cache_bs_id.value(), batch_size); + TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32, + "cache_bs_id type need be torch::kInt32"); + TORCH_CHECK(cache_bs_id.value().get_device() == device_id, + "cache_bs_id tensor device index is not same, original index: ", device_id, + "now index is: ", cache_bs_id.value().get_device()); + cache_bs_id_ptr = cache_bs_id.value().data_ptr(); + } + + const void *cache_seqlen_offset_ptr = nullptr; + if (cache_seqlen_offset.has_value()) { + CHECK_SHAPE(cache_seqlen_offset.value(), batch_size); + TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32, + "cache_seqlen_offset type need be torch::kInt32"); + TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id, + "cache_seqlen_offset tensor device index is not same, original index: ", device_id, + "now index is: ", cache_seqlen_offset.value().get_device()); + cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr(); + } + + const size_t context_bs_stride = packed ? 0 : key.stride(0); + const size_t context_head_stride = packed ? key.stride(1) : key.stride(2); + const size_t context_seq_stride = packed ? key.stride(0) : key.stride(1); + const size_t cache_bs_stride = key_cache.stride(0); + const size_t cache_head_stride = key_cache.stride(1); + const size_t cache_seq_stride = key_cache.stride(2); + const size_t cache_scale_head_stride = key_cache_scale.stride(0); + + const torch_mlu::mlu::MLUGuard device_guard(key.device()); + auto data_dtype = getCnnlDataType(key.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + + // run forward + TMO_KERNEL_CHECK_FATAL(invokeOfflineQuantToLinearCache( + queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache), + getAtTensorPtr(key_cache_scale), getAtTensorPtr(value_cache_scale), cache_bs_id_ptr, + cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr, + getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size, + (int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride, + context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, + cache_scale_head_stride, packed, quant_mode)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/offline_quant_to_paged_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/offline_quant_to_paged_cache.cpp new file mode 100644 index 0000000..130fe4b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/offline_quant_to_paged_cache.cpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/offline_quant_to_paged_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { +void offline_quant_to_paged_cache( + const torch::Tensor &k, // [num_tokens, num_heads, head_size] + const c10::optional &v, // [num_tokens, num_heads, head_size] + const torch::Tensor &k_cache_scale, // [num_heads, head_size] + const c10::optional &v_cache_scale, // [num_heads, head_size] + const torch::Tensor &slot_mapping, + torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_size] + const c10::optional + &v_cache) { // [num_blocks, num_heads, block_size, head_size] + // 1. check device and tensor type + TORCH_CHECK(slot_mapping.dtype() == torch::kInt32, "slot_mapping type need be torch::kInt32"); + checkTensorSameAttr(k, v, k_cache_scale, v_cache_scale, slot_mapping); + + // 2. check dim and shape + TORCH_CHECK(k.dim() == 3, "dim of k must be 3"); + TORCH_CHECK(k_cache.dim() == 4, "dim of k_cache must be 4"); + TORCH_CHECK(k_cache_scale.dim() == 2, "dim of k_cache_scale must be 2"); + TORCH_CHECK( + v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value(), + "v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value()."); + if (v.has_value()) { + TORCH_CHECK(v.value().dim() == 3, "dim of v must be 3"); + TORCH_CHECK(v_cache.value().dim() == 4, "dim of v_cache must be 4"); + TORCH_CHECK(v_cache_scale.value().dim() == 2, "dim of v_cache_scale must be 2"); + } + TORCH_CHECK(slot_mapping.dim() == 1, "dim of slot_mapping must be 1"); + + const int num_tokens = k.size(0); + const int num_heads = k.size(1); + const int head_size = k.size(2); + const int block_size = k_cache.size(2); + const int num_blocks = k_cache.size(0); + + CHECK_SHAPE(k, num_tokens, num_heads, head_size); + CHECK_SHAPE(k_cache_scale, num_heads, head_size); + CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_size); + if (v.has_value()) { + CHECK_SHAPE(v.value(), num_tokens, num_heads, head_size); + CHECK_SHAPE(v_cache_scale.value(), num_heads, head_size); + CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_size); + } + CHECK_SHAPE(slot_mapping, num_tokens); + + // 3. check strides + TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous."); + TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous."); + TORCH_CHECK(k.stride(-2) == head_size, "k second dim must be contiguous."); + if (v.has_value()) { + TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous."); + TORCH_CHECK(v.value().stride(-2) == head_size, "v second dim must be contiguous."); + } + + int64_t kv_cache_range = + (int64_t)num_blocks * num_heads * block_size * head_size * k_cache.element_size(); + TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G."); + + const torch_mlu::mlu::MLUGuard device_guard(k.device()); + auto queue = torch_mlu::getCurMLUStream(); + cnnlDataType_t dtype = getCnnlDataType(k.scalar_type()); + int32_t key_stride0 = static_cast(k.stride(0)); + int32_t value_stride0 = 1; + if (v.has_value()) { + value_stride0 = static_cast(v.value().stride(0)); + } + TMO_KERNEL_CHECK_FATAL(tmo::invokeOfflineQuantToPagedCache( + queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache), + getAtTensorPtr(v_cache), getAtTensorPtr(k_cache_scale), getAtTensorPtr(v_cache_scale), + getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, num_heads, num_blocks, + block_size, head_size)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/op_theory.h b/torch_mlu_ops-v1.3.2/csrc/torch_api/op_theory.h new file mode 100644 index 0000000..47381a1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/op_theory.h @@ -0,0 +1,423 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_TORCH_API_OP_THEORY_H_ +#define CSRC_TORCH_API_OP_THEORY_H_ + +#include "utils.h" + +namespace tmo { +namespace torch_api { + +// base class +struct OpTheory { + virtual size_t getTheoryCalc(void) const = 0; + virtual size_t getTheoryIO(void) const = 0; + virtual cnnlDataType_t getCalcDtype(void) const = 0; + virtual const std::string &getOpName(void) const = 0; +}; + +struct MatmulTheory : public OpTheory { + int m_, k_, n_; + bool has_res_; + cnnlDataType_t calc_type_, data_type_; + std::string name_; + MatmulTheory(int m, + int k, + int n, + bool has_res, + cnnlDataType_t calc_type, + cnnlDataType_t data_type) + : OpTheory(), + m_(m), + k_(k), + n_(n), + has_res_(has_res), + calc_type_(calc_type), + data_type_(data_type), + name_("matmul") {} + // lt: m * n * k, ct: add_bias + act + size_t getTheoryCalc(void) const override { + size_t size_calc = (size_t)m_ * n_ * k_; + return size_calc; + } + size_t getTheoryIO(void) const override { + size_t calc_dw, data_dw; + cnnlGetSizeOfDataType(calc_type_, &calc_dw); + cnnlGetSizeOfDataType(data_type_, &data_dw); + return calc_dw * (m_ * k_ + n_ * k_) + data_dw * m_ * n_ * (1 + has_res_); + } + cnnlDataType_t getCalcDtype(void) const override { return calc_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct QuantMatmulTheory : public OpTheory { + int m_, k_, n_, quant_bit_size_; + bool has_res_; + cnnlDataType_t data_type_; + std::string quant_algo_, name_; + QuantMatmulTheory(int m, + int k, + int n, + int quant_bit_size, + bool has_res, + cnnlDataType_t data_type, + std::string quant_algo) + : OpTheory(), + m_(m), + k_(k), + n_(n), + quant_bit_size_(quant_bit_size), + has_res_(has_res), + data_type_(data_type), + quant_algo_(quant_algo), + name_("quant_matmul") {} + size_t getTheoryCalc(void) const override { + size_t size_calc = (size_t)m_ * n_ * k_; + return size_calc; + } + size_t getTheoryIO(void) const override { + size_t io_size = 0, data_dw = 0; + cnnlGetSizeOfDataType(data_type_, &data_dw); + if (quant_algo_ == "weight_only") { + io_size = (data_dw * m_ * k_ + m_ * n_ * (1 + has_res_)) + n_ * k_ * (quant_bit_size_ / 8.0); + } else { + io_size = data_dw * m_ * n_ * (1 + has_res_) + (m_ * k_ + n_ * k_) * (quant_bit_size_ / 8.0); + } + return io_size; + } + cnnlDataType_t getCalcDtype(void) const override { + if (quant_algo_ == "weight_only") + return data_type_; + else + return CNNL_DTYPE_INT8; + } + const std::string &getOpName(void) const override { return name_; } +}; + +struct GroupGemmTheory : public OpTheory { + int total_m_, experts_num_, k_, n_, quant_bit_size_; + bool has_res_; + cnnlDataType_t calc_type_, data_type_; + std::string name_; + GroupGemmTheory(int total_m, + int experts_num, // The max number of processed experts. + int k, + int n, + bool has_res, + cnnlDataType_t calc_type, + cnnlDataType_t data_type) + : OpTheory(), + total_m_(total_m), + experts_num_(experts_num), + k_(k), + n_(n), + has_res_(has_res), + calc_type_(calc_type), + data_type_(data_type), + name_("group_gemm") {} // smooth_quant or float-point + size_t getTheoryCalc(void) const override { + size_t size_calc = (size_t)total_m_ * n_ * k_; + return size_calc; + } + size_t getTheoryIO(void) const override { + size_t calc_dw, data_dw; + cnnlGetSizeOfDataType(calc_type_, &calc_dw); + cnnlGetSizeOfDataType(data_type_, &data_dw); + // assume load max_num weights + size_t size_io = calc_dw * ((size_t)total_m_ * k_ + (size_t)experts_num_ * n_ * k_) + + data_dw * total_m_ * n_ * (1 + has_res_); + return size_io; + } + cnnlDataType_t getCalcDtype(void) const override { return calc_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct FlashAttnTheory : public OpTheory { + int batch_, total_q_, total_k_, head_num_q_, head_num_kv_, head_size_qk_, head_size_v_; + bool is_causal_; + cnnlDataType_t data_type_; + std::string name_; + FlashAttnTheory(int batch, + int total_q, + int total_k, + int head_num_q, + int head_num_kv, + int head_size_qk, + int head_size_v, + bool is_causal, + cnnlDataType_t data_type) + : OpTheory(), + batch_(batch), + total_q_(total_q), + total_k_(total_k), + head_num_q_(head_num_q), + head_num_kv_(head_num_kv), + head_size_qk_(head_size_qk), + head_size_v_(head_size_v), + is_causal_(is_causal), + data_type_(data_type), + name_("flash_attention") {} + size_t getTheoryCalc(void) const override { + // q * k + qk * v + size_t seq_q = total_q_ / batch_; + size_t seq_kv = total_k_ / batch_; + if (is_causal_) seq_kv = seq_kv - 0.5 * seq_q; + size_t size_lt = (size_t)batch_ * head_num_q_ * seq_q * seq_kv * (head_size_qk_ + head_size_v_); + return size_lt; + } + size_t getTheoryIO(void) const override { + size_t data_dw; + cnnlGetSizeOfDataType(data_type_, &data_dw); + size_t size_io = data_dw * (head_num_q_ * total_q_ + head_num_kv_ * total_k_) * + (head_size_qk_ + head_size_v_); + return size_io; + } + cnnlDataType_t getCalcDtype(void) const override { return data_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct SingleQueryCachedKvAttnTheory : public OpTheory { + int batch_, seq_q_, head_num_q_, head_num_kv_, num_block_, block_size_, head_size_qk_, + head_size_v_, kv_cache_quant_bit_size_; + cnnlQuantizeLayout_t cache_quant_layout_; + cnnlDataType_t data_type_; + std::string name_; + SingleQueryCachedKvAttnTheory(int batch, + int seq_q, + int head_num_q, + int head_num_kv, + int num_block, + int block_size, + int head_size_qk, + int head_size_v, + int kv_cache_quant_bit_size, + cnnlQuantizeLayout_t cache_quant_layout, + cnnlDataType_t data_type) + : OpTheory(), + batch_(batch), + seq_q_(seq_q), + head_num_q_(head_num_q), + head_num_kv_(head_num_kv), + num_block_(num_block), + block_size_(block_size), + head_size_qk_(head_size_qk), + head_size_v_(head_size_v), + kv_cache_quant_bit_size_(kv_cache_quant_bit_size), + cache_quant_layout_(cache_quant_layout), + data_type_(data_type), + name_("single_query_cached_kv_attn") {} + size_t getTheoryCalc(void) const override { + // q * k + qk * v + size_t size_lt = + (size_t)seq_q_ * head_num_q_ * num_block_ * block_size_ * (head_size_qk_ + head_size_v_); + return size_lt; + } + size_t getTheoryIO(void) const override { + size_t data_dw = 0, dw_float = 0, size_scale = 0; + cnnlGetSizeOfDataType(data_type_, &data_dw); + cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float); + size_t cache_dw = kv_cache_quant_bit_size_; + if (kv_cache_quant_bit_size_ == -1) { + cache_dw = data_dw * 8; + } else { + if (cache_quant_layout_ == CNNL_QUANTIZE_PER_CHANNEL) { + size_scale = dw_float * head_num_kv_ * (head_size_qk_ + head_size_v_); + } else { + size_scale = dw_float * num_block_ * head_num_kv_ * block_size_ * 2; + } + } + size_t size_io = + size_scale + data_dw * batch_ * seq_q_ * head_num_q_ * (head_size_qk_ + head_size_v_) + + (cache_dw / 8) * head_num_kv_ * num_block_ * block_size_ * (head_size_qk_ + head_size_v_); + return size_io; + } + cnnlDataType_t getCalcDtype(void) const override { return data_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct FusedNormTheory : public OpTheory { + int num_token_, hidden_size_; + bool has_res_, has_bias_, quant_out_, dynamic_quant_, store_output_before_norm_; + cnnlDataType_t data_type_; + std::string norm_mode_, name_; + FusedNormTheory(int num_token, + int hidden_size, + bool has_res, + bool has_bias, + bool quant_out, + bool dynamic_quant, + bool store_output_before_norm, + cnnlDataType_t data_type, + std::string norm_mode) + : OpTheory(), + num_token_(num_token), + hidden_size_(hidden_size), + has_res_(has_res), + has_bias_(has_bias), + quant_out_(quant_out), + dynamic_quant_(dynamic_quant), + store_output_before_norm_(store_output_before_norm), + data_type_(data_type), + norm_mode_(norm_mode), + name_("fused_norm") {} + size_t getTheoryCalc(void) const override { + size_t size_ct = 0; + if (norm_mode_ == "layernorm") { + size_ct = 7; // avg + sub + square + sum + cycle_mul + mul gamma + add beta, + } else if (norm_mode_ == "rmsnorm") { + size_ct = 4; // square + avg + cycle_mul + mul gamma + } + size_ct += (has_res_ + has_bias_); + if (dynamic_quant_) + size_ct += 4; + else if (quant_out_) + size_ct += 1; + return size_ct * num_token_ * hidden_size_; + } + size_t getTheoryIO(void) const override { + size_t data_dw, dw_int8, dw_float; + cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float); + cnnlGetSizeOfDataType(data_type_, &data_dw); + cnnlGetSizeOfDataType(CNNL_DTYPE_INT8, &dw_int8); + size_t size_io = data_dw * num_token_ * hidden_size_; + if (has_res_) size_io += data_dw * num_token_ * hidden_size_; + if (has_bias_) size_io += data_dw * hidden_size_; + size_io += hidden_size_ * data_dw * (1 + (norm_mode_ == "layernorm")); // gamma&beta + if (quant_out_) { + size_io += dw_int8 * num_token_ * hidden_size_; // output + size_io += dw_float * (hidden_size_ + dynamic_quant_ ? num_token_ : 0); // scale + } else { + size_io += data_dw * num_token_ * hidden_size_; + } + if (store_output_before_norm_) size_io += data_dw * num_token_ * hidden_size_; + return size_io; + } + cnnlDataType_t getCalcDtype(void) const override { return data_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct SmoothQuantTheory : public OpTheory { + int m_, ci_; + cnnlDataType_t data_type_; + bool dynamic_quant_; + std::string name_; + SmoothQuantTheory(int m, int ci, cnnlDataType_t data_type, bool dynamic_quant) + : OpTheory(), + m_(m), + ci_(ci), + data_type_(data_type), + dynamic_quant_(dynamic_quant), + name_("smooth_quant_online") {} + size_t getTheoryCalc(void) const override { + if (!dynamic_quant_) { + return 3 * (size_t)m_ * ci_; // convert + mul scale + convert + } + size_t calc_scale = m_; // 128 / maxvalue + size_t size_calc = (2 /*max + min*/ + 2 /*mul scale*/ + 2 /*convert*/) * (size_t)m_ * ci_; + return calc_scale + size_calc; + } + size_t getTheoryIO(void) const override { + size_t data_dw, dw_int8, dw_float; + cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dw_float); + cnnlGetSizeOfDataType(CNNL_DTYPE_INT8, &dw_int8); + cnnlGetSizeOfDataType(data_type_, &data_dw); + size_t scale_size = dw_float * (ci_ + dynamic_quant_ ? m_ : 0); + return scale_size + (data_dw + dw_int8) * (m_ * ci_); + } + cnnlDataType_t getCalcDtype(void) const override { return data_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct GroupAddBiasActiveTheory : public OpTheory { + int expert_size_, num_expand_token_, ci_, inner_size_; + bool has_gate_, is_addbias_; + cnnlDataType_t data_type_; + std::string act_mode_, name_; + GroupAddBiasActiveTheory(int expert_size, + int num_expand_token, + int inner_size, + bool has_gate, + bool is_addbias, + cnnlDataType_t data_type, + std::string act_mode) + : OpTheory(), + expert_size_(expert_size), + num_expand_token_(num_expand_token), + inner_size_(inner_size), + has_gate_(has_gate), + is_addbias_(is_addbias), + data_type_(data_type), + act_mode_(act_mode), + name_("group_add_bias_active") {} + size_t getTheoryCalc(void) const override { + size_t size_vec = 0, act_size = 0; + if (is_addbias_) size_vec += (size_t)num_expand_token_ * inner_size_ * (1 + has_gate_); + if (act_mode_ == "silu") { // silu = sigmoid + mul + act_size += 2 * (size_t)num_expand_token_ * inner_size_; + } else { // gelu = 1 vector_compution + act_size += (size_t)num_expand_token_ * inner_size_; + } + if (has_gate_) size_vec += (size_t)num_expand_token_ * inner_size_; + return size_vec + act_size; + } + size_t getTheoryIO(void) const override { + size_t data_dw = 0; + cnnlGetSizeOfDataType(data_type_, &data_dw); + size_t bias_io = is_addbias_ ? expert_size_ * ci_ * (1 + has_gate_) : 0; + return data_dw * (bias_io + num_expand_token_ * inner_size_ * (2 + has_gate_)); + } + cnnlDataType_t getCalcDtype(void) const override { return data_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +struct MoeCombineResultTheory : public OpTheory { + int num_token_, topk_, ci_, expert_size_; + bool has_bias_, has_res_; + cnnlDataType_t data_type_; + std::string name_; + MoeCombineResultTheory(int num_token, + int topk, + int ci, + int expert_size, + bool has_bias, + bool has_res, + cnnlDataType_t data_type) + : OpTheory(), + num_token_(num_token), + topk_(topk), + ci_(ci), + expert_size_(expert_size), + has_bias_(has_bias), + has_res_(has_res), + data_type_(data_type), + name_("moe_combine_result") {} + size_t getTheoryCalc(void) const override { + size_t size_calc = + (size_t)((has_bias_ + 3 /*cvt+fma+cvt*/) * topk_ + has_res_) * num_token_ * ci_; + return size_calc; + } + size_t getTheoryIO(void) const override { + size_t data_dw = 0; + cnnlGetSizeOfDataType(data_type_, &data_dw); + size_t bias_io = has_bias_ ? expert_size_ * ci_ : 0; + size_t res_io = has_res_ ? num_token_ * ci_ : 0; + return data_dw * (bias_io + res_io + num_token_ * topk_ * ci_ /*input*/ + + num_token_ * ci_ /*output*/ + num_token_ * topk_ /*reduce_weight*/); + } + cnnlDataType_t getCalcDtype(void) const override { return data_type_; } + const std::string &getOpName(void) const override { return name_; } +}; + +} // namespace torch_api +} // namespace tmo + +#endif // CSRC_TORCH_API_OP_THEORY_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/preload.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/preload.cpp new file mode 100644 index 0000000..bedf127 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/preload.cpp @@ -0,0 +1,25 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/preload.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { +void preload(const torch::Tensor &weight, const int64_t size) { + const torch_mlu::mlu::MLUGuard device_guard(weight.device()); + auto queue = torch_mlu::getCurMLUStream(); + + invokePreload(queue, weight.data_ptr(), weight.element_size() * weight.numel(), size); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_matmul.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_matmul.cpp new file mode 100644 index 0000000..d63c3be --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_matmul.cpp @@ -0,0 +1,115 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +at::Tensor quant_matmul(const at::Tensor &a_tensor, + const c10::optional &a_scale, + const c10::optional &a_zero, + const at::Tensor &b_tensor, + const c10::optional &b_scale, + const c10::optional &b_zero, + const c10::optional &bias, + const c10::optional &c_tensor, + const c10::optional &c_scale, + const c10::optional &c_zero, + const c10::optional &gemm_output_scale, + const c10::optional &gemm_output_zero, + const c10::optional &data_type, + const c10::optional &d, + const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + int64_t quant_bit_size, + const std::string &act_mode, + bool use_hp_active, + double act_coef, + double alpha, + double beta, + bool trans_a, + bool trans_b) { + // check device and dtype + // only support weight only or smooth quant + TORCH_CHECK(quant_algo == "weight_only" || quant_algo == "smooth_quant", "illegal quant mode"); + TORCH_CHECK(quant_bit_size == 4 || quant_bit_size == 8, "illegal quant bit size"); + TORCH_CHECK(trans_a == false && trans_b == true, "illegal trans mode"); + // create output tensor + auto tensor_options = a_tensor.options(); + std::vector output_shape = {a_tensor.sizes()[0], b_tensor.sizes()[0]}; + at::Tensor output_tensor; + if (d.has_value()) { + output_tensor = d.value(); + } else if (data_type.has_value()) { + auto dtype = data_type.value(); + TORCH_CHECK(dtype == "float" || dtype == "half" || dtype == "bfloat16", + "data type must be 'float', 'half' or 'bfloat16'") + auto torch_dtype = str2TorchDtype(dtype); + output_tensor = at::empty(output_shape, tensor_options.dtype(torch_dtype)); + } else { + output_tensor = at::empty(output_shape, tensor_options); + } + + // create tensor desc + auto descs = createTensorDescs( + {a_tensor, a_scale.value_or(torch::Tensor()), a_zero.value_or(torch::Tensor()), b_tensor, + b_scale.value_or(torch::Tensor()), b_zero.value_or(torch::Tensor()), + bias.value_or(torch::Tensor()), c_tensor.value_or(torch::Tensor()), + c_scale.value_or(torch::Tensor()), c_zero.value_or(torch::Tensor()), output_tensor, + gemm_output_scale.value_or(torch::Tensor()), gemm_output_zero.value_or(torch::Tensor())}); + + // get && set Op desc + auto compute_dtype = getCnnlDataType(output_tensor.scalar_type()); + auto op_desc = tmo::op_desc::QuantMatmulDesc(quant_algo, a_quant_layout, b_quant_layout, + quant_bit_size, compute_dtype, act_mode, + use_hp_active, act_coef, trans_a, trans_b); + + auto handle = torch_mlu::getCurrentHandle(); + // get workspace size + size_t workspace_size = 0; + CNNL_CHECK_FATAL(cnnlGetLLMQuantMatmulWorkspaceSize(handle, op_desc, descs[0].get(), + descs[3].get(), descs[7].get(), + descs[10].get(), &workspace_size)); + auto workspace = + at::empty({static_cast(workspace_size)}, a_tensor.options().dtype(at::kByte)); + + // run forward + const cnnlTensorDescriptor_t a_descs[] = {descs[0].get(), descs[1].get(), descs[2].get()}; + const cnnlTensorDescriptor_t b_descs[] = {descs[3].get(), descs[4].get(), descs[5].get()}; + const cnnlTensorDescriptor_t c_descs[] = {descs[7].get(), descs[8].get(), descs[9].get()}; + const cnnlTensorDescriptor_t d_descs[] = {descs[10].get(), nullptr, nullptr}; + const void *a_tensors[] = {getAtTensorPtr(a_tensor), getAtTensorPtr(a_scale), + getAtTensorPtr(a_zero)}; + const void *b_tensors[] = {getAtTensorPtr(b_tensor), getAtTensorPtr(b_scale), + getAtTensorPtr(b_zero)}; + const void *c_tensors[] = {getAtTensorPtr(c_tensor), getAtTensorPtr(c_scale), + getAtTensorPtr(c_zero)}; + void *d_tensors[] = {getAtTensorPtr(output_tensor), nullptr, nullptr}; + float alpha_arr[1] = {(float)alpha}; + float beta_arr[1] = {(float)beta}; + bool has_res = c_tensor.has_value(); + QuantMatmulTheory obj(a_tensor.size(0), a_tensor.size(1), b_tensor.size(0), quant_bit_size, + has_res, compute_dtype, quant_algo); + cnpxPush(obj); + CNNL_CHECK_FATAL(cnnlLLMQuantMatmul( + handle, op_desc, nullptr, alpha_arr, a_descs, a_tensors, b_descs, b_tensors, nullptr, + beta_arr, c_descs, c_tensors, descs[6].get(), getAtTensorPtr(bias), descs[11].get(), + getAtTensorPtr(gemm_output_scale), descs[12].get(), getAtTensorPtr(gemm_output_zero), + getAtTensorPtr(workspace), workspace_size, d_descs, d_tensors)); + cnpxPop(); + return output_tensor; +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_linear_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_linear_cache.cpp new file mode 100644 index 0000000..1ae0881 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_linear_cache.cpp @@ -0,0 +1,208 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/quant_to_linear_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { + +void quant_to_linear_cache( + const at::Tensor &key, // [total_seqlen, num_heads, head_size] + const c10::optional &value, // or [bs, max_context_len, num_heads, head_size] + at::Tensor &key_cache, // [max_bs, num_heads, cache_memory_len, head_size] + const c10::optional + &value_cache, // [max_bs, num_heads, cache_memory_len, head_size] + at::Tensor &key_cache_scale, // [max_bs, num_heads, cache_memory_len] + const c10::optional &value_cache_scale, + const at::Tensor &context_lengths, + const int64_t max_context_len, + bool packed, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seqlen_offset, + const int64_t quant_bit) { + // check device and dtype + TORCH_CHECK(key.scalar_type() == torch::kFloat32 || key.scalar_type() == torch::kFloat16 || + key.scalar_type() == torch::kBFloat16, + "key type need be torch::kFloat32, torch::kFloat16 or torch::kBFloat16"); + TORCH_CHECK(key_cache.scalar_type() == torch::kInt8, "key_cache type need be torch::kInt8"); + TORCH_CHECK(key_cache_scale.scalar_type() == torch::kFloat32, + "key_cache_scale type need be torch::kFloat32"); + if (value.has_value() || value_cache.has_value() || value_cache_scale.has_value()) { + TORCH_CHECK(value.has_value() && value_cache.has_value() && value_cache_scale.has_value(), + "value_cache, value_cache_scale, value must all have value.") + } + if (value_cache.has_value()) { + TORCH_CHECK(value_cache.value().scalar_type() == torch::kInt8, + "value_cache type need be torch::kInt8"); + TORCH_CHECK(value_cache_scale.value().scalar_type() == torch::kFloat32, + "value_cache_scale type need be torch::kFloat32"); + } + if (value.has_value()) { + TORCH_CHECK(value.value().scalar_type() == key.scalar_type(), + "value type need be same with key"); + } + TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32, + "context_lengths type need be torch::kInt32"); + const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0); + + const int max_bs = key_cache.size(0); + const int num_heads = key_cache.size(1); + const int cache_mem_len = key_cache.size(2); + const int head_size = key.size(-1); + // check quantize param + TORCH_CHECK(quant_bit == 8 || quant_bit == 4, "quant_bit must be 8 or 4"); + if (key_cache_scale.dim() == 4) { + TORCH_CHECK(key_cache_scale.size(-1) > 0 && head_size % key_cache_scale.size(-1) == 0, + "key_cache_scale.size(-1) must be legal"); + } + int group_size = key_cache_scale.dim() == 3 ? head_size : head_size / key_cache_scale.size(-1); + + const int total_seqlen = packed ? key.size(0) : batch_size * key.size(-2); + + const size_t context_bs_stride = packed ? 0 : key.stride(0); + const size_t context_head_stride = packed ? key.stride(1) : key.stride(2); + const size_t context_seq_stride = packed ? key.stride(0) : key.stride(1); + const size_t cache_bs_stride = key_cache.stride(0); + const size_t cache_head_stride = key_cache.stride(1); + const size_t key_cache_seq_stride = key_cache.stride(2); + size_t value_cache_seq_stride = head_size; + const size_t cache_scale_bs_stride = key_cache_scale.stride(0); + const size_t cache_scale_head_stride = key_cache_scale.stride(1); + + if (key_cache_scale.dim() == 3) { + CHECK_SHAPE(key_cache_scale, max_bs, num_heads, cache_mem_len); + } else { + CHECK_SHAPE(key_cache_scale, max_bs, num_heads, cache_mem_len, head_size / group_size); + } + if (quant_bit == 8) { + TORCH_CHECK(key_cache.size(-1) == head_size, "last dim of key_cache should be head_size"); + } else { + TORCH_CHECK(key_cache.size(-1) == head_size / 2, "last dim of key_cache should be head_size"); + } + if (value_cache.has_value()) { + if (quant_bit == 8) { + CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size); + } else if (quant_bit == 4) { + CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len / 2, head_size); + } + if (key_cache_scale.dim() == 3) { + CHECK_SHAPE(value_cache_scale.value(), max_bs, num_heads, cache_mem_len); + } else { + CHECK_SHAPE(value_cache_scale.value(), max_bs, num_heads, cache_mem_len, + head_size / group_size); + } + } + + const int64_t device_id = key.get_device(); + TORCH_CHECK(key_cache.get_device() == device_id, + "key_cache tensor device index is not same, original index: ", device_id, + "now index is: ", key_cache.get_device()); + TORCH_CHECK(key_cache_scale.get_device() == device_id, + "key_cache_scale tensor device index is not same, original index: ", device_id, + "now index is: ", key_cache_scale.get_device()); + + bool has_context_seq_offset = context_seq_offset.has_value(); + const void *context_seq_offset_ptr = nullptr; + if (has_context_seq_offset) { + CHECK_SHAPE(context_seq_offset.value(), batch_size); + TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32, + "context_seq_offset type need be torch::kInt32"); + TORCH_CHECK(context_seq_offset.value().get_device() == device_id, + "context_seq_offset tensor device index is not same, original index: ", device_id, + "now index is: ", context_seq_offset.value().get_device()); + context_seq_offset_ptr = context_seq_offset.value().data_ptr(); + } + + bool has_cache_bs_id = cache_bs_id.has_value(); + const void *cache_bs_id_ptr = nullptr; + if (has_cache_bs_id) { + CHECK_SHAPE(cache_bs_id.value(), batch_size); + TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32, + "cache_bs_id type need be torch::kInt32"); + TORCH_CHECK(cache_bs_id.value().get_device() == device_id, + "cache_bs_id tensor device index is not same, original index: ", device_id, + "now index is: ", cache_bs_id.value().get_device()); + cache_bs_id_ptr = cache_bs_id.value().data_ptr(); + } + + bool has_cache_seqlen_offset = cache_seqlen_offset.has_value(); + const void *cache_seqlen_offset_ptr = nullptr; + if (has_cache_seqlen_offset) { + CHECK_SHAPE(cache_seqlen_offset.value(), batch_size); + TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32, + "cache_seqlen_offset type need be torch::kInt32"); + TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id, + "cache_seqlen_offset tensor device index is not same, original index: ", device_id, + "now index is: ", cache_seqlen_offset.value().get_device()); + cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr(); + } + + // check contiguous + TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous."); + TORCH_CHECK(key_cache_scale.is_contiguous(), "key_cache_scale tensor must be contiguous."); + if (value_cache.has_value()) { + TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous."); + TORCH_CHECK(value_cache.value().get_device() == device_id, + "value_cache tensor device index is not same, original index: ", device_id, + "now index is: ", value_cache.value().get_device()); + TORCH_CHECK(value_cache_scale.value().is_contiguous(), + "value_cache_scale tensor must be contiguous."); + TORCH_CHECK(value_cache_scale.value().get_device() == device_id, + "value_cache_scale tensor device index is not same, original index: ", device_id, + "now index is: ", value_cache_scale.value().get_device()); + } + + // check shape + if (packed) { + CHECK_SHAPE(key, total_seqlen, num_heads, head_size); + if (value.has_value()) { + CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size); + } + CHECK_SHAPE(context_lengths, batch_size + 1); + } else { + CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size); + if (value.has_value()) { + CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size); + } + CHECK_SHAPE(context_lengths, batch_size); + } + + if (value.has_value()) { + for (int i = 0; i < key.dim(); i++) { + TORCH_CHECK(value.value().stride(i) == key.stride(i), + "key and value must have same stride along axi ", i); + } + } + + TORCH_CHECK(key_cache.size(0) >= batch_size, + "first dim of key_cache must be great than or equal to batch_size"); + TORCH_CHECK(key_cache.dim() == 4, "dim of key_cache must be 4."); + + const torch_mlu::mlu::MLUGuard device_guard(key.device()); + auto data_dtype = getCnnlDataType(key.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + + // run forward + invokeQuantToLinearCache( + queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache), + getAtTensorPtr(key_cache_scale), getAtTensorPtr(value_cache_scale), cache_bs_id_ptr, + cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr, + getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size, + (int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride, + context_seq_stride, cache_bs_stride, cache_head_stride, key_cache_seq_stride, + value_cache_seq_stride, cache_scale_bs_stride, cache_scale_head_stride, packed, + (int)quant_bit, group_size); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_paged_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_paged_cache.cpp new file mode 100644 index 0000000..4681008 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/quant_to_paged_cache.cpp @@ -0,0 +1,81 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/quant_to_paged_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { +void quant_to_paged_cache( + const torch::Tensor &k, // [num_tokens, num_heads, head_dim] + const c10::optional &v, // [num_tokens, num_heads, head_dim] + torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim] + const c10::optional &v_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor &k_cache_scale, // [num_blocks, num_heads, block_size] + const c10::optional &v_cache_scale, // [num_blocks, num_heads, block_size] + const torch::Tensor &slot_mapping) { + // 1. check device and tensor type + TORCH_CHECK(slot_mapping.dtype() == torch::kInt32, "slot_mapping type need be torch::kInt32"); + TORCH_CHECK(slot_mapping.get_device() == k.get_device(), + "Tensor device index is not same, original index: ", k.get_device(), + "now index is: ", slot_mapping.get_device()); + + // 2. check shape + const int num_tokens = k.size(0); + const int num_heads = k.size(1); + const int head_dim = k.size(2); + const int block_size = k_cache.size(2); + const int head_size = k_cache.size(3); + const int num_blocks = k_cache.size(0); + + CHECK_SHAPE(k, num_tokens, num_heads, head_dim); + CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_dim); + CHECK_SHAPE(k_cache_scale, num_blocks, num_heads, block_size); + TORCH_CHECK(v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value(), + "v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().") + if (v.has_value()) { + CHECK_SHAPE(v.value(), num_tokens, num_heads, head_dim); + CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_dim); + CHECK_SHAPE(v_cache_scale.value(), num_blocks, num_heads, block_size); + } + CHECK_SHAPE(slot_mapping, num_tokens); + + // 3. check strides + TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous."); + TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous."); + TORCH_CHECK(k.stride(-2) == head_dim, "k second dim must be contiguous."); + if (v.has_value()) { + TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous."); + TORCH_CHECK(v.value().stride(-2) == head_dim, "v second dim must be contiguous."); + } + + int64_t kv_cache_range = + (int64_t)num_blocks * num_heads * block_size * head_dim * k_cache.element_size(); + TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G."); + + const torch_mlu::mlu::MLUGuard device_guard(k.device()); + auto queue = torch_mlu::getCurMLUStream(); + cnnlDataType_t dtype = getCnnlDataType(k.scalar_type()); + int32_t key_stride0 = static_cast(k.stride(0)); + int32_t value_stride0 = 1; + if (v.has_value()) { + value_stride0 = static_cast(v.value().stride(0)); + } + + TMO_KERNEL_CHECK_FATAL(tmo::invokeQuantToPagedCache( + queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache), + getAtTensorPtr(v_cache), getAtTensorPtr(k_cache_scale), getAtTensorPtr(v_cache_scale), + getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, num_heads, num_blocks, + block_size, head_size)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_linear_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_linear_cache.cpp new file mode 100644 index 0000000..60ea18a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_linear_cache.cpp @@ -0,0 +1,136 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/reshape_linear_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { +void reshape_linear_cache(const at::Tensor &key, + const c10::optional &value, + at::Tensor &key_cache, + const c10::optional &value_cache, + const at::Tensor &context_lengths, + const int64_t max_context_len, + bool packed, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seqlen_offset) { + // check device and dtype + checkTensorSameAttr(key, key_cache, value, value_cache); + + const int batch_size = packed ? context_lengths.size(0) - 1 : context_lengths.size(0); + + const int max_bs = key_cache.size(0); + const int num_heads = key_cache.size(1); + const int cache_mem_len = key_cache.size(2); + const int head_size = key_cache.size(3); + + const int total_seqlen = packed ? key.size(0) : batch_size * key.size(1); + + const int context_bs_stride = packed ? 0 : key.stride(0); + const int context_head_stride = packed ? key.stride(1) : key.stride(2); + const int context_seq_stride = packed ? key.stride(0) : key.stride(1); + const int cache_bs_stride = key_cache.stride(0); + const int cache_head_stride = key_cache.stride(1); + const int cache_seq_stride = key_cache.stride(2); + + TORCH_CHECK(context_lengths.scalar_type() == torch::kInt32, + "context_lengths type need be torch::kInt32"); + + if (value_cache.has_value()) { + CHECK_SHAPE(value_cache.value(), max_bs, num_heads, cache_mem_len, head_size); + } + + const int64_t device_id = key.get_device(); + bool has_context_seq_offset = context_seq_offset.has_value(); + const void *context_seq_offset_ptr = nullptr; + if (has_context_seq_offset) { + CHECK_SHAPE(context_seq_offset.value(), batch_size); + TORCH_CHECK(context_seq_offset.value().scalar_type() == torch::kInt32, + "context_seq_offset type need be torch::kInt32"); + TORCH_CHECK(context_seq_offset.value().get_device() == device_id, + "context_seq_offset tensor device index is not same, original index: ", device_id, + "now index is: ", context_seq_offset.value().get_device()); + context_seq_offset_ptr = context_seq_offset.value().data_ptr(); + } + + bool has_cache_bs_id = cache_bs_id.has_value(); + const void *cache_bs_id_ptr = nullptr; + if (has_cache_bs_id) { + CHECK_SHAPE(cache_bs_id.value(), batch_size); + TORCH_CHECK(cache_bs_id.value().scalar_type() == torch::kInt32, + "cache_bs_id type need be torch::kInt32"); + TORCH_CHECK(cache_bs_id.value().get_device() == device_id, + "cache_bs_id tensor device index is not same, original index: ", device_id, + "now index is: ", cache_bs_id.value().get_device()); + cache_bs_id_ptr = cache_bs_id.value().data_ptr(); + } + + bool has_cache_seqlen_offset = cache_seqlen_offset.has_value(); + const void *cache_seqlen_offset_ptr = nullptr; + if (has_cache_seqlen_offset) { + CHECK_SHAPE(cache_seqlen_offset.value(), batch_size); + TORCH_CHECK(cache_seqlen_offset.value().scalar_type() == torch::kInt32, + "cache_seqlen_offset type need be torch::kInt32"); + TORCH_CHECK(cache_seqlen_offset.value().get_device() == device_id, + "cache_seqlen_offset tensor device index is not same, original index: ", device_id, + "now index is: ", cache_seqlen_offset.value().get_device()); + cache_seqlen_offset_ptr = cache_seqlen_offset.value().data_ptr(); + } + + // check contiguous + TORCH_CHECK(key_cache.is_contiguous(), "key_cache tensor must be contiguous."); + if (value_cache.has_value()) { + TORCH_CHECK(value_cache.value().is_contiguous(), "value_cache tensor must be contiguous."); + } + + // check shape + if (packed) { + CHECK_SHAPE(key, total_seqlen, num_heads, head_size); + if (value.has_value()) { + CHECK_SHAPE(value.value(), total_seqlen, num_heads, head_size); + } + CHECK_SHAPE(context_lengths, batch_size + 1); + } else { + CHECK_SHAPE(key, batch_size, key.size(1), num_heads, head_size); + if (value.has_value()) { + CHECK_SHAPE(value.value(), batch_size, key.size(1), num_heads, head_size); + } + CHECK_SHAPE(context_lengths, batch_size); + } + + if (value.has_value()) { + for (int i = 0; i < key.dim(); i++) { + TORCH_CHECK(value.value().stride(i) == key.stride(i), + "key and value must have same stride along axis ", i); + } + } + + TORCH_CHECK(key_cache.size(0) >= batch_size, + "first dim of key_cache must be great than or equal to batch_size"); + TORCH_CHECK(key_cache.dim() == 4, "dim of key_cache must be 4."); + + const torch_mlu::mlu::MLUGuard device_guard(key.device()); + auto data_dtype = getCnnlDataType(key.scalar_type()); + auto queue = torch_mlu::getCurMLUStream(); + + // run forward + TMO_KERNEL_CHECK_FATAL(invokeReshapeLinearCache( + queue, getAtTensorPtr(key_cache), getAtTensorPtr(value_cache), cache_bs_id_ptr, + cache_seqlen_offset_ptr, getAtTensorPtr(key), getAtTensorPtr(value), context_seq_offset_ptr, + getAtTensorPtr(context_lengths), data_dtype, batch_size, num_heads, head_size, + (int)max_context_len, cache_mem_len, context_bs_stride, context_head_stride, + context_seq_stride, cache_bs_stride, cache_head_stride, cache_seq_stride, packed)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_paged_cache.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_paged_cache.cpp new file mode 100644 index 0000000..83b1218 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/reshape_paged_cache.cpp @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/reshape_paged_cache.mluh" +#include "torch_ops_api.h" + +namespace tmo { +namespace torch_api { +void reshape_paged_cache( + const torch::Tensor &k, // [num_tokens, num_heads, head_dim] + const c10::optional &v, // [num_tokens, num_heads, head_dim] + torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim] + const c10::optional &v_cache, // [num_blocks, num_heads, block_size, head_dim] + const torch::Tensor &slot_mapping) { + // 1. check device and tensor type + checkTensorSameAttr(k, v, k_cache, v_cache); + TORCH_CHECK(slot_mapping.dtype() == torch::kInt32 || slot_mapping.dtype() == torch::kLong, + "slot_mapping type need be torch::kInt32 or torch::Long"); + TORCH_CHECK(slot_mapping.get_device() == k.get_device(), + "Tensor device index is not same, original index: ", k.get_device(), + "now index is: ", slot_mapping.get_device()); + + // 2. check shape + const int num_tokens = k.size(0); + const int num_heads = k.size(1); + const int head_dim = k.size(2); + const int block_size = k_cache.size(2); + const int head_size = k_cache.size(3); + const int num_blocks = k_cache.size(0); + + CHECK_SHAPE(k, num_tokens, num_heads, head_dim); + CHECK_SHAPE(k_cache, num_blocks, num_heads, block_size, head_dim); + TORCH_CHECK(v.has_value() == v_cache.has_value(), "v.has_value() == v_cache.has_value().") + if (v.has_value()) { + CHECK_SHAPE(v.value(), num_tokens, num_heads, head_dim); + CHECK_SHAPE(v_cache.value(), num_blocks, num_heads, block_size, head_dim); + } + CHECK_SHAPE(slot_mapping, num_tokens); + + // 3. check strides + TORCH_CHECK(k_cache.is_contiguous(), "k_cache need be contiguous."); + TORCH_CHECK(slot_mapping.is_contiguous(), "slot_mapping need be contiguous."); + TORCH_CHECK(k.stride(-1) == 1, "k last dim must be contiguous."); + TORCH_CHECK(k.stride(-2) == head_dim, "k second dim must be contiguous."); + if (v.has_value()) { + TORCH_CHECK(v_cache.value().is_contiguous(), "v_cache need be contiguous."); + TORCH_CHECK(v.value().stride(-1) == 1, "v last dim must be contiguous."); + TORCH_CHECK(v.value().stride(-2) == head_dim, "v second dim must be contiguous."); + } + + // check large tensor + int64_t kv_cache_range = + (int64_t)num_blocks * num_heads * block_size * head_dim * k_cache.element_size(); + TORCH_CHECK(kv_cache_range <= UINT32_MAX, "The addressing range of kv_cache cannot exceed 4G."); + + const torch_mlu::mlu::MLUGuard device_guard(k.device()); + auto queue = torch_mlu::getCurMLUStream(); + cnnlDataType_t dtype = getCnnlDataType(k.scalar_type()); + int32_t key_stride0 = static_cast(k.stride(0)); + int32_t value_stride0 = 1; + if (v.has_value()) { + value_stride0 = static_cast(v.value().stride(0)); + } + + TMO_KERNEL_CHECK_FATAL(tmo::invokeReshapePagedCache( + queue, dtype, getAtTensorPtr(k), getAtTensorPtr(v), getAtTensorPtr(k_cache), + getAtTensorPtr(v_cache), getAtTensorPtr(slot_mapping), key_stride0, value_stride0, num_tokens, + num_heads, num_blocks, block_size, head_size)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/single_query_cached_kv_attn.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/single_query_cached_kv_attn.cpp new file mode 100644 index 0000000..cc08ab0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/single_query_cached_kv_attn.cpp @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "torch_ops_api.h" +#include "utils.h" +namespace tmo { +namespace torch_api { + +// q_ori, out Tensor shape is [batch, seq_len, head_num, head_size] +// if (pagedattn) k_cache, v_cache shape is [num_blocks, k_head_num, block_size, head_size] +// else [max_bs, head_num, max_seq_len, head_size] +// if(pagedattn) k_cache_quant_scale shape is [num_blocks, k_head_num, block_size] else [max_bs, +// head_num, max_seq_len] +// if (pagedattn) block_tables shape is [bs, max_nblock_per_seq] else [bs, 1] +// context_lens: true seqkv_len of each bacth +void single_query_cached_kv_attn( + const torch::Tensor &q_ori, + const torch::Tensor &k_cache, + const torch::Tensor &v_cache, + const torch::Tensor &output, + const torch::Tensor &block_tables, + const torch::Tensor &context_lens, // [batch] + const c10::optional &output_lse, + const c10::optional &k_cache_quant_scale, + const c10::optional &v_cache_quant_scale, + const c10::optional &alibi_slopes, // [bs, head_num] or [head_num] + int64_t max_context_len, + int64_t windows_size_left, + int64_t windows_size_right, + double softmax_scale, + bool return_lse, + int64_t kv_cache_quant_bit_size) { + // Check tensor type and tensor device. + cnnlQuantizeLayout_t cache_quant_layout = CNNL_QUANTIZE_NONE; + bool is_kv_quant = k_cache_quant_scale.has_value(); + bool has_alibi = alibi_slopes.has_value(); + + // Check Tensor Shape + int batch = q_ori.size(0); + int seq_q = q_ori.size(1); + int head_num = q_ori.size(2); + int qk_head_size = q_ori.size(3); + int k_head_num = k_cache.size(1); + int v_head_size = v_cache.size(-1); + int max_num_block_per_seq = block_tables.size(1); + int num_blocks = k_cache.size(0); + int block_size = k_cache.size(2); + TORCH_CHECK(windows_size_right < 0, "only support windows_size_right < 0 currently."); + TORCH_CHECK(head_num % k_head_num == 0, "num_heads need be mutiple of num_kv_heads."); + TORCH_CHECK( + kv_cache_quant_bit_size == 4 || kv_cache_quant_bit_size == 8 || kv_cache_quant_bit_size == -1, + "illegal quant bit size, only support 4, 8 or -1."); + CHECK_SHAPE(block_tables, batch, max_num_block_per_seq); + if (kv_cache_quant_bit_size == 4) { + int v_cache_len = PAD_UP_DIV(block_size, 2); + CHECK_SHAPE(k_cache, num_blocks, k_head_num, block_size, qk_head_size / 2); + CHECK_SHAPE(v_cache, num_blocks, k_head_num, v_cache_len, v_head_size); + } else { + CHECK_SHAPE(k_cache, num_blocks, k_head_num, block_size, qk_head_size); + CHECK_SHAPE(v_cache, num_blocks, k_head_num, block_size, v_head_size); + } + TORCH_CHECK(q_ori.stride(-1) == 1 && q_ori.stride(-2) == qk_head_size, + "q last two dim need be contiguous."); + TORCH_CHECK(k_cache.is_contiguous() && v_cache.is_contiguous(), + "k_cache and v_cache need be contiguous."); + TORCH_CHECK(isMlu(q_ori), "q_ori need be mlu tensor."); + + checkTensorSameAttr(q_ori, k_cache, v_cache, output, block_tables, + context_lens, k_cache_quant_scale, v_cache_quant_scale, + output_lse, alibi_slopes); + if (is_kv_quant) { + TORCH_CHECK(k_cache_quant_scale.value().dim() == 2 || k_cache_quant_scale.value().dim() == 3 || + k_cache_quant_scale.value().dim() == 4, + "k_cache_quant_scale must be 2d or 3d or 4d."); + TORCH_CHECK(k_cache_quant_scale.value().dim() == v_cache_quant_scale.value().dim(), + "the dim of k_cache_quant_scale and v_cache_quant_scale must be euqal."); + if (k_cache_quant_scale.value().dim() == 2) { + CHECK_SHAPE(k_cache_quant_scale.value(), k_head_num, qk_head_size); + CHECK_SHAPE(v_cache_quant_scale.value(), k_head_num, v_head_size); + cache_quant_layout = CNNL_QUANTIZE_PER_CHANNEL; + } else if (k_cache_quant_scale.value().dim() == 3) { + CHECK_SHAPE(k_cache_quant_scale.value(), num_blocks, k_head_num, block_size); + CHECK_SHAPE(v_cache_quant_scale.value(), num_blocks, k_head_num, block_size); + cache_quant_layout = CNNL_QUANTIZE_PER_TOKEN; + } else { + CHECK_SHAPE(k_cache_quant_scale.value(), num_blocks, k_head_num, block_size, 1); + CHECK_SHAPE(v_cache_quant_scale.value(), num_blocks, k_head_num, block_size, 1); + cache_quant_layout = CNNL_QUANTIZE_PER_TOKEN; + } + } + TORCH_CHECK( + block_tables.scalar_type() == torch::kInt32 || block_tables.scalar_type() == torch::kLong, + "block_tables type need be torch::kInt32 or torch::kLong."); + // Check context_lens + TORCH_CHECK(context_lens.dtype() == torch::kInt32, "context_lens type need be torch::kInt32."); + TORCH_CHECK(context_lens.is_contiguous(), "context_lens need be contiguous."); + + if (has_alibi) { + CHECK_SHAPE(alibi_slopes.value(), batch, head_num); + } + if (return_lse) { + TORCH_CHECK(seq_q == 1, "return lse only support seq_q = 1 currently."); + CHECK_SHAPE(output_lse.value(), batch, head_num, seq_q); + } + if (windows_size_left >= 0) { + windows_size_left = windows_size_left + 1; // would be removed next version + } + + // Convert torch tensor to tensor descs + auto descs = createTensorDescs( + {q_ori, k_cache, v_cache, k_cache_quant_scale.value_or(at::Tensor()), + v_cache_quant_scale.value_or(at::Tensor()), context_lens, block_tables, + alibi_slopes.value_or(at::Tensor()), output, output_lse.value_or(at::Tensor())}); + + if (kv_cache_quant_bit_size == 4) { + cnnlDataType_t data_type = CNNL_DTYPE_INT4X2; + // k_cache + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[1].get(), CNNL_LAYOUT_ARRAY, data_type, + k_cache.sizes().size(), k_cache.sizes().data())); + // v_cache + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[2].get(), CNNL_LAYOUT_ARRAY, data_type, + v_cache.sizes().size(), v_cache.sizes().data())); + } + + // Get current handle. + const torch_mlu::mlu::MLUGuard device_guard(q_ori.device()); + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + // Get workspace size and malloc workspace. + size_t workspace_size = 0; + cnnlSingleQueryCachedKVAttnDescriptor_t att_desc; + CNNL_CHECK_FATAL(cnnlGetSingleQueryCachedKVAttnWorkspaceSize_v2( + handle, att_desc, descs[0].get(), descs[1].get(), descs[2].get(), (int)max_context_len, + &workspace_size)); + auto workspace = + at::empty({static_cast(workspace_size)}, q_ori.options().dtype(at::kByte)); + + cnnlDataType_t data_dtype = getCnnlDataType(q_ori.scalar_type()); + SingleQueryCachedKvAttnTheory obj(batch, seq_q, head_num, k_head_num, num_blocks, block_size, + qk_head_size, v_head_size, kv_cache_quant_bit_size, + cache_quant_layout, data_dtype); + cnpxPush(obj); + // call cnnl extra op. + CNNL_CHECK_FATAL(cnnlSingleQueryCachedKVAttn_v3( + handle, att_desc, descs[0].get(), getAtTensorPtr(q_ori), descs[1].get(), + getAtTensorPtr(k_cache), descs[2].get(), getAtTensorPtr(v_cache), descs[3].get(), + getAtTensorPtr(k_cache_quant_scale), descs[4].get(), getAtTensorPtr(v_cache_quant_scale), + descs[5].get(), getAtTensorPtr(context_lens), descs[6].get(), getAtTensorPtr(block_tables), + descs[7].get(), getAtTensorPtr(alibi_slopes), cache_quant_layout, (int)max_context_len, + windows_size_left, windows_size_right, softmax_scale, return_lse, getAtTensorPtr(workspace), + workspace_size, descs[8].get(), getAtTensorPtr(output), descs[9].get(), + getAtTensorPtr(output_lse))); + cnpxPop(); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/smooth_quant.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/smooth_quant.cpp new file mode 100644 index 0000000..83f0ab9 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/smooth_quant.cpp @@ -0,0 +1,112 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { + +void smooth_quant(const at::Tensor &input, + const at::Tensor &input_scale, + const at::Tensor &output, + const at::Tensor &output_scale, + const c10::optional &input_zero, + const c10::optional &token_count, + const c10::optional &gather_index, + const c10::optional &gather_index_start_position, + const std::string &quant_mode, + const bool &dynamic_quant) { + bool is_pertoken = quant_mode == "per_token"; + CHECK_TENSOR_CONTIGUOUS(input_scale) + if (output_scale.dim() > 0) { + CHECK_TENSOR_CONTIGUOUS(output_scale) + } + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(input_zero) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(token_count) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_index) + CHECK_OPTIONAL_TENSOR_CONTIGUOUS(gather_index_start_position) + + TORCH_CHECK(isMlu(input), "input must be mlu tensor."); + checkTensorSameAttr(input, input_scale, input_zero, token_count, gather_index, + gather_index_start_position); + + TORCH_CHECK(quant_mode == "per_token" || quant_mode == "per_tensor", + "quant_mode must be 'per_token' or per_tensor.") + if (dynamic_quant) { + TORCH_CHECK(is_pertoken, "only support per_token if dynamic_quant == true") + TORCH_CHECK(output_scale.numel() > 0, "invalid output_scale if dynamic_quant = true") + } else { + TORCH_CHECK(output_scale.numel() <= 0, "not support output_scale if dynamic_quant = false") + } + if (gather_index.has_value() || token_count.has_value()) { + TORCH_CHECK(input.dim() == 2, "input.dim() == 2 if has gather_index or token_count") + } + TORCH_CHECK(input.dim() >= 2, "input.dim() >= 2") + TORCH_CHECK(output.dim() >= 2, "output.dim() >= 2") + if (!gather_index.has_value()) { + TORCH_CHECK(output.sizes() == input.sizes(), "input and output must have the same shape") + } + if (output_scale.numel() > 0) { + if (!gather_index.has_value()) { + TORCH_CHECK(std::vector(output_scale.sizes().begin(), output_scale.sizes().end()) == + std::vector(input.sizes().begin(), input.sizes().end() - 1), + "output_scale_shape must be equal to input_shape[0:-1]") + } + } + if (gather_index_start_position.has_value()) { + TORCH_CHECK(gather_index.has_value(), + "gather_index must exist if gather_index_start_position has value") + TORCH_CHECK(gather_index.value().dim() == 1, "gather_index.dim() == 1") + } + + const torch_mlu::mlu::MLUGuard device_guard(input.device()); + // flatten input + at::Tensor input_flat = input.flatten(0, input.dim() - 2); + TORCH_CHECK(input_flat.data_ptr() == input.data_ptr(), "check the input strides") + // flatten output + at::Tensor output_flat = output.flatten(0, output.dim() - 2); + TORCH_CHECK(output_flat.data_ptr() == output.data_ptr(), "check the output stride") + // flatten output_scale + at::Tensor output_scale_flat; + if (dynamic_quant) { + output_scale_flat = output_scale.flatten(0, -1); + } + + // Convert torch tensor to tensor descriptor + auto descs = createTensorDescs( + {input_flat, input_scale, input_zero.value_or(at::Tensor()), + token_count.value_or(at::Tensor()), gather_index.value_or(at::Tensor()), + gather_index_start_position.value_or(at::Tensor()), output_flat, output_scale_flat}); + + auto input_shape = input.sizes(); + int64_t in_channel = input_shape.back(); + int64_t m = + std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies()); + SmoothQuantTheory obj(m, in_channel, getCnnlDataType(input.scalar_type()), dynamic_quant); + cnpxPush(obj); + CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4( + torch_mlu::getCurrentHandle(), nullptr /*smooth_quant_online_desc*/, descs[0].get(), + getAtTensorPtr(input_flat), /*input*/ + descs[1].get(), getAtTensorPtr(input_scale), /*input_scale*/ + descs[2].get(), getAtTensorPtr(input_zero), /*input_zero*/ + descs[3].get(), getAtTensorPtr(token_count), /*token_count*/ + descs[4].get(), getAtTensorPtr(gather_index), /*gather_index*/ + descs[5].get(), getAtTensorPtr(gather_index_start_position), /*gather_index_start_position*/ + descs[6].get(), getAtTensorPtr(output_flat), /*output*/ + descs[7].get(), getAtTensorPtr(output_scale_flat), /*output_scale*/ + nullptr, nullptr, /*output_zero*/ + nullptr, 0 /*worksapce*/)); + cnpxPop(); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/swap_blocks.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/swap_blocks.cpp new file mode 100644 index 0000000..9f9483c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/swap_blocks.cpp @@ -0,0 +1,65 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/swap_blocks.mluh" +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { +void swap_blocks(torch::Tensor &dst, + const torch::Tensor &src, + const c10::Dict &block_mapping_dict) { + std::map block_mapping; + for (const auto &item : block_mapping_dict) { + block_mapping[item.key()] = item.value(); + } + + // check shape + TORCH_CHECK(src[0].numel() == dst[0].numel(), "the block_size of src and dst are not the same.") + + // check dtype + TORCH_CHECK(src.dtype() == dst.dtype(), "the data type of src and dst are not the same.") + TORCH_CHECK(src.dtype() == torch::kInt8 || src.dtype() == torch::kUInt8 || + src.dtype() == torch::kInt16 || src.dtype() == torch::kInt32 || + src.dtype() == torch::kLong || src.dtype() == torch::kFloat16 || + src.dtype() == torch::kFloat32 || src.dtype() == torch::kBFloat16, + "data type only supports torch::kInt8, torch::kUInt8, torch::kInt16, torch::kInt32, " + "torch::kLong, torch::kFloat16, torch::kFloat32 and torch::kBFloat16"); + + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + torch::Device mlu_device = src_device; + cnrtMemTransDir_t memcpy_type; + bool src_is_mlu = isMlu(src); + bool dst_is_mlu = isMlu(dst); + + if (src_is_mlu && dst_is_mlu) { + TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same MLU"); + memcpy_type = cnrtMemcpyDevToDev; + } else if (src_is_mlu && dst_device.is_cpu()) { + memcpy_type = cnrtMemcpyDevToHost; + } else if (src_device.is_cpu() && dst_is_mlu) { + memcpy_type = cnrtMemcpyHostToDev; + mlu_device = dst_device; + } else { + TORCH_CHECK(false, "Invalid device combination"); + } + + const torch_mlu::mlu::MLUGuard device_guard(mlu_device); + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + TMO_KERNEL_CHECK_FATAL(invokeSwapBlocksKernel(handle, dst.data_ptr(), src.data_ptr(), + block_size_in_bytes, memcpy_type, block_mapping)); +} +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/torch_ops_api.h b/torch_mlu_ops-v1.3.2/csrc/torch_api/torch_ops_api.h new file mode 100644 index 0000000..66e8a5f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/torch_ops_api.h @@ -0,0 +1,472 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CSRC_TORCH_API_TORCH_OPS_API_H_ +#define CSRC_TORCH_API_TORCH_OPS_API_H_ + +#include +#include +#include +#include +#include "op_theory.h" +#include "ops/kernel_api.h" +#include "torch/extension.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { + +void fused_layernorm(const at::Tensor &input, + const at::Tensor &out, + const c10::optional &residual, + const c10::optional &gamma, + const c10::optional &beta, + const c10::optional &bias, + const c10::optional &quant_scale, + const c10::optional &residual_out, + const c10::optional &smooth_quant_scale, + const std::string &norm_mode, + double eps, + bool store_output_before_norm, + bool dynamic_quant); + +std::vector attention_project(const at::Tensor &input, + const at::Tensor &q_weight, + const c10::optional &q_bias, + const c10::optional &k_weight, + const c10::optional &k_bias, + const c10::optional &v_weight, + const c10::optional &v_bias, + const c10::optional &norm_weight, + const c10::optional &norm_bias, + const c10::optional &residual, + const std::string &out_layout, + int64_t head_size, + double eps, + double alpha, + double beta, + bool norm_out); + +at::Tensor ffn(const at::Tensor &input, + const at::Tensor &up_fc_weight, + const c10::optional &up_fc_bias, + const at::Tensor &down_proj_weight, + const c10::optional &down_proj_bias, + const c10::optional &gate_up_proj_weight, + const c10::optional &gate_up_proj_bias, + const c10::optional &layernorm_weight, + const c10::optional &layernorm_bias, + const std::string &act_mode, + const std::string &residual_is, + double eps, + double alpha, + double beta); + +void flash_attention(const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const c10::optional &output_lse, + const c10::optional &cu_seq_lens_q, + const c10::optional &cu_seq_lens_kv, + const c10::optional &alibi_slope, + const c10::optional &attn_bias, + const c10::optional &k_cache_quant_scale, + const c10::optional &v_cache_quant_scale, + const c10::optional &block_tables, + const int64_t max_seq_len_q, + const int64_t max_seq_len_kv, + const double softmax_scale, + const bool is_causal, + const int64_t window_size_left, + const int64_t window_size_right, + const std::string &compute_dtype, + bool return_lse); + +void single_query_cached_kv_attn( + const torch::Tensor &q_ori, + const torch::Tensor &k_cache, + const torch::Tensor &v_cache, + const torch::Tensor &output, + const torch::Tensor &block_tables, + const torch::Tensor &context_lens, // [batch] + const c10::optional &output_lse, + const c10::optional &k_cache_quant_scale, + const c10::optional &v_cache_quant_scale, + const c10::optional &alibi_slopes, // [bs, head_num] or [head_num] + int64_t max_context_len, + int64_t windows_size_left, + int64_t windows_size_right, + double softmax_scale, + bool return_lse, + int64_t kv_cache_quant_bit_size); + +void reshape_linear_cache(const at::Tensor &key, + const c10::optional &value, + at::Tensor &key_cache, + const c10::optional &value_cache, + const at::Tensor &context_lengths, + const int64_t max_context_len, + bool packed, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seqlen_offset); + +void reshape_paged_cache(const torch::Tensor &k, + const c10::optional &v, + torch::Tensor &k_cache, + const c10::optional &v_cache, + const torch::Tensor &slot_mapping); + +void quant_to_paged_cache(const torch::Tensor &k, + const c10::optional &v, + torch::Tensor &k_cache, + const c10::optional &v_cache, + torch::Tensor &k_cache_scale, + const c10::optional &v_cache_scale, + const torch::Tensor &slot_mapping); + +void offline_quant_to_paged_cache(const torch::Tensor &k, + const c10::optional &v, + const torch::Tensor &k_cache_scale, + const c10::optional &v_cache_scale, + const torch::Tensor &slot_mapping, + torch::Tensor &k_cache, + const c10::optional &v_cache); + +void quant_to_linear_cache(const at::Tensor &key, + const c10::optional &value, + at::Tensor &key_cache, + const c10::optional &value_cache, + at::Tensor &key_cache_scale, + const c10::optional &value_cache_scale, + const at::Tensor &context_lengths, + const int64_t max_context_len, + bool packed, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seqlen_offset, + const int64_t quant_bit); + +void offline_quant_to_linear_cache(const at::Tensor &key, + const c10::optional &value, + at::Tensor &key_cache, + const c10::optional &value_cache, + const at::Tensor &key_cache_scale, + const c10::optional &value_cache_scale, + const at::Tensor &context_lengths, + const int64_t max_context_len, + const int64_t quant_mode, + const bool packed, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seqlen_offset); + +void apply_rotary(const torch::Tensor &input, + const torch::Tensor &sin_cache, + const torch::Tensor &cos_cache, + const c10::optional &position_ids, + const c10::optional &cu_seqlens, + bool interleaved, + bool discrete, + bool dynamic_ntk, + int64_t max_seqlen); + +void swap_blocks(torch::Tensor &dst, + const torch::Tensor &src, + const c10::Dict &block_mapping); + +void copy_blocks(const std::vector &k_caches, + const std::vector &v_caches, + const c10::Dict> &block_mapping); +std::tuple, std::vector> copy_blocks_out_of_place( + const std::vector &k_caches, + const std::vector &v_caches, + const c10::Dict> &block_mapping); + +at::Tensor fused_moe(const at::Tensor &hidden_states, + const at::Tensor &gating_output, + const at::Tensor &w1, + const at::Tensor &w2, + const c10::optional &bias1, + const c10::optional &bias2, + const c10::optional &residual, + const c10::optional &input_smooth, + const c10::optional &act_smooth, + const c10::optional &w1_scale, + const c10::optional &w2_scale, + const c10::optional> &w1_quant_flag, + const c10::optional> &w2_quant_flag, + const int64_t topk, + const bool renormalize, + const bool gated, + const std::string &act_mode, + const int64_t start_expert_id, + const int64_t block_n = 0, + const int64_t cncl_comm = 0); + +// d = (a * b + bias) * alpha + c * beta +// d = (((a * a_scale * b * b_scale) * gemm_output_scale + bias) * alpha + c * beta) * d_scale +at::Tensor quant_matmul( + const at::Tensor &a_tensor, // input + const c10::optional &a_scale, // input scale, smooth quant + const c10::optional &a_zero, // input zero + const at::Tensor &b_tensor, // weight + const c10::optional &b_scale, // weight scale, smooth quant + const c10::optional &b_zero, // weight zero + const c10::optional &bias, // bias + const c10::optional &c_tensor, // residual + const c10::optional &c_scale, // residual scale + const c10::optional &c_zero, // residual zero + const c10::optional &gemm_output_scale, // for dequant, weight only + const c10::optional &gemm_output_zero, // for dequant, preserve + const c10::optional &data_type, // output data type + const c10::optional &d, // output + const std::string &quant_algo, // quant type + const std::string &a_quant_layout, // input quant mode + const std::string &b_quant_layout, // weight quant mode + int64_t quant_bit_size = 8, // weight quant bit size + const std::string &act_mode = "none", // act mode + bool use_hp_active = false, // for active, whether uses high precision mode + double act_coef = 1.f, // active_coef + double alpha = 1.f, // for quant matmul + double beta = 1.f, // for residual + bool trans_a = false, // whether transpose input + bool trans_b = true); // whether transpose weight + +at::Tensor quant_matmul_allreduce(const int64_t cncl_comm, + const at::Tensor &a_tensor, + const c10::optional &a_scale, + const c10::optional &a_zero, + const at::Tensor &b_tensor, + const c10::optional &b_scale, + const c10::optional &b_zero, + const c10::optional &bias, + const c10::optional &c_tensor, + const c10::optional &c_scale, + const c10::optional &c_zero, + const c10::optional &gemm_output_scale, + const c10::optional &gemm_output_zero, + const c10::optional &data_type, + const c10::optional &d, + const std::string &quant_algo, + const std::string &a_quant_layout, + const std::string &b_quant_layout, + int64_t quant_bit_size, + double alpha, + double beta, + bool trans_a, + bool trans_b, + const int64_t block_m); + +void active(const torch::Tensor &input, + const torch::Tensor &output, + const c10::optional &bias, + const c10::optional &cusum_token_count, + const std::string &act_mode, + bool is_gated, + int64_t start_expert_id, + int64_t expert_size, + double active_coef = 1.0); + +void smooth_quant(const at::Tensor &input, + const at::Tensor &input_scale, + const at::Tensor &output, + const at::Tensor &output_scale, + const c10::optional &input_zero, + const c10::optional &token_count, + const c10::optional &gather_index, + const c10::optional &gather_index_start_position, + const std::string &quant_mode, + const bool &dynamic_quant); + +at::Tensor matmul(const at::Tensor &a, + const at::Tensor &b, + const c10::optional &d, + const c10::optional &bias, + const c10::optional &c, + const c10::optional &dtype, + const std::string &act_mode, + double alpha, + double beta, + bool fast_act, + bool approximate, + double a_scale, + double b_scale, + bool trans_a, + bool trans_b); + +at::Tensor matmul_allreduce(const int64_t cncl_comm, + const at::Tensor &a, + const at::Tensor &b, + const c10::optional &bias, + const c10::optional &c, + const c10::optional &d, + const double alpha, + const double beta, + const int64_t block_m); + +at::Tensor group_gemm(const at::Tensor &a_tensor, + const at::Tensor &b_tensor, + const at::Tensor &m_list, + const c10::optional &gather_idx, + const c10::optional &c_tensor, + const c10::optional &alpha, + const c10::optional &beta, + const c10::optional &a_scale, + const c10::optional &b_scale, + const c10::optional &bias, + const c10::optional &data_type, + const c10::optional> &quant_flag, + const c10::optional &b_offset, + const int64_t max_m /*max m in m_list*/); + +void preload(const torch::Tensor &weight, const int64_t size); + +at::Tensor group_gemm_combine_result_allreduce(const int64_t cncl_comm, + const at::Tensor &a_tensor, + const at::Tensor &b_tensor, + const at::Tensor &m_list, + const at::Tensor &combine_idx, + const at::Tensor &combine_weight, + const c10::optional &c_tensor, + const c10::optional &alpha, + const c10::optional &beta, + const c10::optional &a_scale, + const c10::optional &b_scale, + const c10::optional &data_type, + const int64_t num_token, + const int64_t topk, + const int64_t block_n); + +at::Tensor flash_attn_sq_mm_allreduce(const int64_t cncl_comm, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const c10::optional &cu_seq_lens_q, + const c10::optional &cu_seq_lens_kv, + const c10::optional &alibi_slope, + const c10::optional &attn_bias, + const at::Tensor &smooth, + const at::Tensor &weight, + const at::Tensor &weight_scale, + const c10::optional &bias, + const int64_t max_seq_len_q, + const int64_t max_seq_len_kv, + const double softmax_scale, + const bool is_causal, + const int64_t window_size_left, + const int64_t window_size_right, + const std::string &compute_dtype, + const int64_t block_seq); + +std::vector moe_softmax_topk(const at::Tensor &input, + int64_t topk, + int64_t num_expert_group, + int64_t topk_group, + bool normalize, + const c10::optional &mask, + const std::string &normed_by); + +at::Tensor moe_expand_input(const torch::Tensor &input, + const torch::Tensor &gather_idx, + const c10::optional &cusum_token_count, + int64_t start_expert_id, + int64_t expert_size); + +std::vector moe_gen_idx(const torch::Tensor &expert_id, int64_t expert_num); + +at::Tensor moe_combine_result(const at::Tensor &input, + const at::Tensor &reduce_weight, + const at::Tensor &gather_ids, + const c10::optional &residual, + const c10::optional &cusum_token_count, + int64_t start_expert_id, + int64_t expert_size, + const c10::optional &bias); + +void fused_rope(at::Tensor &qkv, + at::Tensor &key_cache_hp, + at::Tensor &value_cache_hp, + const c10::optional &key_cache_lp, + const c10::optional &value_cache_lp, + const at::Tensor &sin_table, + const at::Tensor &cos_table, + const at::Tensor &position_ids, + const at::Tensor &gamma, + const at::Tensor &beta, + const c10::optional &key_scale_hp, + const c10::optional &value_scale_hp, + const c10::optional &key_scale_lp, + const c10::optional &value_scale_lp, + const c10::optional &cache_bs_id_hp, + const c10::optional &cache_seq_offsets_hp, + const c10::optional &cache_bs_id_lp, + const c10::optional &cache_seq_offsets_lp, + const c10::optional &slot_mapping_hp, + const c10::optional &slot_mapping_lp, + const double eps); + +void batch_matmul(const at::Tensor &a, + const at::Tensor &b, + const at::Tensor &c, + double alpha, + double beta, + double a_scale, + double b_scale, + bool trans_a, + bool trans_b); + +at::Tensor moe_cast_gating(const at::Tensor &input, const at::Tensor &weight); + +void update_out_and_lse(at::Tensor &out, + at::Tensor &lse, + const at::Tensor &block_out, + const at::Tensor &block_lse, + const c10::optional &seq_offsets, + const c10::optional &cu_seqs, + const c10::optional &block_cu_seqs); + +void cnpxPush(const OpTheory &op); + +void cnpxPop(void); + +void dequant_from_linear_cache(at::Tensor &key, + const c10::optional &value, + const at::Tensor &key_cache, + const c10::optional &value_cache, + const at::Tensor &key_quant_scale, + const c10::optional &value_quant_scale, + const at::Tensor &context_lengths, + const int64_t max_context_len, + const c10::optional &context_seq_offset, + const c10::optional &cache_bs_id, + const c10::optional &cache_seq_offset, + const int64_t quant_mode, + const int64_t quant_bit); + +void dequant_from_paged_cache(at::Tensor &key, + const c10::optional &value, + const at::Tensor &key_cache, + const c10::optional &value_cache, + const at::Tensor &key_cache_quant_scale, + const c10::optional &value_cache_quant_scale, + const at::Tensor &context_lengths, + int64_t max_context_len, + const c10::optional &context_seq_offset, + const at::Tensor &block_tables, + int64_t quant_mode, + int64_t quant_bit); + +} // namespace torch_api +} // namespace tmo + +#endif // CSRC_TORCH_API_TORCH_OPS_API_H_ diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/torch_register_function.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/torch_register_function.cpp new file mode 100644 index 0000000..d0b070f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/torch_register_function.cpp @@ -0,0 +1,263 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include +#include "glue_ops.h" +#include "torch_ops_api.h" + +TORCH_LIBRARY_FRAGMENT(torch_mlu_ops, m) { + // torch 2.1.0 does not support impl_abstract_pystub, enable it when torch 2.3.0 is released +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 + m.impl_abstract_pystub("torch_mlu_ops.abstract"); +#endif + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::attention_project(Tensor input, Tensor q_weight, Tensor? q_bias, Tensor? " + "k_weight, " + "Tensor? k_bias, Tensor? v_weight, Tensor? v_bias, Tensor? norm_weight, Tensor? norm_bias, " + "Tensor? residual, str out_layout, int head_size, float eps, float alpha, float beta, bool " + "norm_out) -> (Tensor[])")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::ffn(Tensor input, Tensor up_fc_weight, Tensor? up_fc_bias, Tensor " + "down_proj_weight, " + "Tensor? down_proj_bias, Tensor? gate_up_proj_weight, Tensor? gate_up_proj_bias, Tensor? " + "layernorm_weight, Tensor? layernorm_bias, str act_mode, str residual_is, float eps, float " + "alpha, float beta) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::flash_attention(Tensor q, Tensor k, Tensor v, Tensor(a!) out, Tensor(b!)? " + "out_lse, Tensor? cu_seq_lens_q, Tensor? cu_seq_lens_kv, " + "Tensor? alibi_slope, Tensor? attn_bias, Tensor? k_quant_scale, Tensor? v_quant_scale, " + "Tensor? block_tables, int max_seq_len_q, int " + "max_seq_len_kv, float softmax_scale, bool is_causal, int window_size_left, int " + "window_size_right, str compute_dtype, bool return_lse) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::single_query_cached_kv_attn(Tensor q_ori, Tensor k_cache, Tensor " + "v_cache, Tensor(a!) output, Tensor block_tables, Tensor context_lens, " + "Tensor(b!)? out_lse, Tensor? k_cache_quant_scale, Tensor? v_cache_quant_scale, " + "Tensor? alibi_slopes, int max_contxt_len, int windows_size_left, int " + "windows_size_right, float softmax_scale, bool return_lse, int kv_cache_quant_bit_size) -> " + "()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::apply_rotary(Tensor(a!) input, Tensor sin_cache, Tensor cos_cache, " + "Tensor? position_ids, Tensor? cu_seqlens, bool interleaved, bool " + "discrete, bool dynamic_ntk, int max_seqlen) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::reshape_linear_cache(Tensor key, Tensor? value, Tensor(a!) key_cache, " + "Tensor(b!)? " + "value_cache, Tensor context_lengths, int max_context_len, bool packed, Tensor? " + "context_seq_offset, Tensor? cache_bs_id, Tensor? cache_seqlen_offset) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::reshape_paged_cache(Tensor k, Tensor? v, Tensor(a!) k_cache, " + "Tensor(b!)? v_cache, Tensor slot_mapping) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::quant_to_paged_cache(Tensor k, Tensor? v, Tensor(a!) k_cache, Tensor(b!)? " + "v_cache, " + "Tensor(c!) k_cache_scale, Tensor(d!)? v_cache_scale, Tensor slot_mapping) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::offline_quant_to_paged_cache(Tensor k, Tensor? v, Tensor k_cache_scale," + "Tensor? v_cache_scale, Tensor slot_mapping, Tensor(a!) k_cache, Tensor(b!)? v_cache)-> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::quant_to_linear_cache(Tensor key, Tensor? value, Tensor(a!) key_cache, " + "Tensor(b!)? " + "value_cache, Tensor(c!) key_cache_scale, Tensor(d!)? value_cache_scale, Tensor " + "context_lengths, " + "int max_context_len, bool packed, Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? " + "cache_seqlen_offset, int quant_bit=8) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::offline_quant_to_linear_cache(Tensor key, Tensor? value, Tensor(a!) " + "key_cache," + " Tensor(b!)? value_cache, Tensor(c!) key_cache_scale, Tensor(d!)? value_cache_scale, " + "Tensor context_lengths, int max_context_len, int quant_mode, bool packed, " + "Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? " + "cache_seqlen_offset) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::quant_matmul(Tensor a_tensor, Tensor? a_scale, Tensor? a_zero, Tensor " + "b_tensor, " + "Tensor? b_scale, Tensor? b_zero, Tensor? bias, Tensor? c_tensor, Tensor? c_scale, Tensor? " + "c_zero, Tensor? gemm_output_scale, Tensor? gemm_output_zero, str? data_type, Tensor? d, " + "str " + "quant_algo, str a_quant_layout, str b_quant_layout, int quant_bit_size=8, str " + "act_mode='none', bool use_hp_active=False, float act_coef=1.0, float alpha=1.0, float " + "beta=1.0, bool trans_a=False, bool trans_b=True) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::quant_matmul_allreduce(int cncl_comm, Tensor a_tensor, Tensor? a_scale, " + "Tensor? a_zero, Tensor b_tensor, Tensor? b_scale, Tensor? b_zero, " + "Tensor? bias, Tensor? c_tensor, Tensor? c_scale, Tensor? " + "c_zero, Tensor? gemm_output_scale, Tensor? gemm_output_zero, str? data_type, Tensor? d, " + "str quant_algo, str a_quant_layout, str b_quant_layout, int quant_bit_size=8, " + "float alpha=1.0, float beta=1.0, bool trans_a=False, bool trans_b=True, int block_m=0) -> " + "Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::active(Tensor input, Tensor(a!) output, Tensor? bias, " + "Tensor? cusum_token_count, str act_mode, bool is_gated, int start_expert_id, " + "int expert_size, float active_coef=1.0) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::smooth_quant(Tensor input, Tensor input_scale, Tensor(a!) output, " + "Tensor(b!) output_scale, " + "Tensor? input_zero, Tensor? token_count, Tensor? gather_index, " + "Tensor? gather_index_start_position, str quant_mode, bool dynamic_quant) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::fused_layernorm(Tensor input, Tensor(a!) out, " + "Tensor? residual, Tensor? gamma, Tensor? beta, Tensor? bias, Tensor? " + "quant_scale, Tensor(b!)? residual_out, Tensor? smooth_quant_scale, str norm_mode, " + "float eps, bool store_output_before_norm, bool dynamic_quant) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::fused_moe(Tensor hidden_states, Tensor gating_output, Tensor " + "w1, Tensor w2, Tensor? bias1, Tensor? bias2, Tensor? residual, Tensor? " + "input_smooth, Tensor? act_smooth, Tensor? w1_scale, Tensor? w2_scale, " + "int[]? w1_quant_flag, int[]? w2_quant_flag, " + "int topk, bool renormalize, bool gated, str act_mode, int start_expert_id, " + "int block_n, int cncl_comm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::matmul(Tensor a, Tensor b, Tensor? d, Tensor? bias, Tensor? c, str? dtype," + " str act_mode, float alpha, float beta, bool fast_act, bool approximate, float a_scale, " + "float " + "b_scale, bool trans_a, bool trans_b) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::batch_matmul(Tensor a, Tensor b, Tensor(a!) c," + "float alpha, float beta, float a_scale, float b_scale, bool trans_a, bool trans_b) -> ()")); + m.def( + TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::matmul_allreduce(int cncl_comm, Tensor a, Tensor b, " + "Tensor? bias, Tensor? c, Tensor? d, " + "float alpha, float beta, int block_m) -> Tensor")); + m.def( + TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::swap_blocks(Tensor(a!) dst, Tensor src, Dict(int, " + "int) block_mapping) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::copy_blocks(Tensor(a!)[] k_caches, Tensor(b!)[] v_caches, " + "Dict(int, int[]) block_mapping) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::copy_blocks_out_of_place(Tensor[] k_caches, Tensor[] v_caches, " + "Dict(int, int[]) block_mapping) -> (Tensor[], Tensor[])")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::group_gemm(Tensor a, Tensor b, Tensor m_list, Tensor? idx, Tensor? c, " + "Tensor? alpha, Tensor? beta, Tensor? a_scale, Tensor? b_scale, Tensor? bias," + "str? data_type, int[]? quant_flag, Tensor? b_offset, int max_m) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::preload(Tensor(a!) weight, int size) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::group_gemm_combine_result_allreduce(int cncl_comm, Tensor a, Tensor b, " + "Tensor m_list, Tensor combine_idx, Tensor combine_weight, Tensor? c, " + "Tensor? alpha, Tensor? beta, Tensor? a_scale, Tensor? b_scale, " + "str? data_type, int num_token, int topk, int block_n) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::flash_attn_sq_mm_allreduce(int cncl_comm, Tensor q, Tensor k, " + "Tensor v, Tensor? cu_seq_lens_q, Tensor? cu_seq_lens_k, Tensor? alibi_slope, " + "Tensor? attn_bias, Tensor smooth, Tensor weight, Tensor weight_scale, " + "Tensor? bias, int max_seq_len_q, int max_seq_len_kv, float softmax_scale, bool is_causal, " + "int window_size_left, int window_size_right, " + "str compute_dtype, int block_seq) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::moe_softmax_topk(Tensor input, int topk, int num_expert_group, " + "int topk_group, bool normalize, Tensor? mask, str normed_by) -> Tensor[]")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::moe_expand_input(Tensor input, Tensor gather_idx, Tensor ? " + "cusum_token_count, int start_expert_id, int expert_size) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::moe_gen_idx(Tensor expert_id, int expert_num) -> Tensor[]")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::moe_combine_result(Tensor input, Tensor reduce_weight," + "Tensor gather_ids, Tensor? residual," + "Tensor? cusum_token_count, int start_expert_id, int expert_size, Tensor? bias) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::fused_rope(Tensor(a!) input, Tensor(b!) k_cache_hp, Tensor(c!) v_cache_hp, " + "Tensor(d!)? k_cache_lp, Tensor(e!)? v_cache_lp, Tensor sin_table, Tensor cos_table, " + "Tensor position_ids, Tensor gamma, Tensor beta, Tensor? k_scale_hp, Tensor? v_scale_hp, " + "Tensor(f!)? k_scale_lp, Tensor(g!)? v_scale_lp, Tensor? cache_bs_id_hp, " + "Tensor? cache_seq_offsets_hp, Tensor? cache_bs_id_lp, Tensor? cache_seq_offsets_lp, " + "Tensor? slot_mapping_hp, Tensor? slot_mapping_lp, float eps) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::moe_cast_gating(Tensor input, Tensor weight) -> Tensor")); + m.def( + TORCH_SELECTIVE_SCHEMA("torch_mlu_ops::update_out_and_lse(Tensor(a!) out, Tensor(b!) lse, " + "Tensor block_out, Tensor block_lse," + "Tensor? seq_offsets, Tensor? cu_seqs, Tensor? block_cu_seqs) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::dequant_from_linear_cache(Tensor(a!) key, Tensor(b!)? value, " + "Tensor key_cache, Tensor? value_cache, Tensor key_cache_quant_scale, " + "Tensor? value_cache_quant_scale, Tensor context_lengths, int max_context_len, " + "Tensor? context_seq_offset, Tensor? cache_bs_id, Tensor? cache_seq_offset, " + "int quant_mode=0, int quant_bit=8) -> ()")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torch_mlu_ops::dequant_from_paged_cache(Tensor(a!) key, Tensor(b!)? value, " + "Tensor key_cache, Tensor? value_cache, Tensor key_cache_quant_scale, " + "Tensor? value_cache_quant_scale, Tensor context_lengths, int max_context_len, " + "Tensor? context_seq_offset, Tensor block_tables, int quant_mode=0, int quant_bit=8) -> ()")); +} + +TORCH_LIBRARY_IMPL(torch_mlu_ops, PrivateUse1, m) { + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::attention_project"), + TORCH_FN(tmo::torch_api::attention_project)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::ffn"), TORCH_FN(tmo::torch_api::ffn)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::flash_attention"), + TORCH_FN(tmo::torch_api::flash_attention)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::single_query_cached_kv_attn"), + TORCH_FN(tmo::torch_api::single_query_cached_kv_attn)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::apply_rotary"), + TORCH_FN(tmo::torch_api::apply_rotary)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::reshape_linear_cache"), + TORCH_FN(tmo::torch_api::reshape_linear_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::reshape_paged_cache"), + TORCH_FN(tmo::torch_api::reshape_paged_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_to_paged_cache"), + TORCH_FN(tmo::torch_api::quant_to_paged_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::offline_quant_to_paged_cache"), + TORCH_FN(tmo::torch_api::offline_quant_to_paged_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_to_linear_cache"), + TORCH_FN(tmo::torch_api::quant_to_linear_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::offline_quant_to_linear_cache"), + TORCH_FN(tmo::torch_api::offline_quant_to_linear_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_matmul"), + TORCH_FN(tmo::torch_api::quant_matmul)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::quant_matmul_allreduce"), + TORCH_FN(tmo::torch_api::quant_matmul_allreduce)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::active"), TORCH_FN(tmo::torch_api::active)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::smooth_quant"), + TORCH_FN(tmo::torch_api::smooth_quant)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_layernorm"), + TORCH_FN(tmo::torch_api::fused_layernorm)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_moe"), TORCH_FN(tmo::torch_api::fused_moe)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::matmul"), TORCH_FN(tmo::torch_api::matmul)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::batch_matmul"), + TORCH_FN(tmo::torch_api::batch_matmul)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::matmul_allreduce"), + TORCH_FN(tmo::torch_api::matmul_allreduce)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::swap_blocks"), TORCH_FN(tmo::torch_api::swap_blocks)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::copy_blocks"), TORCH_FN(tmo::torch_api::copy_blocks)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::copy_blocks_out_of_place"), + TORCH_FN(tmo::torch_api::copy_blocks_out_of_place)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::group_gemm"), TORCH_FN(tmo::torch_api::group_gemm)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::preload"), TORCH_FN(tmo::torch_api::preload)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::group_gemm_combine_result_allreduce"), + TORCH_FN(tmo::torch_api::group_gemm_combine_result_allreduce)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::flash_attn_sq_mm_allreduce"), + TORCH_FN(tmo::torch_api::flash_attn_sq_mm_allreduce)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_softmax_topk"), + TORCH_FN(tmo::torch_api::moe_softmax_topk)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_expand_input"), + TORCH_FN(tmo::torch_api::moe_expand_input)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_gen_idx"), TORCH_FN(tmo::torch_api::moe_gen_idx)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_combine_result"), + TORCH_FN(tmo::torch_api::moe_combine_result)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::moe_cast_gating"), + TORCH_FN(tmo::torch_api::moe_cast_gating)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::fused_rope"), TORCH_FN(tmo::torch_api::fused_rope)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::update_out_and_lse"), + TORCH_FN(tmo::torch_api::update_out_and_lse)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::dequant_from_linear_cache"), + TORCH_FN(tmo::torch_api::dequant_from_linear_cache)); + m.impl(TORCH_SELECTIVE_NAME("torch_mlu_ops::dequant_from_paged_cache"), + TORCH_FN(tmo::torch_api::dequant_from_paged_cache)); +} + +// Function(copy_blocks__functionalization_glue) is no longer needed in the torch 2.5 +TORCH_LIBRARY_IMPL(torch_mlu_ops, Functionalize, m) { + m.impl("torch_mlu_ops::copy_blocks", copy_blocks__functionalization_glue); +} diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/update_out_and_lse.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/update_out_and_lse.cpp new file mode 100644 index 0000000..c7172c8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/update_out_and_lse.cpp @@ -0,0 +1,115 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/update_out_and_lse.mluh" +#include "torch_ops_api.h" +#include "utils.h" + +namespace tmo { +namespace torch_api { + +void update_out_and_lse(at::Tensor &out, + at::Tensor &lse, + const at::Tensor &block_out, + const at::Tensor &block_lse, + const c10::optional &seq_offsets, + const c10::optional &cu_seqs, + const c10::optional &block_cu_seqs) { + const torch_mlu::mlu::MLUGuard device_guard(out.device()); + auto queue = torch_mlu::getCurMLUStream(); + checkTensorSameAttr(out, lse, block_out, block_lse, seq_offsets, cu_seqs, + block_cu_seqs); + checkTensorSameAttr(seq_offsets, cu_seqs, block_cu_seqs); + checkTensorSameAttr(out, block_out); + checkTensorSameAttr(lse, block_lse); + + TORCH_CHECK(out.dim() == 3 || out.dim() == 4, "the dim of out must be 3 or 4"); + TORCH_CHECK(block_out.dim() == 3 || block_out.dim() == 4, "the dim of block_out must be 3 or 4"); + TORCH_CHECK(lse.dim() == 3, "the dim of lse must be 3"); + TORCH_CHECK(block_lse.dim() == 3, "the dim of block_lse must be 3"); + + bool packed = out.dim() == 3; + + int32_t batch{0}, head_num{0}, head_size{0}, max_seq_len{0}, block_seq_len{0}; + int64_t bs_stride{0}, seq_stride{0}, head_stride{0}, block_bs_stride{0}, block_seq_stride{0}, + block_head_stride{0}; + auto data_type = getCnnlDataType(out.scalar_type()); + + batch = lse.size(0); + head_num = lse.size(1); + max_seq_len = lse.size(2); + block_seq_len = block_lse.size(2); + head_size = out.size(-1); + + TORCH_CHECK(block_seq_len <= max_seq_len, + "the max block_seq_len of block_lse can not be greater than the max_seq_len of lse. " + "block_seq_len: ", + block_seq_len, " max_seq_len: ", max_seq_len); + TORCH_CHECK(data_type == CNNL_DTYPE_FLOAT || data_type == CNNL_DTYPE_BFLOAT16 || + data_type == CNNL_DTYPE_HALF, + "the data_type of out only support float, bfloat16 and half"); + + if (packed) { + seq_stride = out.stride(0); + head_stride = out.stride(1); + block_seq_stride = block_out.stride(0); + block_head_stride = block_out.stride(1); + + if (max_seq_len == block_seq_len && block_seq_len == 1) { + bs_stride = seq_stride; + block_bs_stride = block_seq_stride; + } + + CHECK_SHAPE(out, out.size(0), head_num, head_size); + CHECK_SHAPE(block_out, block_out.size(0), head_num, head_size); + } else { + bs_stride = out.stride(0); + seq_stride = out.stride(1); + head_stride = out.stride(2); + block_bs_stride = block_out.stride(0); + block_seq_stride = block_out.stride(1); + block_head_stride = block_out.stride(2); + CHECK_SHAPE(out, batch, max_seq_len, head_num, head_size); + CHECK_SHAPE(block_out, batch, block_seq_len, head_num, head_size); + } + + TORCH_CHECK(head_size <= 1024, "the head_size can not be greater than 1024."); + + if (seq_offsets.has_value()) { + TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous"); + TORCH_CHECK(seq_offsets.value().dtype() == torch::kInt32, + "the dtype of seq_offsets must be int32"); + TORCH_CHECK(seq_offsets.value().dim() == 1, "the dim of seq_offsets must be 1"); + CHECK_SHAPE(seq_offsets.value(), batch); + } + + if (cu_seqs.has_value()) { + TORCH_CHECK(cu_seqs.value().is_contiguous(), "cu_seqs must be contiguous"); + TORCH_CHECK(cu_seqs.value().dim() == 1, "the dim of cu_seqs must be 1"); + CHECK_SHAPE(cu_seqs.value(), batch + 1); + } + if (block_cu_seqs.has_value()) { + TORCH_CHECK(block_cu_seqs.value().is_contiguous(), "block_cu_seqs must be contiguous"); + TORCH_CHECK(block_cu_seqs.value().dim() == 1, "the dim of block_cu_seqs must be 1"); + CHECK_SHAPE(block_cu_seqs.value(), batch + 1); + } + + invokeUpdateOutAndLse(queue, getAtTensorPtr(out), (float *)getAtTensorPtr(lse), + getAtTensorPtr(block_out), (float *)getAtTensorPtr(block_lse), + (int32_t *)getAtTensorPtr(seq_offsets), (int32_t *)getAtTensorPtr(cu_seqs), + (int32_t *)getAtTensorPtr(block_cu_seqs), batch, head_num, head_size, + max_seq_len, block_seq_len, bs_stride, seq_stride, head_stride, + block_bs_stride, block_seq_stride, block_head_stride, packed, data_type); +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp b/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp new file mode 100644 index 0000000..ed536b2 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp @@ -0,0 +1,75 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "utils.h" +#include "common/utils.h" + +namespace tmo { +namespace torch_api { + +#define CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(_) \ + _(CNNL_DTYPE_FLOAT, at::kFloat) \ + _(CNNL_DTYPE_BFLOAT16, at::kBFloat16) \ + _(CNNL_DTYPE_HALF, at::kHalf) \ + _(CNNL_DTYPE_INT32, at::kInt) \ + _(CNNL_DTYPE_INT8, at::kChar) \ + _(CNNL_DTYPE_UINT8, at::kByte) \ + _(CNNL_DTYPE_BOOL, at::kBool) \ + _(CNNL_DTYPE_INT16, at::kShort) \ + _(CNNL_DTYPE_COMPLEX_HALF, at::kComplexHalf) \ + _(CNNL_DTYPE_COMPLEX_FLOAT, at::kComplexFloat) + +cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type) { + switch (data_type) { +#define DEFINE_CASE(cnnl_dtype, scalar_type) \ + case scalar_type: \ + return cnnl_dtype; + CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(DEFINE_CASE) +#undef DEFINE_CASE + case at::kLong: + return CNNL_DTYPE_INT32; + case at::kDouble: + return CNNL_DTYPE_FLOAT; + case at::kComplexDouble: + return CNNL_DTYPE_COMPLEX_FLOAT; + default: + std::string msg("getCnnlDataType() not supported for "); + throw std::runtime_error(msg + c10::toString(data_type)); + } +} + +std::vector createTensorDescs(const std::initializer_list &tensors) { + std::vector descs; + for (size_t i = 0; i < tensors.size(); ++i) { + descs.emplace_back(TensorDesc{nullptr, cnnlDestroyTensorDescriptor}); + auto tensor = tensors.begin()[i]; + if (!tensor.defined()) { + continue; + } + cnnlTensorDescriptor_t desc; + CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&desc)); + descs[i].reset(desc); + cnnlDataType_t data_type = getCnnlDataType(tensor.scalar_type()); + if (tensor.strides().size() == 0) { + CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type, + tensor.sizes().size(), tensor.sizes().data())); + } else { + CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type, + tensor.sizes().size(), tensor.sizes().data(), + tensor.strides().data())); + } + } + return descs; +} + +} // namespace torch_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.h b/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.h new file mode 100644 index 0000000..588a455 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.h @@ -0,0 +1,164 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#ifndef CSRC_TORCH_API_UTILS_H_ +#define CSRC_TORCH_API_UTILS_H_ + +#include +#include +#include +#include +#include "ATen/ScalarType.h" +#include "ATen/Tensor.h" +#include "aten/cnnl/cnnlHandle.h" +#include "c10/util/Exception.h" +#include "cnnl.h" +#include "framework/core/MLUStream.h" +#include "framework/core/caching_allocator.h" +#include "framework/core/device.h" +#include "framework/core/mlu_guard.h" +#include "torch/torch.h" +#include "torch/version.h" + +namespace tmo { +namespace torch_api { +using TensorDesc = std::unique_ptr, + decltype(&cnnlDestroyTensorDescriptor)>; +std::vector createTensorDescs(const std::initializer_list &tensors); +cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type); + +template +bool isMlu(const T &tensor) { +#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1) + return tensor.device().is_privateuseone(); +#else + return tensor.is_mlu(); +#endif +} + +enum class TensorAttr { DEVICE, DTYPE, ALL }; +struct attr_t { + int64_t device_id; + at::ScalarType dtype; +}; + +inline void checkDevice(int64_t &device_id, const at::Tensor &tensor) { + auto tensor_device_id = tensor.get_device(); + if (device_id == -1) { + device_id = tensor_device_id; + return; + } + TORCH_CHECK(tensor_device_id == device_id, + "Tensor device id is not same, original device_id: ", device_id, + "now device_id is: ", tensor_device_id); +} + +inline void checkDtype(at::ScalarType &dtype, const at::Tensor &tensor) { + auto tensor_dtype = tensor.scalar_type(); + if (dtype == at::ScalarType::Undefined) { + dtype = tensor_dtype; + return; + } + TORCH_CHECK(tensor_dtype == dtype, "Tensor dtype is not same. original dtype: ", dtype, + "now dtype is: ", tensor_dtype); +} + +template +inline void checkTensorAttr(attr_t &attr_states, const at::Tensor &tensor) { + if (attr == TensorAttr::DEVICE) { + checkDevice(attr_states.device_id, tensor); + } else if (attr == TensorAttr::DTYPE) { + checkDtype(attr_states.dtype, tensor); + } else if (attr == TensorAttr::ALL) { + checkDevice(attr_states.device_id, tensor); + checkDtype(attr_states.dtype, tensor); + } +} + +template ::type, at::Tensor>::value>::type> +void checkTensorSameWithSpecificAttr(attr_t &attr_states, const c10::optional &tensor) { + if (!tensor.has_value() || !tensor->defined()) return; + auto temp_tensor = tensor.value(); + TORCH_CHECK(isMlu(temp_tensor), "Only support mlu tensor."); + checkTensorAttr(attr_states, temp_tensor); +} + +template ::type, at::Tensor>::value>::type> +void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor) { + if (!tensor.defined()) return; + TORCH_CHECK(isMlu(tensor), "Only support mlu tensor."); + checkTensorAttr(attr_states, tensor); +} + +template +void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor, Args &&...args) { + checkTensorSameWithSpecificAttr(attr_states, tensor); + checkTensorSameWithSpecificAttr(attr_states, std::forward(args)...); +} + +template +void checkTensorSameAttr(Args &&...args) { + attr_t attr_states = {-1, at::ScalarType::Undefined}; + checkTensorSameWithSpecificAttr(attr_states, std::forward(args)...); +} + +inline at::ScalarType str2TorchDtype(const std::string &type) { + static std::map dtype_map = { + {"float", torch::kFloat32}, {"half", torch::kHalf}, {"bfloat16", torch::kBFloat16}, + {"int32", torch::kInt32}, {"int8", torch::kInt8}, + }; + return dtype_map.at(type); +} + +inline std::string &torchDtype2Str(const at::ScalarType type) { + static std::map torch_dtype_map = { + {torch::kFloat32, "float"}, {torch::kHalf, "half"}, {torch::kBFloat16, "bfloat16"}, + {torch::kInt32, "int32"}, {torch::kInt8, "int8"}, + }; + return torch_dtype_map.at(type); +} + +inline cnnlDataType_t str2CnnlDtype(const std::string &type) { + static std::map cnnl_dtype_map = { + {"float", CNNL_DTYPE_FLOAT}, {"half", CNNL_DTYPE_HALF}, {"bfloat16", CNNL_DTYPE_BFLOAT16}, + {"int32", CNNL_DTYPE_INT32}, {"int8", CNNL_DTYPE_INT8}, + }; + return cnnl_dtype_map.at(type); +} + +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") + +#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") + +#define CHECK_OPTIONAL_TENSOR_CONTIGUOUS(x) \ + if (x.has_value()) TORCH_CHECK(x.value().is_contiguous(), #x " must be contiguous.") + +inline void *getAtTensorPtr(const c10::optional &tensor) { + return tensor.has_value() ? tensor.value().data_ptr() : nullptr; +} + +inline void *getAtTensorPtr(const at::Tensor &tensor) { + return tensor.defined() ? tensor.data_ptr() : nullptr; +} + +} // namespace torch_api +} // namespace tmo + +#endif // CSRC_TORCH_API_UTILS_H_ diff --git a/torch_mlu_ops-v1.3.2/docs/build.sh b/torch_mlu_ops-v1.3.2/docs/build.sh new file mode 100755 index 0000000..703608e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/build.sh @@ -0,0 +1,19 @@ +#! /bin/bash + +script_path=`dirname $0` +echo "script path:" ${script_path} +cd $script_path + + +### User Guide ### +pushd user_guide + ./makelatexpdf.sh + cp ./build/latex/Cambricon*.pdf ../../ + cp ./build/torch_mlu_ops_user_guide_html.zip ../../ +popd + +pushd release_notes + ./makelatexpdf.sh + cp ./build/latex/Cambricon*.pdf ../../ + cp ./build/torch_mlu_ops_release_notes_html.zip ../../ +popd diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/Makefile b/torch_mlu_ops-v1.3.2/docs/release_notes/Makefile new file mode 100644 index 0000000..433460f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = Torch-MLU-Ops ReleaseNotes +SOURCEDIR = . +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/README.rst b/torch_mlu_ops-v1.3.2/docs/release_notes/README.rst new file mode 100644 index 0000000..0f5d49d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/README.rst @@ -0,0 +1,84 @@ +Release Notes写作注意事项 +-------------------------------- + +* overview文件夹下的overview.rst用来写模块概述,简单描述该模块的定义、作用等内容。 +* 如果写新增版本的releasenotes,请先按版本号新建文件夹,比如新增了0.1.1版本,则新建文件夹并命名为 ``0.1.1``,在 ``0.1.1`` 文件夹下新建文件index.rst,建议将之前版本的index.rst copy过来直接修改。 +* 将新增文件添加到工程根目录下的 ``index.rst`` 主索引文件中,否则新增内容将无法包含在最终文档中。 +* 按照rst语法、release notes写作规范,写变更内容。 +* release notes写代码的修改、变更等相关内容,和文档的更新历史没直接关系。文档的“更新历史”只写与文档相关的变更。 +* 执行 ``./makelatexpdf.sh`` 编译文档。 + +Release Notes写作规范 +------------------------ + +**特性变更:** + +如何写: + + * 该章节包含新增特性、删除特性以及修改特性。新增算子、新增接口、删除算子、删除接口以及算子和接口的修改变动,都在该章节详细描述。 + * 用无序列表的方式详细描述每一项内容。 + * 修改部分要详细描述修改原因,修改能带来什么样的好处。 删除部分也要说明删除原因。: + + 举例如下: + + - 新增xxx特性,支持xxx视频的播放。 + - 删除xxx特性,因为xxx特性为临时方案,目前已被YYY特性取代,因此删除xxx临时特性。 + - 新增x1、x2算子以支持xxx功能。 + +**已修复问题:** + +如何写: + + 第一个版本的Release notes该章节无需写作。后续版本,已解决的遗留问题要在该章节描述。 + + 已修复问题至少要描述清楚两点: + + 1.解决了什么问题; + 2.因为解决该问题做了什么样的修改。 + + 可以按以下句式: + + * 通过修改xxx,修复了xxx问题 + +**已知遗留问题:** + +如何写: + + * 该章节描述在该版本还存在的问题,或者已给用户承诺的暂时未达标的规格要求。 + * 如果有多个遗留问题,每个遗留问题用无序列表的方式描述,祥见下面的举例。 + * 该部分内容必须包含“现象:”、“原因:”、“影响:”、“规避措施:”四个部分,可参见下面的举例。 + + 举例如下: + + * 080p分辨率以上的视频问题件无法播放。 + + **现象:** + + 详细描述出现该问题后的现象。该现象是用户可以直接看到的,不是问题本身的描述。比如: + + 当播放1080p以上分辨率的视频文件时,会出现卡段或掉帧的现象,导致视频播放失败。 + + **原因:** + + 描述出现该问题的原因,或者用户如何操作才会导致该问题,比如: + + 当前采用的VPU不支持1080p视频文件的播放。 + + **影响:** + + 详细说明该问题对用户可能造成的影响,比如: + + 该问题会导致无法播放1080p分辨率以上的视频文件。 + + **规避措施:** + + 对该问题有无规避措施。如果无规避措,直接写“无规避措施”,如果有规避措施,详细描述该问题的规避措施.比如: + + 无。 + + **注意:** + + 在上一版本记录的已知遗留问题,如果在当前版本依然存在,在当前版本需要继续记录。 + + +。 diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/_static/custom.css b/torch_mlu_ops-v1.3.2/docs/release_notes/_static/custom.css new file mode 100644 index 0000000..fe4b3a0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/_static/custom.css @@ -0,0 +1,5 @@ +@import url("theme.css"); + +.wy-nav-content { + max-width: 100%; +} diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/chapterbkpaper.pdf b/torch_mlu_ops-v1.3.2/docs/release_notes/chapterbkpaper.pdf new file mode 100644 index 0000000..6223b59 Binary files /dev/null and b/torch_mlu_ops-v1.3.2/docs/release_notes/chapterbkpaper.pdf differ diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/conf.json b/torch_mlu_ops-v1.3.2/docs/release_notes/conf.json new file mode 100644 index 0000000..4a34d47 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/conf.json @@ -0,0 +1,165 @@ +{ + "package": [ + "\\\\usepackage{amsmath}", + "\\\\usepackage[table]{xcolor}", + "\\\\usepackage{eso-pic}", + "\\\\usepackage{wallpaper}", + "%\\\\usepackage{titlesec}", + "\\\\usepackage{tocbibind}", + "% \\\\usepackage{draftwatermark}", + "\\\\usepackage{enumitem}", + "\\\\usepackage{tcolorbox}", + "\\\\usepackage{listings}", + "\\\\usepackage{framed}", + "\\\\usepackage{color}", + "\\\\usepackage{multirow}", + "% \\\\usepackage[justification=centering]{caption}" + ], + "replacepackage": { + "comment": "该章节内容是为了替换现有tex里面包的配置,左边为tex文件现有内容,右边是替换内容。", + "\\\\usepackage{hyperref}": "\\\\usepackage[bookmarksnumbered=true]{hyperref}", + "\\\\sphinxtableofcontents": + "\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}\n\\\\sphinxtableofcontents", + "\\\\chapter{([\\s\\S].*)}": + "\\\\chapter{\\1}\n\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}", + "\\\\listoffigures": "\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}\n\\\\listoffigures", + "\\\\listoftables": + "\\\\newpage\n\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}\n\\\\listoftables", + "\\\\footnotesize\\\\raggedright\\\\printindex": + "% \\\\footnotesize\\\\raggedright\\\\printindex", + "\\\\begin{itemize}": "\\\\begin{itemize}[leftmargin=*]", + "\\\\begin{enumerate}": "\\\\begin{enumerate}[leftmargin=*]", + "\\\\setmainfont{FreeSerif}\\[[\\s\\S]*?\\]": "", + "\\\\setsansfont{FreeSans}\\[[\\s\\S]*?\\]": "", + "\\\\setmonofont{FreeMono}\\[[\\s\\S]*?\\]": "", + "Extension([\\s\\S]*?)= .otf,": "Extension = ,", + "\\\\sphinxtoprule": "\\\\hline", + "\\\\sphinxmidrule": "\\\\hline", + "\\\\sphinxbottomrule": "\\\\hline", + "\\\\sphinxhline": "\\\\hline", + "\\\\sphinxhyphen": "{sphinx:5.3.0}\\\\PYGZhy", + "\\\\begin{sphinxalltt}": "\\\\begin{sphinxVerbatim}[commandchars=\\\\\\\\\\\\{\\}]", + "\\\\end{sphinxalltt}": "\\\\end{sphinxVerbatim}", + "\\\\begin{sphinxadmonition}{note}{注解:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=注解:]", + "\\\\begin{sphinxadmonition}{note}{备注:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=注解:]", + "\\\\begin{sphinxadmonition}{warning}{警告:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=警告:]", + "\\\\begin{sphinxadmonition}{tip}{小技巧:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=小技巧:]", + "\\\\begin{sphinxadmonition}{attention}{注意:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=注意:]", + "\\\\begin{sphinxadmonition}{hint}{提示:}": + "\\\\begin{tcolorbox}[colframe={hintframecolor},colback={hintbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=提示:]", + "\\\\end{sphinxadmonition}": "\\\\end{tcolorbox}", + "\\\\begin{sphinxadmonition}{note}{Note:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Note:]", + "\\\\begin{sphinxadmonition}{warning}{Warning:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Warning:]", + "\\\\begin{sphinxadmonition}{tip}{Tip:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Tip:]", + "\\\\begin{sphinxadmonition}{attention}{Attention:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Attention:]", + "\\\\begin{sphinxadmonition}{hint}{Hint:}": + "\\\\begin{tcolorbox}[colframe={hintframecolor},colback={hintbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Hint:]" + }, + "customoptions": [ + "% \\\\numberwithin{figure}{chapter}", + "% \\\\numberwithin{table}{chapter}", + "\\\\titleformat{\\\\chapter}{\\\\raggedleft\\\\huge\\\\bfseries\\\\color{white}}{\\\\thechapter}{0.5em}{}", + "\\\\titlespacing{\\\\chapter}{0pt}{50pt}{25pt}", + "\\\\definecolor{noteframecolor}{RGB}{91,163,235}", + "\\\\definecolor{notebackcolor}{RGB}{222,237,251}", + "\\\\definecolor{warningframecolor}{RGB}{235,162,28}", + "\\\\definecolor{warningbackcolor}{RGB}{255,247,236}", + "\\\\definecolor{hintframecolor}{RGB}{70,193,196}", + "\\\\definecolor{hintbackcolor}{RGB}{226,254,249}", + "\\\\definecolor{camblue}{RGB}{0,89,196}", + "% \\\\SetWatermarkText{Cambricon}", + "% \\\\SetWatermarkLightness{0.9}", + "% \\\\SetWatermarkScale{1}", + "\\\\renewcommand{\\\\labelitemi}{$\\\\vcenter{\\\\hbox{\\\\scriptsize$\\\\bullet$}}$}", + "\\\\definecolor{shadecolor}{RGB}{220,220,220}" + ], + "isfiguretabletoc": { + "comment": "插图目录英文:List of Figures;表格目录英文:List of Tables.", + "isfigurestoc": false, + "istablestoc": false, + "figurestoc": [ + "\\\\renewcommand\\\\listfigurename{插\\ 图\\ 目\\ 录}", + "\\\\listoffigures" + ], + "tablestoc": [ + "\\\\renewcommand\\\\listtablename{表\\ 格\\ 目\\ 录}", + "\\\\listoftables" + ] + }, + "tables": { + "comment": + "isname:true-根据表格的name属性查找表格,false-根据表格的标题查找表格。isnewpage:true-强制在新页生成表格,false-默认。ispartial:true-采用部分匹配查找表格,false-精确匹配查找表格;isLongTable:true-强制设置为长表格,false-不设置长表格。isVertical:true-对第一列进行渲染,false-不对第一列进行渲染。isCusHead:true-对第一行进行渲染,false-不对第一行进行渲染。", + "isname": false, + "rowtype": "", + "headtype": "{camblue!100}", + "headfontcolor": "\\textbf{\\textcolor{white}{}}", + "styles": [ + { + "align": "centering", + "caption": "分区信息xxxxx", + "captionalign": "left", + "isLongTable": true, + "isVertical": false, + "isCusHead": true, + "isnewpage": true, + "ispartial": true + }, + { + "align": "centering", + "caption": "版本记录", + "captionalign": "left", + "isLongTable": false, + "isVertical": true, + "isCusHead": false, + "ispartial": true + } + ] + }, + "image": { + "styles": [ + { + "name": "", + "align": "", + "caption": "" + } + ] + }, + "sensitivewords": [ + "安防", + "监听", + "stream", + "hisi", + "展讯", + "英伟达", + "nvidia", + "讯飞", + "展锐", + "c10", + "c20", + "IK", + "AH", + "US", + "MLU320" + ], + "ignorewarndes": [ + "该字段已废弃不再使用,但不能删掉,否则会导致编译错误!" + ], + "ignorewarnkey": [ + "sdk_memory.rst && 文档没有加入到任何目录树中", + "sdk_memory.rst && document isn't included in any toctree", + "Package hyperref Warning", + "LaTeX Font Warning", + "Package cmap Warning", + "Package xeCJK Warning" + ] + +} diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/conf.py b/torch_mlu_ops-v1.3.2/docs/release_notes/conf.py new file mode 100644 index 0000000..93ccd79 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/conf.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +# +# TensorFlow V1.10 documentation build configuration file, created by +# sphinx-quickstart on Tue Apr 23 18:33:05 2019. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) +from __future__ import print_function + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +#import sys +#reload(sys) +#sys.setdefaultencoding('utf-8') +extensions = [] +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'寒武纪Torch-MLU-Ops版本说明书' +copyright = u'2024, Cambricon' +author = u'' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = u'' +# The full version, including alpha/beta/rc tags. +release = u'' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language ='zh_CN' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'README.rst'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_copy_source = False +html_css_files = [ + 'custom.css', +] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# This is required for the alabaster theme +# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars +html_sidebars = { + '**': [ + 'relations.html', # needs 'show_related': True theme option to display + 'searchbox.html', + ] +} + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'release notes' + + +# -- Options for LaTeX output --------------------------------------------- + + + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'Cambricon-Torch-MLU-Ops-Release-notes-CN.tex', u'寒武纪Torch-MLU-Ops版本说明书', + author, 'manual'), +] + + +# -- Options for manual page output --------------------------------------- +# xelatex 作为 latex 渲染引擎,因为可能有中文渲染 +latex_engine = 'xelatex' + + # 如果没有使用自定义封面,可以加上封面 logo。这里可以是 pdf 或者 png,jpeg 等 +latex_logo = "./logo.png" + + +latex_show_urls = 'footnote' +latex_use_latex_multicolumn = True +latex_use_xindy = False +# 主要的配置,用于控制格式等 +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +# +'papersize': 'a4paper', + +# The font size ('10pt', '11pt' or '12pt'). +# +# 'pointsize': '8pt', + +'inputenc': '', +'utf8extra': '', +'figure_align': 'H', +'releasename':"", +'sphinxsetup': ''' + verbatimwithframe=false, + verbatimwrapslines=true, + VerbatimColor={RGB}{220,220,220}, + verbatimhintsturnover=false, +''', +'fontpkg':''' + % 设置字体 + + \\usepackage{xeCJK} + \\usepackage{fontspec} + \\CJKsetecglue{\\hskip0.08em plus0.01em minus 0.01em} + \\defaultfontfeatures[\\rmfamily,\\sffamily,\\ttfamily]{} + \\setCJKmainfont{Noto Sans CJK SC}[AutoFakeSlant] + \\setCJKsansfont{Noto Sans CJK SC}[AutoFakeSlant] + \\setCJKmonofont{Noto Sans Mono CJK SC}[AutoFakeSlant] + \\setmainfont{Noto Sans CJK SC}[AutoFakeSlant] +''', +# Additional stuff for the LaTeX preamble. +'preamble': ''' +\\addto\\captionsenglish{\\renewcommand{\\chaptername}{}} +\\LTcapwidth=400pt + +%清除 latex headheight small错误编译告警 +\\setlength{\\headheight}{14.0pt} +% 段首缩进 2 格 + +%\\usepackage{indentfirst} +%\\setlength{\\parindent}{2em} + +\\usepackage{setspace} + +% 1.5 倍行间距 +\\renewcommand{\\baselinestretch}{1.5} +% 表格里的行间距 +\\renewcommand{\\arraystretch}{1.5} + +% list 列表的缩进对齐 +\\usepackage{enumitem} +\\setlist{nosep} + +% 表格类的宏包 +\\usepackage{threeparttable} +\\usepackage{array} +\\usepackage{booktabs} + +% fancy 页眉页脚 +\\usepackage{fancyhdr} +\\pagestyle{fancy} + +% 在 sphinx 生成的 tex 文件里,normal 是指普通页面(每一个章节里,除了第一页外剩下的页面) +% 页眉,L:left,R:right,E:even,O:odd +% 奇数页面:左边是 leftmark,章号和章名称;右边是 rightmark,节号与节名称 +% 偶数页面:左边是 rightmark,节号与节名称;右边是 leftmark,章号和章名称 +% textsl 是字体,slanted shape,对于英语而言,某种斜体。但是不同于 textit 的 italic shape +% 左页脚:版权信息 +% 右页脚:页码数 +% rulewidth:页眉和页脚附近的横线的粗细,当设置为 0pt 时,就没有该横线 +% +\\fancypagestyle{normal} { +\\fancyhf{} +\\fancyhead{} +\\fancyhead[LE,RO]{\\textsl{\\rightmark}} +\\fancyhead[LO,RE]{\\textsl{\\leftmark}} +\\lfoot{Copyright © 2024 Cambricon Corporation.} +\\rfoot{\\thepage} +\\renewcommand{\\headrulewidth}{0.4pt} +\\renewcommand{\\footrulewidth}{0.4pt} +} + +% 在 sphinx 生成的 tex 文件里,plain 是指每个章节的第一页等 +\\fancypagestyle{plain} { +\\fancyhf{} +% left head 还可以内嵌图片,图片可以是 pdf,png,jpeg 等 +% \\lhead{\\includegraphics[height=40pt]{cn_tm.pdf}} +\\lhead{\\large\\textcolor[rgb]{0.1804,0.4588,0.7137}{Cambricon®}} +\\lfoot{Copyright © 2024 Cambricon Corporation.} +\\rfoot{\\thepage} +\\renewcommand{\\headrulewidth}{0.4pt} +\\renewcommand{\\footrulewidth}{0.4pt} +} + +''', + + +# +'printindex': r'\footnotesize\raggedright\printindex', + + +# 移除空白页面 +'extraclassoptions': 'openany,oneside', + +# 如果需要用 latex 自已做封面,可以使用 maketitle +# 下面这个封面的例子来自于互联网 + +# 'maketitle': r''' +# \pagenumbering{Roman} %%% to avoid page 1 conflict with actual page 1 +# +# \begin{titlepage} +# \centering +# +# \vspace*{40mm} %%% * is used to give space from top +# \textbf{\Huge {Sphinx format for Latex and HTML}} +# +# \vspace{0mm} +# \begin{figure}[!h] +# \centering +# \includegraphics[width=0.8\textwidth]{cn.png} +# \end{figure} +# +# % \vspace{0mm} +# % \Large \textbf{{Meher Krishna Patel}} +# +# % \small Created on : Octorber, 2017 +# +# % \vspace*{0mm} +# % \small Last updated : \MonthYearFormat\today +# +# +# %% \vfill adds at the bottom +# % \vfill +# % \small \textit{More documents are freely available at }{\href{http://pythondsp.readthedocs.io/en/latest/pythondsp/toc.html}{PythonDSP}} +# \end{titlepage} +# +# \clearpage +# \pagenumbering{roman} +# \tableofcontents +# \listoffigures +# \listoftables +# \clearpage +# \pagenumbering{arabic} +# +# ''', +# +} # latex_elements + +import os.path as osp +import sys +import traceback + +try: + # conf_callback.py一般和conf.py放在同一目录下即可,如果有特殊需求请根据conf_callback.py的实际路径进行修改。 + if osp.exists(osp.join(osp.dirname(__file__), r'./conf_callback.py')): + sys.path.append(osp.join(osp.dirname(__file__), '.')) + else: + sys.path.append(osp.join(osp.dirname(__file__), '../')) + + from conf_callback import * + +except Exception as e: + # 如果出现了异常,请检查上面默认配置的两个路径是否满足要求,如果不满足要求,请修改上面的路径直到没异常发生。 + print(e) + traceback.print_exc() diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/conf_callback.py b/torch_mlu_ops-v1.3.2/docs/release_notes/conf_callback.py new file mode 100644 index 0000000..d142168 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/conf_callback.py @@ -0,0 +1,319 @@ +#!/usr/bin/python +# -*- coding: UTF-8 + +from __future__ import print_function + +# sphinx事件回调函数 +#如果copy到conf.py文件中,只copy下面的代码即可。 + +import os +import os.path as osp +import sys +import importlib +import traceback +import datetime as dt +import locale + +#如果文档编译失败,请检查下面默认的路径是否正确,请根据实际路径进行修改 + +if osp.exists(osp.join(osp.dirname(__file__), './parsejson.py')): + sys.path.append(osp.join(osp.dirname(__file__), '.')) +else: + sys.path.append(osp.join(osp.dirname(__file__), '../')) + +cusdirect = None +extpathisexist = False # ./_ext/customdirective.py是否存在的标志,用于设置指令扩展 + +extpath = osp.join(osp.dirname(__file__), './_ext') +if osp.exists(extpath): + # 自定义指令所在路径,加到搜索目录,否则无法使用自定义指令 + sys.path.append(extpath) + if osp.exists(osp.join(osp.dirname(__file__),r'./_ext/customdirective.py')): + extpathisexist = True + cusdirect = importlib.import_module('customdirective') +else: + extpath = osp.join(osp.dirname(__file__), '../_ext') + sys.path.append(extpath) + if osp.exists(osp.join(osp.dirname(__file__),r'../_ext/customdirective.py')): + extpathisexist = True + cusdirect = importlib.import_module('customdirective') + +import sphinx +import sphinx.errors as sperror +import parsejson + +try: + import breathe # 仅为了判断版本号,根据breathe版本号添加不同配置 + breathe_version = breathe.__version__[:3] +except Exception as e: + breathe_version = "" + +warnfile = '' # 告警文件,不包含路径 +warnfilepath = '' # 保存告警日志文件,包含完整的路径名 + +months=['January','February','March','April','May','June','July', + 'August','September','October','Novmber','December'] + +gmaketitle = r''' + \pagenumbering{Roman} + \begin{titlepage} + \centering + \vspace*{40mm} + \textbf{\Huge {%(titlename)s}} + + \vspace{10mm} + \textbf{\Large{%(release)s}} + + \vfill + \textbf{\large{%(today)s}} + \end{titlepage} + ''' + +def __ModifyMakeTitle(elementskeys,config): + ''' + 如果有maketitle变量则修改maketitle的内容 + :param config: 配置变量 + :return: + ''' + #得到封面名称 + titlename = config.latex_documents[0][2] + + if 'releasename' in elementskeys: + releasename = config.latex_elements['releasename'] + else: + releasename = '' + + if hasattr(config,'version'): + version = config.version + else: + version = '' + release = releasename +' ' + version + if hasattr(config,'today') and len(config.today) > 0: + today = config.today + else: + todayft = dt.date.today() + if hasattr(config,'language') and config.language=='zh_CN': + today = todayft.strftime('%Y 年 %m 月 %d 日') + else: + #除中文外,一律用英文时间格式 + #得到英文月的全称。 + month = months[int(todayft.strftime('%m'))-1] + today = month + todayft.strftime(" %d, %Y") + + if len(config.latex_elements['maketitle'].strip())==0: + maketitle = gmaketitle + else: + maketitle = config.latex_elements['maketitle'] + #必须转一下,否则赋值不成功 + config.latex_elements['maketitle'] = maketitle % { + 'titlename': titlename, + 'release': release, + 'today': today + } + +def config_inited_handler(app, config): + # 检查年份是否正确,并修改。需配合最新parsejson.py文件使用,否则不支持该接口。 + # 通用配置的修改放在该位置,否则对配置的修改不生效 + try: + # 该函数的第二个参数可以传入版权声明文件相对于conf.py的相对路径,自动修改版权声明的年份 + # 比如:parsejson.CheckCurrentYearAndModify(app,'../source/copyright/cnperf_conpyright.rst') + # 默认路径为“./copyright/copyright.rst”和“./copyright/copyright_zh.rst”或者“./copyright/copyright_en.rst” + # 以上三个路径无需传入 + parsejson.CheckCurrentYearAndModify(app) + except Exception as e: + print('------------------') + print(e) + # traceback.print_stack() + traceback.print_exc() + print('------------------') + + try: + # print(sphinx.__version__) + # app.require_sphinx('3.5') + keys = config.latex_elements.keys() + if 'sphinxsetup' not in keys: + config.latex_elements['sphinxsetup'] = '' + if "3.5" <= sphinx.__display_version__[:3]: + config.latex_elements['sphinxsetup'] = 'verbatimforcewraps=true,verbatimmaxunderfull=2,' + \ + config.latex_elements['sphinxsetup'] + if "5.3" <= sphinx.__display_version__[:3]: + config.latex_table_style = ['standard', 'nocolorrows'] + config.latex_elements['sphinxsetup'] += 'pre_border-radius=0pt,' + if len(breathe_version) > 0 and "4.3" <= breathe_version: + config.latex_elements['preamble'] += ''' + \\renewenvironment{description} + {\\list{}{\\labelwidth=0pt + \\let\\makelabel\\descriptionlabel}} + {\\endlist} + ''' + #如果有自定义maketitle,则修改maketitle的内容 + if 'maketitle' in keys: + __ModifyMakeTitle(keys,config) + + # sphinxsetup = config.latex_elements['sphinxsetup'] + # print('sphinxversion:%s;sphinxsetup:%s' % (sphinx.__version__,sphinxsetup)) + except Exception as e: + # print('sphinxversion:%s %s' % (sphinx.__version__,sphinx.__display_version__[:3])) + pass + + +def build_finished_handler(app, exception): + if exception != None: + # print(exception) + return + + # 判断告警文件是否存在,只有无告警或者告警全是忽略告警才允许继续后续的编译 + if warnfilepath != '' and osp.exists(warnfilepath): + # 判断告警文件中是否全是忽略告警 + iswarn = parsejson.warn_main(warnfilepath, app) + if iswarn: + # 如果为True则说明有不可忽略的告警,报sphinxerror异常,停止继续编译 + raise sperror.SphinxError('There are alarms, please check the file of %s for details' % warnfile) + return + + try: + if app.builder.name == "latex": + selffnlst = app.config.latex_documents + parsejson.Modifylatex_main(app.outdir, selffnlst, app) + except Exception as e: + print('------------------') + print(e) + traceback.print_exc() + + # 检查html标题一致性。该接口最好放在该位置,主索引文件的标题已经解析出来。 + # if app.builder.name == "html": + # result = parsejson.CheckProjectNameIsConsistent(app) + # if result != "": + # raise sperror.SphinxError(result) #带raise的异常如果用except捕捉,则在终端打印告警将不能显示为红色字体。 + + +def build_inited_handler(app): + global warnfile + global warnfilepath + + print(sys.argv) + args = sys.argv[1:] # 0为sphinx-build,需忽略掉 + if '-w' in args: + pos = args.index('-w') # 找到-w所在的索引位置 + warnfile = args[pos + 1] # 得到告警保存的文件名 + # print('warnfile=' + warnfile) + + # 根据工作路径,得到文件名的绝对路径 + # 当前在build阶段,因此工作路径为Makefile所在的目录,-w后面的文件保存在基于Makefile的相对路径下 + filepath = osp.join(os.getcwd(), warnfile) + warnfilepath = osp.abspath(filepath) + # print('warnfilepath = ' + warnfilepath) + + # try: + # # 检查是否有rst_prolog或者rst_epilog替换内容,有的话去掉前后空格 + # # 仅适用于pdf文件和中文文档,英文文档和html不启作用 + # checkDocobj = parsejson.clsCheckDocCopyrightYear(app, "") + # if app.builder.name == 'latex' or app.builder.name == 'latexpdf': + # checkDocobj.CheckReplaceContent() + # except Exception as e: + # print('------------------') + # print(e) + # traceback.print_exc() + # print('------------------') + + # try: + # # 检查html配置是否符合规范。必须放在这个位置,否则app没有builder对象。 + # if app.builder.name == "html": + # error = parsejson.CheckHtmlConfigIsCorrect(app) + # if error != "": + # raise sperror.ConfigError(error) + # except Exception as e: + # print('------------------') + # print(e) + # traceback.print_stack() + # traceback.print_exc() + # print('------------------') + + +def source_read_handler(app, docname, source): + if cusdirect is not None: + cnonlyobj = cusdirect.CNOnlyPro(app, docname, source[0]) + source[0] = cnonlyobj.parsecnonlycontent(source[0]) + + # try: + # 自动添加更新历史代码,默认添加“无内容更新” + # source[0] = parsejson.CheckUpdateHistory(app, docname, source[0]) + # except Exception as e: + # print('------------------') + # print(e) + # traceback.print_exc() + # print('------------------') + + return source + + +def doctree_resolved_handle(app, doctree, docname): + """ + 删除不需要文件的内容,避免编译html的时候还进行编译,导致内容泄漏 + """ + if not hasattr(app.config, 'enable_exclude_html') or not app.config.enable_exclude_html: + # 默认按sphinx方式编译所有rst文件到html + return + + indexname = app.config.master_doc + if docname == indexname: + return + indexdocnames = app.env.toctree_includes[indexname] + isexist = False + # 判断文档是否在索引文件里 + if docname not in indexdocnames: + newtoctree = app.env.toctree_includes.copy() + del newtoctree[indexname] # 删除主索引文件,因为无需比较 + for key in newtoctree.keys(): + values = newtoctree[key] + if docname in values: + isexist = True + # 判断该key是否在主索引文件里面 + if key not in indexdocnames: + # 如果不在整个的索引文件里面,删除该文件内容 + # doctree.attributes['source']='' + while len(doctree.children) > 0: + doctree.remove(doctree.children[0]) + # 删除标题,否则html页面还有标题,还需要维护标题 + for node in app.env.titles[docname].children: + app.env.titles[docname].remove(node) + if not isexist: + # 文档不在toctree字典里,直接删除内容 + while len(doctree.children) > 0: + doctree.remove(doctree.children[0]) + # 删除标题,否则html页面还有标题,还需要维护标题 + for node in app.env.titles[docname].children: + app.env.titles[docname].remove(node) + + +def setup(app): + """ + 该函数中的路径都是相对于工程builder环境的路径而不是相对于conf.py所在的路径, + 该函数外面的相对路径是相对于conf.py的路径没有问题。 + 比如./_ext相对路径是相对于工程的路径,因此如果使用abspath得到绝对路径会得到~/m0_docs/_ext/, + 而不是~ / mo_docs / source / _ext。 + 因此该路径下对路径的判断最好转化为绝对路径判断,否则可能会判断失败。 + :param app: + :return: + """ + app.connect('config-inited', config_inited_handler) + app.connect('build-finished', build_finished_handler) + app.connect('builder-inited', build_inited_handler) + app.connect('source-read', source_read_handler) + app.connect('doctree-resolved', doctree_resolved_handle) + + app.add_config_value('chapterbkpaper_image', '', 'env', [str]) + # 是否排除未添加到index的html文件。在编译html时,sphinx默认会将所有rst文件都生成html,因此可能会造成部分内容的泄露。 + # 为了删除这部分内容,设置该变量。如果置为True的话,将删除不在index.rst里面的html的内容。 + # 默认为false,不做处理,以增加编译效率。 + # 建议还是将不参与编译的rst文件,放到exclude_patterns变量内,放到该变量后,rst不参与编译,能提高编译效率。 + app.add_config_value('enable_exclude_html', False, 'env', [bool]) + #配置版权声明文件完整路径,实现版权声明文件年份的自动修改。 + #parsejson.CheckCurrentYearAndModify如果该函数中传入了版权声明的全路径,则以该函数中传入的为准 + #在该处又增加了一个配置值,是为了配置和代码解偶,为了兼容性,保留了原来的配置方式 + app.add_config_value('copyfile_path', '', 'env', [str]) + #在conf.py中配置的conf.json路径,为了实现conf.json配置的灵活处理 + app.add_config_value('confjson_path', '', 'env', [str]) + + if extpathisexist: + app.setup_extension('customdirective') diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/index.rst b/torch_mlu_ops-v1.3.2/docs/release_notes/index.rst new file mode 100644 index 0000000..96d15e8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/index.rst @@ -0,0 +1,10 @@ + +寒武纪Torch-MLU-Ops版本说明书 +=========================================== + +.. toctree:: + :maxdepth: 4 + :caption: 目录 + + ./overview/overview + ./version diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/logo.png b/torch_mlu_ops-v1.3.2/docs/release_notes/logo.png new file mode 100644 index 0000000..7b6ac72 Binary files /dev/null and b/torch_mlu_ops-v1.3.2/docs/release_notes/logo.png differ diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/makelatexpdf.sh b/torch_mlu_ops-v1.3.2/docs/release_notes/makelatexpdf.sh new file mode 100755 index 0000000..2e0a7c1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/makelatexpdf.sh @@ -0,0 +1,5 @@ +#! /bin/bash + +make clean +make latexpdf +make html&&zip -qr -P"Cambricon" build/torch_mlu_ops_release_notes_html.zip build/html diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/overview/overview.rst b/torch_mlu_ops-v1.3.2/docs/release_notes/overview/overview.rst new file mode 100644 index 0000000..baf2db1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/overview/overview.rst @@ -0,0 +1,5 @@ + +概述 +==================== + +Torch-MLU-Ops是寒武纪设计和开发的PyTorch第三方算子库。对于使用PyTorch框架的开发者,通过Torch-MLU-Ops,能够便捷地使用这些自定义算子,进行算子的集成、评测和业务部署。 diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/parsejson.py b/torch_mlu_ops-v1.3.2/docs/release_notes/parsejson.py new file mode 100644 index 0000000..40c460d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/parsejson.py @@ -0,0 +1,3552 @@ +#!/usr/bin/python +# -*- coding: UTF-8 +''' +注意: +修改该文件,print打印信息中不能包含中文,否则在不支持中文的操作系统中会出现如下的错误, +导致编译失败或者达不到预期的显示效果。 +'ascii' codec can't encode characters in position + +2023年12月04日 修改历史 +------------------------------------ + +主要修改以下bug: +1. 当表格通过class属性设置为longtable,但是表头用“-”分割而不是用“=”分割的时候,在sphinx5.3.0以下版本编译失败的问题。 + 如果没有替换最新的parsejson.py文件,该bug在sphinx 5.3.0以下版本,可以采用以下方式规避: + 当表格设置为longtable时,表头改用“=”分割,而不是“-”分割,可以规避该问题。 + 如果替换了最新的parsejson.py文件,则不存在该问题,无需刻意规避。 +2. 修复了当表头为两行,并且合并列单元格放在表格的最前几列的时候,表头中的横线制表符没有画的bug。 + +优化了以下实现方式: +1. 将sphinx 5.3.0之后版本,为了适配表格改动所做的修改,固定在该文件的ModifyReplacePackage函数中。 + 这样只需替换该文件,而无需再修改conf.json文件,即可兼容sphinx5.3.0之前及之后版本,做到全版本兼容。 +2. 当有合并单元格的表头时,如果用“widths”参数设置的列宽,则表头中的标题不会自动换行,必须手动换行。 + 如果用“tabularcolumns”指令设置的列宽,则默认会自动换行。之前的行为是默认都会自动换行。 + 因此当前的实现要想使表头能够自动换行,必须采用tabularcolumns指令设置每一列的列宽。如果用“widths”设置列宽,只能手动换行。 + 修改原因: + 因为如果要实现自动换行,每一列的列宽必须以“m{0.1\textwidth}”这种方式实现,之前是把“widths”设置的列宽进行了转换。 + 这种转换对于列比较少的表格没什么问题。但是当表格比较大,列比较多时,这种转换就导致表格的显示超出了页面可显示的宽度, + 而且也会导致如何修改“widths”的值,都没什么效果。因此修改了自动换行的规则。 + 结论: + 对于列比较多而且表头有合并单元格的表格,建议用“tabularcolumns”设置每一列的列宽,因为该指令相当于latex的原始指令,设置的列宽直接反应每一列的实际宽度。 + 对于表头没有合并单元格的表格,建议用“widths”设置列宽比较简单,让sphinx转化实际的列宽,同时对latex和html都生效。 + +2023年11月15日 修改历史 +---------------------------- +将路径字符串相加的赋值方式,修改为os.path.join的方式,解决编译文档提示如下告警的问题: +RemovedInSphinx80Warning: Sphinx 8 will drop support for representing paths as strings. Use "pathlib.Path" or "os.fspath" instead. +sphinx 8之后将不再支持字符串表示路径。 +使用os.path.join的方式一定要特别注意以下几点: +1.join函数会自动添加“/”符号,因此一定不要再添加多余的“/”,否则可能得到的路径是错误的。 +2.join可以链接超过两个的路径,但是最后一个要链接的路径一定不能加“/”符号,否则前面的路径都失效。举例如下: + path1 = "home" + path2 = "rst" + path3 = "/meerkat" + path = os.path.join(path1,path2,path3) + 则得到的path为“/meerkat”,因此最后一个被链接的路径一定不能加“/”。 + +''' + +from __future__ import print_function + +import os +import sys +import json +import re +import codecs +import shutil +import platform +import traceback +import datetime as dt +from subprocess import Popen, PIPE, STDOUT +#import importlib + +try: + import sphinx + sphinx_version = sphinx.__display_version__[0:3] +except Exception as e: + sphinx_version = "" + +Py_version = sys.version_info +Py_v_info = str(Py_version.major) + '.' + str(Py_version.minor) + '.' + str(Py_version.micro) + +#if Py_version >= (3,5,0) and Py_version <(3,7,0): +# print(False) +# +#if Py_version >= (3,7,0): +# print(True) +# +#if platform.system().lower() == 'windows': +# print("windows") +#elif platform.system().lower() == 'linux': +# print("linux") + +filepathlist=os.path.split(os.path.realpath(__file__)) +currfilepath=filepathlist[0] +latexfontsize={'tiny','scriptsize','footnotesize','small','normalsize','large','Large','LARGE','huge','Huge'} + +def __dictIsHasKey__(dict,key): + if Py_version >= (3,0,0): + return dict.__contains__(key) + else: + return dict.has_key(key) + +def __openconfjsonfile__(filepath = ''): + jsonfile = filepath + if filepath == '': + jsonfile = currfilepath+'/conf.json' + + with codecs.open(jsonfile,"r+",encoding = 'utf-8') as load_f: + load_dict = json.load(load_f) + return load_dict + +def getignorewarning(): + return __load_dict__["ignorewarnkey"] + +def GetOtherIncludePackage(): +# packagearr = __load_dict__['package'] +# for package in packagearr: +# print(package) + return __load_dict__['package'] + +def GetReplacePackage(): + return __load_dict__['replacepackage'] + +def GetCustomOptions(): + return __load_dict__['customoptions'] + +def GetIsTocContents(): + return __load_dict__['isfiguretabletoc'] + +def GetSensitiveword(): + #得到敏感词数组,便于搜索文档中是否有敏感词存在 + return __load_dict__['sensitivewords'] + +def GetTablesContent(): + return __load_dict__['tables'] + +def GetTablerowtype(): +# packagearr = __load_dict__['tables']['rowtype'] +# print(packagearr) + return __load_dict__['tables']['rowtype'] + +def GetTableheadtype(): +# packagearr = __load_dict__['tables']['headtype'] +# print(packagearr) + return __load_dict__['tables']['headtype'] + +def GetTableHeadFontColor(): + return __load_dict__['tables']['headfontcolor'] + +def GetTableStylesArr(): +# packagearr = __load_dict__['tables']['styles'] +# for package in packagearr: +# print(package) + return __load_dict__['tables']['styles'] + +def GetImageStyleArr(): +# packagearr = __load_dict__['image']['styles'] +# for package in packagearr: +# print(package) + return __load_dict__['image']['styles'] + +#判断是否有忽略告警的类 +class clsIgnoreWarn: + + def __init__(self,warnfile): + self.warnfile = warnfile #保存告警文件 + + #判断该关键字是否在告警中,在告警中则返回true,不在告警中则返回fasle + def __JudgeWarnKey(self,keystr,warnstr): + + #先判断是否为组合关键字 + if "&&" in keystr: + #解析组合关键字为字符串列表 + keyls = keystr.split("&&") + isignore = True #默认该告警为忽略,如果组合关键字其中一个关键字没匹配上,则该告警不能忽略。 + for i in range(0,len(keyls)): + if keyls[i].strip().lower() not in warnstr.lower(): + isignore =False + break + return isignore + else: + if keystr.lower() in warnstr.lower(): + return True #忽略告警在告警字符串里面,则返回true,表示该告警可以忽略 + else: + return False + + #解析告警文件 + def parsewarnfile(self,ignorekey): + ''' + #make latex命令产生的告警都会保存在stderr文件中,只要判断该文件中的告警是否在忽略告警里就可以了,不在忽略告警里,则肯定有错误,停止执行 + ''' + if not os.path.exists(self.warnfile): + return True + + fs = codecs.open(self.warnfile,"r",encoding = 'utf-8') + fstr = fs.read() #先将所有内容读取到字符串中,方便正则表达式查找 + fs.close() + + #查找带warning的内容 + pattern = re.compile(r"([\s\S].*)[WARNING:|ERROR:]([\s\S].*)", re.I | re.U | re.M) + mobjarr = pattern.finditer(fstr) + + errstr = '' #保存不能被忽略的告警 + isNotError = True + for mobj in mobjarr: + amarr = mobj.group() + + if amarr == '': + continue + + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + + #判断该告警是否在忽略列表里,不在忽略列表里,保存在errstr字符串中 + + if ("WARNING:" not in amarr) and ("ERROR:" not in amarr): #直接忽略 + continue + + isWrite = False #默认不能忽略 + for igkey in ignorekey: + isWrite = self.__JudgeWarnKey(igkey,amarr) + if isWrite: + break + + if isWrite is False: + #写入stderr文件 + isNotError = False + #fe.writelines(amarr+'\n') + errstr += amarr + + if errstr != '': + #如果有不能被忽略的告警,则将不能忽略的告警重新写入warnfile文件中 + #先删除源文件,避免原文件无法写入 + fw = codecs.open(self.warnfile, "w",encoding = 'utf-8') + fw.write(errstr) + fw.close() + else: + #如果所有的告警都能忽略则删除之前的告警文件 + #该功能暂时不实现,还是保存原始的告警文件,方便查找。 + #os.remove(self.warnfile) + pass + + return isNotError + + #该函数暂时未用 + def __parsepdfarg(stdout,stderr,ignorelist,ignorekey): + ''' + make all-pdf命令产生的所有输出都只会保存在stdout中,因此从stdout中判断是否有告警内容 + ''' + stdoutfilepath = currfilepath+'/'+stdout + stderrfilepath = currfilepath+'/'+stderr + + if not os.path.exists(stdoutfilepath): + return True + + fs = codecs.open(stdoutfilepath, "r+",encoding = 'utf-8') + fstr = fs.read() #先将所有内容读取到字符串中,方便正则表达式查找 + fs.close + #查找latexmk的位置,latexmk位置的开始即make all-pdf打印输出的开始 + searchstr = r"latexmk" + m = re.search(searchstr, fstr, re.I|re.U ) + if m == None: + return True + + spos = m.span() #获取位置 + latexcont = fstr[spos[0]:len(fstr)] #获取到make all-pdf产生的内容 + + #查找带warning的内容 + pattern = re.compile(r'([\s\S].*)Warning([\s\S].*)', re.I|re.U) + mobjarr = pattern.finditer(latexcont) + + #打开stderr文件,方便后续写 + fe = codecs.open(stderrfilepath, "a+",encoding = 'utf-8') + isNotError = True + for mobj in mobjarr: + amarr = mobj.group() + if amarr == '': + continue + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + + #判断该告警是否在忽略列表里,不在忽略列表里,要写入stderr文件 + if amarr.lower() not in [elem.lower() for elem in ignorelist]: #如果错误不在忽略告警里面 + #如果不能整行匹配,再判断该行是否包含需忽略的告警的关键字 + isWrite = False #默认不能忽略 + for igkey in ignorekey: + + isWrite = self.__JudgeWarnKey(igkey,amarr) + if isWrite: + break; + + if isWrite == False: + #写入stderr文件 + isNotError = False + fe.writelines(amarr) + fe.close + + return isNotError + +def warn_main(warnfile,app=None): + + global __load_dict__ + + if len(__load_dict__) ==0: + GetConfjsondict(app) + + ignorekey = getignorewarning() + clswarn = clsIgnoreWarn(warnfile) + + if not clswarn.parsewarnfile(ignorekey): + #如果存在不可忽略的告警,则返回True,否则返回False + return True + + return False + + +class clsTableattr: + + def __init__(self, tables): + self.rowtype = tables['rowtype'] + if __dictIsHasKey__(tables,'headtype'): + self.headtype = tables['headtype'] + else: + self.headtype = "" + if __dictIsHasKey__(tables,'headfontcolor'): + self.headfontcolor = tables['headfontcolor'] + else: + self.headfontcolor = "" + if __dictIsHasKey__(tables,'styles'): + self.tablestyles = tables['styles'] + else: + self.tablestyles = None + if __dictIsHasKey__(tables,'isname'): + self.isname = tables['isname'] + else: + self.isname = False + if __dictIsHasKey__(tables,'fontsize'): + self.fontsize = tables['fontsize'] + else: + self.fontsize = "" + + if __dictIsHasKey__(tables,'ismulticolleftalign'): + self.ismulticolleftalign = tables['ismulticolleftalign'] + else: + self.ismulticolleftalign = False + + if __dictIsHasKey__(tables,'ismultirowautowrap'): + self.ismultirowautowrap = tables['ismultirowautowrap'] + else: + self.ismultirowautowrap = True + + if __dictIsHasKey__(tables,'ismultirowatupcell'): + self.ismultirowatupcell = tables['ismultirowatupcell'] + else: + self.ismultirowatupcell = False + + if __dictIsHasKey__(tables,'ismultirowatlowcell'): + self.ismultirowatlowcell = tables['ismultirowatlowcell'] + else: + self.ismultirowatlowcell = False + +class clsModifyTex: + + def __init__(self, content,app=None): + self.content = content + self.app = app + self.tablesattrobj = clsTableattr(GetTablesContent()) + + #构建模板字符串 + #自动换行模板 + self.autolinetemplat =r'''\renewcommand{\arraystretch}{0.8} + \begin{tabular}[c]{@{}l@{}}%(autocontent)s\end{tabular}''' + + fontcolor = self.tablesattrobj.headfontcolor + fontcolor = fontcolor.replace('{}', '{%(content)s}', 1) + self.commonstr = r'''%(sphinxstyletheadfamily)s\cellcolor''' + \ + self.tablesattrobj.headtype + "\n{" + fontcolor + "}" + + # multirow模板 + self.multirowstr = r'''\multirow{-%(count)s}{%(coluwidth)s}%(raggedright)s{ + ''' + self.commonstr + "}" + + self.simplemultirow = r'''\multirow{%(count)s}{%(coluwidth)s}{\cellcolor''' + \ + self.tablesattrobj.headtype + "}" + + # multicolumn模板 + self.multicolumnstr = r'''\multicolumn{%(count)s}{%(flag)s}{%(prefix)s''' + \ + self.commonstr + r'''%(suffix)s''' + + # 为了普通标题文本看起来和使用了multicolum你的位置一致,对于非multirow和multicolumn单元格的内容,都用multicolumn{1}{|l|}标志 + self.normalstr = r'''\multicolumn{1}{%(flag)s}{ + \begin{varwidth}[c]{%(colwidth)s} + ''' + self.commonstr + r''' +\par +\vskip-\baselineskip\vbox{\hbox{\strut}}\end{varwidth}}''' + self.comcolstr = r'''\multicolumn{%(count)s}{%(flag)s}{''' + self.commonstr +"}" + # 空的单元格内容,multirow位置变换后,会有空的单元格 + self.emptystr = r"\cellcolor" + self.tablesattrobj.headtype+"{}" + + #删除字符串列表。有时候sphinx生成的文本内容里会带有特殊的latex指令,该指令不支持设置字体颜色,需要删除。删除不影响内容显示。 + self.delstrlst=[r'\begin{quote}',r'\end{quote}'] + + def __checktpacknameandparam(self,packlist,packname,*packarg): + ''' + 检查指定的包是否需要增加或者修改参数 + :param packlist: 已经读出的包含的package + :param packname: 要修改或者添加的包 + :param packarg: 要修改或者添加的包参数,可变参数。不会覆盖已有参数 + ''' + if len(packarg) == 0: + #如果没有参数则直接添加该包 + loadpackagestr = "\\\\usepackage{" + packname + "}" + if loadpackagestr not in packlist: + packlist.append(loadpackagestr) + return + + params = ",".join(packarg) + loadpackagestr = "\\\\usepackage[" + params + "]{" + packname + "}" + #print(loadpackagestr) + packstr = ",".join(packlist) + #print(packstr) + searchstr = r'\\\\usepackage(\[[\S, ][^\\]+\])?\{'+packname+'\}' + match = re.search(searchstr,packstr) + if match is None: + #如果根本没有加载packname包,则加载该包 + packlist.append(loadpackagestr) + return + latexpack = match.group() + #print(latexpack) + if match.group(1) is None: + #如果该包没有携带参数,则重新加载该包 + #得到老包的位置,尽量原位置插入,避免因为位置问题导致编译失败 + index = packlist.index(latexpack) + del packlist[index] + #再重新加载新的包 + packlist.insert(index,loadpackagestr) + return + #如果携带了参数,则更新该参数 + #为了兼容后续新添加参数,并防止覆盖已有参数,将每个参数进行比较添加 + paramstr = match.group(1) + searchstr = '\[([\S, ]+)?\]' + match = re.search(searchstr,paramstr) + if match.group(1).strip() == params: + #print('-------------enter-----------') + #为了提高运行效率,两个参数相等直接返回,不再走下面的循环 + return + params = match.group(1) + for param in packarg: + if param not in params: + params += "," + param + #print(params) + loadpackagestr = "\\\\usepackage[" + params+"]{" + packname +"}" + #先删除旧包,原位置插入,避免因位置错误导致的编译失败。 + index = packlist.index(latexpack) + del packlist[index] + packlist.insert(index,loadpackagestr) + + #加入其它包 + def AddPackageToTex(self): + #得到需要包的数组 + packarr = GetOtherIncludePackage() + if len(packarr)==0: + return False + + self.__checktpacknameandparam(packarr,"multirow") + self.__checktpacknameandparam(packarr,"tcolorbox","most","skins","breakable") + + #如果数组有内容,就需要将包添加到latex文件的导言区 + #搜索\usepackage{sphinx},将包加在它的前面,用正则表达式搜索它的位置 + #采用正则表达式替换的方式,替换搜索到的字符串,因此需要先构建字符串 + #python认为\u后面为unicode字符,因此需要多加一个转义字符\,python才认为是整个的字符串 + #searchstr = r'\\usepackage\[dontkeepoldnames\]{sphinx}' + searchstr = r'\\usepackage(\[\S*\]*)?{sphinx}' + matchstr = re.search(searchstr,self.content) + + replacestr="" + for package in packarr: + replacestr += package+'\n' + + if Py_version >= (3,7,0): + replacestr += "\\" + matchstr.group(0) + else: + replacestr += matchstr.group(0) + + + self.content = re.sub(searchstr, replacestr, self.content, 1, re.M | re.I|re.U) + + return True + + #加入自定义选项,包放在了sphinx包的前面,因此选项放在sphinx包的后面 + def AddCustormOptionsToTex(self): + #得到需要包的数组 + packarr = GetCustomOptions() + if len(packarr)==0: + return False + + #如果数组有内容,就需要将包添加到latex文件的导言区 + #搜索\usepackage{sphinx},将自定义参数放在它的后面,用正则表达式搜索它的位置 + #采用正则表达式替换的方式,替换搜索到的字符串,因此需要先构建字符串 + #python认为\u后面为unicode字符,因此需要多加一个转义字符\,python才认为是整个的字符串 + searchstr = r'\\usepackage(\[\S*\]*)?{sphinx}' + matchstr = re.search(searchstr,self.content) + + replacestr="" + for package in packarr: + replacestr += package+'\n' + + if Py_version >= (3,7,0): + replacestr = "\\" + matchstr.group(0)+'\n'+replacestr + else: + replacestr = matchstr.group(0)+'\n'+replacestr + + self.content = re.sub(searchstr, replacestr, self.content, 1, re.M | re.I|re.U) + return True + + #增加figure和table toc到tex + def AddOtherTocToTex(self): + #得到需要包的数组 + packarr = GetIsTocContents() + if len(packarr)==0: + return False + + replacestr = "" + if packarr['isfigurestoc']: + figlst = packarr['figurestoc'] + for figstr in figlst: + replacestr += figstr + '\n' + + if packarr['istablestoc']: + figlst = packarr['tablestoc'] + for figstr in figlst: + replacestr += figstr + '\n' + if replacestr == "": + return + + #如果数组有内容,就需要将包添加到latex文件的导言区 + #搜索\usepackage{sphinx},将包加在它的前面,用正则表达式搜索它的位置 + #采用正则表达式替换的方式,替换搜索到的字符串,因此需要先构建字符串 + #python认为\u后面为unicode字符,因此需要多加一个转义字符\,python才认为是整个的字符串 + searchstr = r'\\sphinxtableofcontents' + matchstr = re.search(searchstr,self.content) + + if Py_version >= (3,7,0): + replacestr = "\\" + matchstr.group(0) + '\n' + replacestr + else: + replacestr = matchstr.group(0) + '\n' + replacestr + + self.content = re.sub(searchstr, replacestr, self.content, 1, re.M | re.I|re.U) + return True + + def __IsRequireReplace(self, value): + """ + 根据sphinx版本判断是否需要替换,有两种写法: + 1. {sphinx:5.3.0}:这种写法是当前sphinx版本大于等于5.3.0版本时,才执行替换。 + 2. {5.3.0:sphinx}:这种写法是当前sphinx版本小于5.3.0时,才进行替换 + :return: 返回True需要替换,和需要替换的实际value;或者返回fasle,不需要替换 + """ + #根据正则表达式得到需要替换的sphinx的最小版本号 + if len(sphinx_version)==0: + return True,value + #按大于等于版本号进行匹配查找 + searchstr = r"\{sphinx:([0-9\.]+)\}([\s\S]+)" + match = re.match(searchstr,value,re.I | re.U) + if match is not None: + #print(match.group(1),match.group(2)) + try: + require_version = match.group(1)[:3] + if sphinx_version >= require_version: + return True, match.group(2) + else: + return False, "" + except Exception as e: + print(e) + return True,value + + else: + #按小于版本号进行匹配查找 + searchstr = r"\{([0-9\.]+):sphinx\}([\s\S]+)" + match = re.match(searchstr, value, re.I | re.U) + if match is not None: + try: + require_version = match.group(1)[:3] + if sphinx_version < require_version: + return True, match.group(2) + else: + return False, "" + except Exception as e: + print(e) + return True, value + else: + return True,value + + + def ModifyReplacePackage(self): + #得到字典值 + #得到需要替换的包,用正则表达式替换 + redict = GetReplacePackage() + if len(redict) ==0: + return + + #为了兼容sphinx5.3.0版本,固定添加以下配置 + keys = redict.keys() + if "\\\\sphinxtoprule" not in keys: + redict["\\\\sphinxtoprule"] = "\\\\hline" + if "\\\\sphinxmidrule" not in keys: + redict["\\\\sphinxmidrule"] = "\\\\hline" + if "\\\\sphinxbottomrule" not in keys: + redict["\\\\sphinxbottomrule"] = "\\\\hline" + if "\\\\sphinxhline" not in keys: + redict["\\\\sphinxhline"] = "\\\\hline" + if "\\\\sphinxhyphen" not in keys: + redict["\\\\sphinxhyphen"] = "{sphinx:5.3.0}\\\\PYGZhy" + + #返回字典中所有键值的列表 + keylst = list(redict) + for key in keylst: + if key == 'comment' : + continue + keyvalue = redict[key] #得到键对应的值 + result,keyvalue = self.__IsRequireReplace(keyvalue) + if result : + #对键值进行替换 + self.content = re.sub(key, keyvalue, self.content, 0, re.I|re.U) + return + + def __GetReplacedContent(self,afterstr,replacelst): + """ + 对需要替换的内容进行替换 + :param afterstr: 需要替换的字符串 + :return: + """ + if len(afterstr) ==0 or len(replacelst)==0: + return afterstr + for replacelist in replacelst: + if len(replacelist) !=2: + continue + replacedstr = replacelist[0] + replacestr = replacelist[1] + if len(replacedstr) ==0 or len(replacestr)==0: + continue + afterstr = re.sub(replacedstr, replacestr, afterstr, 0, re.U | re.M) + + return afterstr + + def __ParseStyleForSpecialFunction(self,isshaded,apidict,linecontent): + + if isshaded: + envstartstr = "\n\\begin{shaded}\n" # 环境起始字符串 + envendstr = "\n\\end{shaded}\n" # 环境结束字符串 + else: + envstartstr = "" # 环境起始字符串 + envendstr = "" # 环境结束字符串 + + afterstr = linecontent + + if apidict is None or len(apidict)==0: + return envstartstr,envendstr,afterstr + + if not __dictIsHasKey__(apidict, "isenter") or \ + (__dictIsHasKey__(apidict, "isenter") and \ + apidict['isenter'] is True): + afterstr = self.__strreplacepos(afterstr,r"pysiglinewithargsret",r",",r",\\") + if __dictIsHasKey__(apidict, "istcolorbox") and \ + apidict['istcolorbox'] is True: + envstartstr, envendstr = self.__GettcolorboxForFunction(apidict) + elif __dictIsHasKey__(apidict, "fontsize") and \ + apidict['fontsize'] in latexfontsize: + envstartstr += "\\" + apidict['fontsize'] + "\n" + + #检查是否有替换,有替换的话,对内容进行替换 + if __dictIsHasKey__(apidict, "replace"): + replacelst = apidict['replace'] + if len(replacelst) > 0: + afterstr = self.__GetReplacedContent(afterstr,replacelst) + + return envstartstr,envendstr,afterstr + + def __ModifyFunctionBkColorByPython(self,isshaded,apilst,newcontent,linecontent,curpos,prepos): + #根据python生成的格式为类和函数声明添加灰底,并按逗号换行 + + ''' + 根据正则表达式获得以\pysiglinewithargsret开头,以“{}”结尾的中间字符串 + 对中间字符串添加灰底和换行 + ''' + if isshaded: + envstartstr = "\n\\begin{shaded}\n" # 环境起始字符串 + envendstr = "\n\\end{shaded}\n" # 环境结束字符串 + else: + envstartstr = "" # 环境起始字符串 + envendstr = "" # 环境结束字符串 + + startmultiline = '\\pysigstartmultiline\n' + stopmultiline = '\n\\pysigstopmultiline' + #searchstr = r'(?=\\pysiglinewithargsret).*(?<={})' + searchstr=r'(?=\\pysiglinewithargsret).*' + match = re.search(searchstr, linecontent, re.I|re.U) + if match is not None: + #重新组合linecontent + #pos = match.span() + isapi, apidict = self.__IsSpecialFunction(apilst, linecontent) + if isapi is True and apidict is not None: + envstartstr,envendstr,afterstr = self.__ParseStyleForSpecialFunction(isshaded,apidict,match.group()) + newstr = '\n' + linecontent[:match.start()] + envstartstr + startmultiline + afterstr+ stopmultiline + envendstr+ linecontent[match.end():len(linecontent)] + else: + newstr = '\n' + linecontent[:match.start()] + envstartstr+ startmultiline + match.group().replace(r",",r",\\") + stopmultiline + envendstr + linecontent[match.end():len(linecontent)] + #计算替换前的内容 + if len(prepos)==0: + newcontent = self.content[:curpos[0]-1] + newstr + else: + newcontent += self.content[prepos[1]:curpos[0]-1] + newstr + + return newcontent + + def __strreplacepos(self,srcstr,posstr,oldstr,newstr): + + # 从指定字符串开始对行内字符串进行替换 + ''' + srcstr:原始字符串 + posstr:从该字符串开始进行替换 + oldstr:被替换字符串 + newstr:替换字符串 + ''' + #查找字符串的起始位置 + pos = srcstr.find(posstr) + if pos == -1: + return "" #如果查找的字符串没有找到,则返回空。 + + #根据找到的位置进行字符串拆分 + startstr = srcstr[0:pos] + beforestr = srcstr[pos:len(srcstr)] + #对字符串进行替换 + afterstr = beforestr.replace(oldstr,newstr) + return startstr+afterstr + + def __GetleftrightValueForSpecialapi(self,str): + """ + 得到conf.json specialapi配置中left和right的值 + 字符串格式为“[-][数字][单位]”,根据这个格式查找值和单位 + :param str: + :return: + """ + if len(str) ==0: + return 0,"" + + searchstr = r'[-]{0,1}([0-9]{1,2})([A-Za-z]{1,2})' + match = re.search(searchstr, str, re.I) + if match is not None: + value = int(match.group(1)) + unit = match.group(2) + return value,unit + else: + return 0,"" + + def __GettcolorboxForFunction(self,apidict): + #判断有没有设置页面左右边距,的左右边距 + #必须通过\geometry指令设置,否则按默认处理 + + leftvalue = 0 + leftunit = "mm" + left = "0mm" + + rightvalue = 0 + rightunit = "mm" + right="0mm" + + isleftset = False + isrightset = False + + if __dictIsHasKey__(apidict, "left") and \ + len(apidict['left']) > 0: + leftvalue,leftunit=self.__GetleftrightValueForSpecialapi(apidict['left']) + if len(leftunit) > 0: + left = str(leftvalue)+leftunit + isleftset = True + + if isleftset is not True: + leftvalue = 9 #默认值 + leftunit="mm" + left = str(leftvalue)+leftunit + + searchstr = r'\\geometry{[\s\S]*left=([0-9]{1,2})([A-Za-z]{1,2})[\s\S]*}' + match = re.search(searchstr, self.content, re.I | re.M) + if match is not None: + left = match.group(1) + match.group(2) + leftvalue = int(match.group(1)) + leftunit = match.group(2) + + #print(left) + + if __dictIsHasKey__(apidict,"right") and \ + len(apidict['right']) >0 : + rightvalue, rightunit = self.__GetleftrightValueForSpecialapi(apidict['right']) + if len(rightunit) > 0: + right = str(rightvalue) + rightunit + isrightset = True + + if isrightset is not True: + rightvalue = leftvalue - 1 # 为了使显示效果更好看 + right = str(rightvalue)+leftunit + + searchstr = r'\\geometry{[\s\S]*right=([0-9]{1,2})([A-Za-z]{1,2})[\s\S]*}' + match =re.search(searchstr,self.content,re.I|re.M) + if match is not None: + rightvalue = int(match.group(1)) -2 + rightunit = match.group(2) + right = str(rightvalue) + rightunit + #print(right) + + fontsize = "" + if __dictIsHasKey__(apidict, "fontsize") and \ + apidict['fontsize'] in latexfontsize: + fontsize = "fontupper=\\" + apidict['fontsize'] + + envstartstr = "\n\\begin{tcolorbox}[arc=0mm,boxrule=-1mm,left skip=-" \ + + str(leftvalue+1) + leftunit +",left="+left +",right=-" +right \ + + ",colback=shadecolor," + fontsize +"]\n" + envendstr = "\n\\end{tcolorbox}\n" + #print("=====================") + #print(envstartstr) + #print(envendstr) + #print("=====================") + return envstartstr,envendstr + + + def __IsSpecialFunction(self,apilst,linecontent): + ''' + 判断是否为特殊api + :param apilst: + :param linecontent: + :return: + ''' + if len(apilst)==0: + return False,None + + for apidict in apilst: + if __dictIsHasKey__(apidict,"latexapiname"): + apistr=apidict['latexapiname'] #有可能以逗号分割的多个函数名称 + apiarr=apistr.split(",") #以逗号分割函数名称列表 + for api in apiarr: + if api in linecontent: + return True,apidict + return False,None + + def __RemoveLastSpace(self,headstr): + """ + 删除最后的空格,避免函数含义和关键字间有大的空行 + :param headstr: + :return: + """ + headlist = headstr.split('\n') + index = 0 + for i in range(len(headlist)-1,0,-1): + str = headlist[i].strip('\r') + if len(str) == 0: + index = i + break + if index > 0: + headlist.pop(index) + return '\n'.join(headlist) + else: + return headstr + + def __ModifyParamPos(self,paramcontent): + """ + 调整参数位置在最前面,为了适配breathe 4.24.1之后版本 + :param paramcontent: + :return: + """ + #查找第一个\begin{description}\n\sphinxlineitem{}字符串 + searchstr = r"(\\begin\{quote\})?(\r\n|\n)?\\begin\{description\}(\r\n|\n)?\\sphinxlineitem\{(\sphinxstylestrong\{)?(?P[\s\S].+)?(\})?\}(\r\n|\n)?[\s\S]+?(\\end\{description\}){1}(\r\n|\n)?(\\end\{quote\})?" + match = re.search(searchstr,paramcontent,re.I|re.M|re.U) + if match is not None: + + if "\\begin{quote}" in match.group(): + #删掉缩进 + otherstr = match.group()[len('\\begin{quote}'):len(match.group())-len('\\end{quote}')] + else: + otherstr = match.group() + + paramkey = match.group('paramkey') + headstr = paramcontent[0:match.start()] + headstr = self.__RemoveLastSpace(headstr) + if "Parameters" in paramkey: + #说明参数本身在第一个位置,不做处理 + return headstr + '\n' + otherstr + paramcontent[match.end():len(paramcontent)] + + #如果参数不在第一个位置,查找参数的位置并将第一个位置记录下来 + searchstr = r"(\\begin\{quote\})?(\r\n|\n)?\\begin\{description\}(\r\n|\n)?\\sphinxlineitem\{(\sphinxstylestrong\{)?(?PParameters)?(\})?\}(\r\n|\n)?[\s\S]+?(\\end\{description\}){1}(\r\n|\n)?(\\end\{quote\})?" + matchparam = re.search(searchstr,paramcontent,re.I|re.M|re.U) + if matchparam is None: + return headstr + '\n' + otherstr + paramcontent[match.end():len(paramcontent)] + + if "\\begin{quote}" in matchparam.group(): + #删掉缩进 + paramstr = matchparam.group()[len('\\begin{quote}'):len(matchparam.group())-len('\\end{quote}')] + else: + paramstr = matchparam.group() + + #重新组合字符串,将参数放在第一位 + newparamcontent = headstr + '\n' +paramstr + '\n' +otherstr \ + + paramcontent[match.end():matchparam.start()] \ + + paramcontent[matchparam.end():len(paramcontent)] + + return newparamcontent + else: + return paramcontent + + def ModifyFunctionBackColor(self): + ''' + 该函数用来修改函数的背景色,并根据参数换行 + * c/c++ sphinx生成的函数都被以下三行包围: + \pysigstartmultiline + \pysiglinewithargsret{xxxxx} + \pysigstopmultiline + 因此利用正则表达式查找这三行所在的位置进行替换。 + 注意: + 要实现该功能,latex必须包含framed和color包,同时定义以下颜色: + \definecolor{shadecolor}{RGB}{220,220,220} + 如果shadecolor颜色未定义,则该函数执行失败。 + * python生成的latex文件中函数或者类的声明,有\pysiglinewithargsret指令,但是没有pysigstartmultiline指令。 + 因此需要在\pysiglinewithargsret指令句子的结束,前后先添加pysigstartmultiline和pysigstopmultiline。再添加\begin{shaded}和\end{shaded} + 一句的结束根据"{}"来判断,遇到"{}"即为句末。 + ''' + #判断是否有特殊API需要处理 + apilst = [] + isshaded = True + if __dictIsHasKey__(__load_dict__,"specialapi"): + if __dictIsHasKey__(__load_dict__["specialapi"],"styles"): + apilst=list(__load_dict__["specialapi"]["styles"]) + if __dictIsHasKey__(__load_dict__["specialapi"],"isshaded"): + isshaded = __load_dict__["specialapi"]["isshaded"] + + pythontype = False + newcontent = "" + prepos=[] #定义需要替换的字符串列表,该列表中的每一个元素都包含上面提到的三行。 + #searchstr = r'^(?=.*\\pysiglinewithargsret).*$' + #该正则表达式可以把pysiglinewithargsret或者带template的识别出来。有template的会被pysigline标签标识 + #该修改也兼容sphinx 5.3.0及以后版本。否则5.3.0版本之后将无法使用。 + #searchstr=r'(\\pysigline[\s\S].*){0,1}[\n]{0,1}(\\pysiglinewithargsret[\s\S].*)[\n]' + searchstr=r'(\\pysigstartsignatures){0,1}[\r\n|\n]{0,1}(\\pysigstartmultiline){0,1}(\r\n|\n){0,1}((\\pysigline[\s\S].*){0,1}[\n]{0,1}(\\pysiglinewithargsret[\s\S].*))[\n](\\pysigstopmultiline){0,1}[\r\n|\n]{0,1}(\\pysigstopsignatures){0,1}' + m = re.finditer(searchstr, self.content, re.M|re.I|re.U) + for match in m: + + linecontent = match.group() + #判断是否添加了pysigstartmultiline和pysigstopmultiline,没有添加的话需要添加才能添加灰色背景 + #multilen=len(r'\pysigstartmultiline') + #startmultistr = self.content[match.start()-multilen-1:match.start()] + #如果上一行不是\pysigstartmultiline则需要按python风格修改,添加该标志 + #print(startmultistr.strip()) + if match.group(2) is None: + #print('is python') + # 计算替换前的内容 + newcontent = self.\ + __ModifyFunctionBkColorByPython(isshaded,apilst,newcontent,linecontent,match.span(),prepos) + prepos = match.span() + pythontype = True + else: + #tablestr = match.groups() + #计算替换前的内容 + if len(prepos)==0: + newcontent = self.content[0:match.start()] + else: + # 得到参数信息内容,方便后面调整位置,此时调整的参数为上一个函数的参数位置 + paramcontent = self.content[prepos[1]:match.start()] + # breathe 4.24.1之后版本将parameters放在了最后面, + # 为了适配breathe 4.24.1之后版本对格式的调整 + # 检查参数是否在后面,是的话,需要将参数放到最前面 + # breathe 4.24.1之后版本将doxygen的标签都放在\begin{description}\n\sphinxlineitem{}标志对里面 + # 因此对该标志对进行搜索 + newparamcontent = self.__ModifyParamPos(paramcontent) + newcontent += newparamcontent + + + #当有template的时候,pysiglinewithargsret不一定在行首,因此需要从pysiglinewithargsret开始替换逗号,加强制换行符 + if isshaded: + envstartstr = "\n\\begin{shaded}\n" # 环境起始字符串 + envendstr = "\n\\end{shaded}\n" # 环境结束字符串 + else: + envstartstr = "" # 环境起始字符串 + envendstr = "" # 环境结束字符串 + + isapi,apidict = self.__IsSpecialFunction(apilst,linecontent) + if isapi is True and apidict is not None: + envstartstr,envendstr,afterstr=self.__ParseStyleForSpecialFunction(isshaded,apidict,linecontent) + newstr = envstartstr + afterstr + envendstr + else: + afterstr=self.__strreplacepos(linecontent,r"pysiglinewithargsret",r",",r",\\") + #得到替换后的字符串 + newstr = envstartstr + afterstr + envendstr + #得到替换后的内容 + newcontent += newstr + prepos = match.span() + pythontype = False + + if len(prepos) > 0: + paramcontent = self.content[prepos[1]:len(self.content)] + newparamcontent = self.__ModifyParamPos(paramcontent) + self.content = newcontent + newparamcontent + #self.content = newcontent + self.content[prepos[1]:len(self.content)] + #print('===============================================') + + def __GetTableCaption(self,tablecontent): + ''' + 得到表格的caption,方便输出打印 + :param tablecontent: 表格内容 + :return: 返回表格的caption + ''' + tablecaption = "" + searchstr = r'(\\sphinxcaption|\\caption)\{(?P[\s\S]*?)\}(?=\\label|\\\\)' + matchcaption = re.search(searchstr, tablecontent, re.M | re.I | re.U) + if matchcaption is not None: + tablecaption = matchcaption.group('caption') #得到caption的值 + #长表格caption后面会添加\struct的后缀,去掉 + tablecaption = tablecaption.replace(r"\strut","") + + return tablecaption + + def ModifyTablesAttributes(self): + #修改表格属性 + newcontent = self.content + searchstr = r'(\\begin{savenotes})([\s\S]*?)(\\sphinxattablestart|\\sphinxatlongtablestart)([\s\S]*?)(\\sphinxattableend|\\sphinxatlongtableend)([\s\S]*?)(\\end{savenotes})' + matchobj = re.finditer(searchstr, self.content, re.M|re.I|re.U) + #outputstr = "\033[34mModifying Table:\033[0m" + #print(outputstr, end="") + icount = 0 + for match in matchobj: + try: + icount = icount + 1 + tablestr = match.groups() + tablecontent = tablestr[0]+tablestr[1]+tablestr[2]+tablestr[3]+tablestr[4] + columnwidthlst,newtablecontent = self.__GetColumnParameters(tablecontent) + #得到表格caption + tablecaption=self.__GetTableCaption(tablecontent) + #print(tablecaption) + #outputtable="\033[34m[" + tablecaption +"];\033[0m" + #print(outputtable, end="") + #为了兼容纯英文linux系统,在打印信息中不能包含中文,因此按自动计数累加表示表格, + #避免表格标题中有中文时导致修改表格失败。 + prstr = "\033[34m[Modifying table"+str(icount)+":... ...];\033[0m" + print(prstr, end="\r") + if self.tablesattrobj.tablestyles is None: + #先修改表格的通用属性 + newtableattr,islongtable = self.__StartModifyTableAttr(newtablecontent,columnwidthlst, False) + newcontent = newcontent.replace(tablecontent, newtableattr) + else: + caption_dict = self.__CreateTableCaptionDict(self.tablesattrobj.tablestyles) + if len(caption_dict ) > 0 : + newtableattr = self.__ModifySingleTableattr(newtablecontent, columnwidthlst, caption_dict ) #tablestr也是3个内容的数组,因为正则表达式被分为了3组,只取中间分组的内容。 + #重新组成新的字符串 + newcontent = newcontent.replace(tablecontent, newtableattr) + + except Exception as e: + print(e) + traceback.print_exc() + + print("\r\n") + self.content = newcontent + + #替换表格中字体大小 + def __ModifyTableFontSize(self,singletablecontent,fontsize): + + searchstr = r'(\\begin{savenotes})([\s\S]*?)(\\sphinxattablestart|\\sphinxatlongtablestart)' + m = re.search(searchstr,singletablecontent, re.M|re.I|re.U) + tablestr = m.groups() + #替换表格中的字体 + #同时也要使表格的caption字体大小保持一致,修改表格的caption需要添加如下指令:\renewcommand{\tablename}{\scriptsize{表} } + captionfont = "\n\\renewcommand{\\tablename}{\\" + fontsize + "{Table} }\n" + if doc_language=='zh_CN': + captionfont="\n\\renewcommand{\\tablename}{\\" + fontsize + "{表} }\n" + + tablefontsize = '\\'+fontsize + captionfont + + return tablestr[0]+tablefontsize+tablestr[2]+singletablecontent[m.end():len(singletablecontent)] + + def __CreateTableCaptionDict(self, tablestylesarr): + #根据caption生成表格字典,key=caption,value=属性数组 + cap_dict = {} + + for tablestyle_dict in tablestylesarr: + captionarr = tablestyle_dict['caption'] + #该caption可能是一个逗号分隔的字符串数组,因此需要做拆分 + captionlist = captionarr.split(",") + for caption in captionlist: + cap_dict[caption] = tablestyle_dict #以caption为key重新生成字典,便于查找 + return cap_dict + + #查找表格caption是否被匹配 + def __JudgeIsCaption(self,caption_dict,tablecaption): + + captionlst=caption_dict.keys() #得到json配置的所有caption + for caption in captionlst: + tablestyle_dict = caption_dict[caption] + if __dictIsHasKey__(tablestyle_dict, 'ispartial') and tablestyle_dict['ispartial']==True: + if caption in tablecaption: + return True,tablestyle_dict + else: + if caption == tablecaption: + return True,tablestyle_dict + return False,None + + def __ModifySingleTableattr(self, singletablecontent, columnwidthlst, caption_dict): + #修改单个表格属性 + #从单个表格里用正则表达式找caption + #定义正则表达式,查找caption内容 + tablecaption = '' + new_singletablecontent = singletablecontent + if self.tablesattrobj.isname: + searchstr = r'.*\\label.*?:(?P[\s\S].*)}}.*' + matchcaption = re.search(searchstr, singletablecontent, re.M | re.I | re.U) + if matchcaption is not None: + tablecaption = matchcaption.group('caption') # 得到caption的值 + else: + tablecaption = '' + else: + tablecaption = self.__GetTableCaption(singletablecontent) + + #stylefontsize = False + iscaption = False + iscaption,tablestyle_dict=self.__JudgeIsCaption(caption_dict,tablecaption) + if iscaption: + + #if __dictIsHasKey__(tablestyle_dict,'fontsize') and (tablestyle_dict['fontsize'] in latexfontsize): + # stylefontsize = True + # tablestyle_dict['isLongTable'] = True #只有长表格的\caption指令,设置的字体才能生效,所以如果要设置表格字体,需要强制设置为长表格。 + jsonmultirow = [] #保存需要设置自动换行的合并单元格的列号。 + if __dictIsHasKey__(tablestyle_dict, 'multirowcolumn'): + jsonmultirow = tablestyle_dict['multirowcolumn'] + + if __dictIsHasKey__(tablestyle_dict, 'isLongTable') is not True: + tablestyle_dict['isLongTable'] = False + + if __dictIsHasKey__(tablestyle_dict, 'isCusHead') is not True: + tablestyle_dict['isCusHead'] = False + + #修改表格通用属性 + new_singletablecontent,islongtable= self.__StartModifyTableAttr(new_singletablecontent, + columnwidthlst, + tablestyle_dict['isLongTable'], + tablestyle_dict['isCusHead'], + jsonmultirow, + tablestyle_dict) + #设置该表格定制字体 + if __dictIsHasKey__(tablestyle_dict,'fontsize') and (tablestyle_dict['fontsize'] in latexfontsize): + new_singletablecontent=self.__ModifyTableFontSize(new_singletablecontent,tablestyle_dict['fontsize']) + + #渲染竖型表格的第一列 + if __dictIsHasKey__(tablestyle_dict,'isVertical') and tablestyle_dict['isVertical']==True: + if islongtable: + # 长表格自定义sphinxcolwidth的第2个参数默认为5,否则无法实现自动换行 + #self.normalstr = self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth': '5', + # 'sphinxstyletheadfamily': '%(sphinxstyletheadfamily)s', + # 'content': '%(content)s' + #} + new_singletablecontent = self.__ModifyVerticalLongTable(new_singletablecontent,tablestyle_dict,columnwidthlst) + else: + # 普通表格自定义sphinxcolwidth的第2个参数默认为3,否则无法实现自动换行 + #self.normalstr = self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth': '3', + # 'sphinxstyletheadfamily': '%(sphinxstyletheadfamily)s', + # 'content': '%(content)s' + #} + new_singletablecontent = self.__ModifyTableByLine(new_singletablecontent,True,False,tablestyle_dict,columnwidthlst) + + if __dictIsHasKey__(tablestyle_dict,'isnewpage') and tablestyle_dict['isnewpage']==True: + new_singletablecontent = '\\newpage\n' + new_singletablecontent + + #最后修改表格的行高 + if __dictIsHasKey__(tablestyle_dict, 'ht_rowno') and tablestyle_dict['ht_rowno'] is not None: + new_singletablecontent = self.__ModifyTableRowHeight(islongtable,new_singletablecontent,tablestyle_dict['ht_rowno']) + else: + #修改表格的通用属性 + new_singletablecontent,islongtable = self.__StartModifyTableAttr(singletablecontent, columnwidthlst, False) + if new_singletablecontent == '': + new_singletablecontent = singletablecontent + + return new_singletablecontent + + def __NormaltableRowHeight(self,rowno,rowheightinfo,singletablecontent): + #自定义长表格,第一行以“\hline”开头,因此从第一个“\hline”开头进行替换。 + searchstr = r'\\hline' + match = re.search(searchstr,singletablecontent,re.I | re.U) + if match is None: + return singletablecontent + + #latex表格属性内容 + tablehead = singletablecontent[0:match.start()] + #表格内容 + tablecontent = singletablecontent[match.start():len(singletablecontent)] + tablecontent = self.__ModifyNormalRowHeight(rowno,rowheightinfo,0,tablecontent) + + return tablehead + tablecontent + + def __ModifyNormalRowHeight(self, rowno, rowheightinfo, startline, tablecontent): + # 修改除表格外其它行的行高,并返回修改后的内容 + searchstr = r'\\\\' + pattern = re.compile(searchstr, re.I | re.U) + matchiter = pattern.finditer(tablecontent) + i = startline + for m in matchiter: + i += 1 + if rowno == i: + # 开始设置该行的行高 + # 获取起始和结束位置 + posstart = m.start() + posend = m.end() + tablecontent = tablecontent[0:posstart] + rowheightinfo + tablecontent[posend:len(tablecontent)] + break + + return tablecontent + + def __JudgeRowlistRule(self,rowlist): + #判断行号和行高信息是否符合规则 + + if len(rowlist)!=2: + return False + rowno = rowlist[0] #得到行号 + rowhtstr=rowlist[1] #得到行高信息 + + #判断是否为异常信息 + #设置的最大行高不能超过10,因为10的行高太高,正常不需要这么高的行高。而且sphinx会根据换行自动调整行高,该方法只用户行高的微调。 + if rowno <= 0 or len(rowhtstr) == 0 or not rowhtstr.isdigit() or (int(rowhtstr) > 99): + return False + return True + + def __ModifyLongTableRowHeight(self,rowno,rowheightinfo,singletablecontent): + + # 判断是自定义长表格还是sphinx生成的长表格。 + # sphinx自动生成的长表格有两个表头,第一个表头以“\endfirsthead”结尾。 + # 自定义的长表格没有“\endfirsthead”标志 + searchstr = r'\\endfirsthead' + matchobj = re.search(searchstr, singletablecontent, re.I | re.U) + if matchobj is None: + # 可能是自定义长表格,自定义长表格没有endfirsthead和endhead需要单独设置 + singletablecontent = self.__NormaltableRowHeight(rowno,rowheightinfo,singletablecontent) + return singletablecontent + + if rowno==1: #因为长表格有两个表头,因此两个表头都要设置,避免换页后表头格式不一致。 + #修改长表格的表头 + firsthead = singletablecontent[0:matchobj.end()] + rpos = firsthead.rfind(r"\\") + firsthead = firsthead[0:rpos] + rowheightinfo + firsthead[rpos+len(r"\\"):len(firsthead)] + + #查找第2个表头,第2个表头以“\endhead”结尾 + searchstr = r'\\endhead' + matchobj2 = re.search(searchstr, singletablecontent, re.I | re.U) + endhead = singletablecontent[matchobj.end():matchobj2.end()] + rpos = endhead.rfind(r"\\") + + endhead = endhead[0:rpos]+rowheightinfo+ endhead[rpos+len(r"\\"):len(firsthead)] + + singletablecontent = firsthead+endhead+singletablecontent[matchobj2.end():len(singletablecontent)] + + if rowno > 1: + #长表格先将表头去除掉,否则表头会包含很多“\\”,但不代表行。 + #长表格的表头以“\endlastfoot”结尾 + searchstr = r'\\endlastfoot' + matchobj = re.search(searchstr, singletablecontent, re.I | re.U) + tablehead = singletablecontent[0:matchobj.end()] + othercontent = singletablecontent[matchobj.end():len(singletablecontent)] + + #修改表格内容,因为长表格有两个表头,因此表头后的行号起始行从1开始。 + othercontent = self.__ModifyNormalRowHeight(rowno,rowheightinfo,1,othercontent) + + singletablecontent = tablehead + othercontent + return singletablecontent + + def __ModifyTableRowHeight(self,islongtable,singletablecontent, rowlistarry): + + #解析行信息,包含行号和行高信息 + for rowlist in rowlistarry: + + if not self.__JudgeRowlistRule(rowlist): + continue + rowno = rowlist[0] + rowheightinfo = r"\\[" + rowlist[1] +r"ex]" #要替换的行高信息 + #根据正则表达式,查找换行符"\\",每一个“\\"代表表格的一行,需要在“\\”的末尾添加行高数据:“\\[rowhtstr+ex]” + if islongtable: + singletablecontent = self.__ModifyLongTableRowHeight(rowno, rowheightinfo,singletablecontent) + else: + singletablecontent = self.__NormaltableRowHeight(rowno,rowheightinfo,singletablecontent) + + return singletablecontent + + def __StartModifyTableAttr(self, singletablecontent, columnwidthlst, islongtable, isCusHead=True,jsonmultirow=None,tablestyle_dict=None): + #修改表格的通用属性 + searchstr = r'(\\begin{tabular}|\\begin{tabulary})(\[[a-z]\]|{\\linewidth}\[[a-z]\])([\s\S].*)' + #为了添加表格的通用属性,先对字符串做分割 + #python会把正则表达式中的分组自动分割,因此上面的正则表达式会自动分割为三个字符串 + #加上头尾字符串总共分为5个字符串数组。要修改第1维字符串为\\being{longtable}, + #第2维字符串直接删除,第3维字符串不变 + splittable = re.split(searchstr, singletablecontent,0, re.M | re.I|re.U ) + if splittable is None or len(splittable) < 5: + #再修改长表格属性 + searchstr = r'\\begin{longtable}([\s\S].*)' + #为了添加表格的通用属性,先对字符串做分割 + #python会把正则表达式中的分组自动分割,因此上面的正则表达式会自动分割为三个字符串 + #加上头尾字符串总共分为5个字符串数组。要修改第1维字符串为\\being{longtable},第2维字符串直接删除, + #第3维字符串不变 + splittable = re.split(searchstr, singletablecontent,0, re.M | re.I|re.U ) + if len(splittable) < 3: + #至少是3维的数组,否则不是预想的内容,不做处理 + return singletablecontent + if isCusHead: + newtable4 = self.__ModifyLongTableHead(splittable[2], + columnwidthlst, + self.tablesattrobj.headtype, + jsonmultirow, + tablestyle_dict) + # begin{longtable}必须再加上,因为Python并不认为它是正则表达式,因此不再分组里面第0个分组为空。 + singletablecontent = splittable[0]+r'\begin{longtable}'+splittable[1]+newtable4 + if self.tablesattrobj.fontsize !="": + #设置表格字体 + singletablecontent=self.__ModifyTableFontSize(singletablecontent,self.tablesattrobj.fontsize) + return singletablecontent,True + + #拆分后splittable应该为5个字符串的数组,拆分后便于添加通用属性 + if self.tablesattrobj.rowtype != '': + splittable[0] += self.tablesattrobj.rowtype + '\n' + + if isCusHead is True: + #修改表头字体颜色为白色加粗 + newtable4 = self.__ModifyTableHead(splittable[4], + columnwidthlst, + self.tablesattrobj.headtype, + jsonmultirow, + tablestyle_dict) + else: + newtable4 = splittable[4] + + singletablecontent = splittable[0]+splittable[1]+splittable[2]+splittable[3]+newtable4 + #根据配置设置表格为长表格,现在基本通过将sphinx的class属性设置为longtable实现,不再通过手动设置完成。 + if islongtable is True: + singletablecontent = self.__ModifyTableLongHeadAndTail(singletablecontent) + + if self.tablesattrobj.fontsize !="": + #设置表格字体 + singletablecontent=self.__ModifyTableFontSize(singletablecontent,self.tablesattrobj.fontsize) + + return singletablecontent,False + + def __ModifyTableLongHeadAndTail(self,singletablecontent): + #长表格的头尾都要替换,因此单独封装成一个函数,否则长表格的表格序号将加2 + #替换第一行为长表格 + searchstr = r'(\\begin{savenotes}[\S]*?\\sphinxattablestart)' + splittable = re.search(searchstr, singletablecontent,re.M | re.I|re.U ) + firstend =splittable.end() + #替换为长表格 + tablefirstline = re.sub(r'\\sphinxattablestart',r'\n\\sphinxatlongtablestart\n', splittable.group(0), 1,re.M|re.I|re.U) + if r'\sphinxthistablewithglobalstyle' in singletablecontent: + tablefirstline += '\\sphinxthistablewithglobalstyle\n' + if r'\sphinxthistablewithvlinesstyle' in singletablecontent: + tablefirstline += '\\sphinxthistablewithvlinesstyle\n' + #替换第2行为长表格 + searchstr = r'(\\begin{tabular}|\\begin{tabulary})(\[[a-z]\]|{\\linewidth}\[[a-z]\])([\s\S].*)' + splittable = re.search(searchstr, singletablecontent,re.I|re.U ) + #记录表头的位置 + headstartpos= splittable.start() + headlastpos = splittable.end() + tablesecondline = re.sub(r'\\begin{tabular}|\\begin{tabulary}',r'\\begin{longtable}', splittable.group(0), 1,re.I|re.U) + tablesecondline = re.sub(r'\{\\linewidth\}',r'', tablesecondline, 1,re.I|re.U) + + #查找caption + searchstr = r'\\sphinxcaption([\s\S].*)' + splittable = re.search(searchstr, singletablecontent, re.I|re.U) + if splittable is not None: + longcaption = re.sub(r'\\sphinxcaption',r'\\caption', splittable.group(0), 1,re.I|re.U) + #添加长表格专用指令 + longcaption += r"\\*[\sphinxlongtablecapskipadjust]" + + #构建长表个的表头部分 + longhead = tablefirstline + tablesecondline + '\n' + r'\sphinxthelongtablecaptionisattop' + '\n'+ longcaption+'\n' + else: + #如果没有caption需要把中间的内容加上,否则会丢失label导致连接失效 + longhead = tablefirstline + singletablecontent[firstend:headstartpos]+tablesecondline + '\n' + + #替换表尾 + newtablecontent = singletablecontent[headlastpos:len(singletablecontent)] + index = newtablecontent.rfind('\\end{tabulary}') + if index > 0: + endtable = newtablecontent[index:len(newtablecontent)] + othercontent = newtablecontent[0:index] + endtable = re.sub(r'(\\end{tabular}|\\end{tabulary})', r'\\end{longtable}', endtable, 1,re.M | re.I | re.U) + endtable = re.sub(r'\\par', r'', endtable, 1, re.M | re.I | re.U) + endtable = re.sub(r'\\sphinxattableend', r'\\sphinxatlongtableend', endtable, 1, re.M | re.I | re.U) + newtablecontent = othercontent + endtable + else: + index = newtablecontent.rfind('\\end{tabular}') + if index >0: + endtable = newtablecontent[index:len(newtablecontent)] + othercontent = newtablecontent[0:index] + endtable = re.sub(r'(\\end{tabular}|\\end{tabulary})', r'\\end{longtable}', endtable, 1,re.M | re.I | re.U) + endtable = re.sub(r'\\par', r'', endtable, 1, re.M | re.I | re.U) + endtable = re.sub(r'\\sphinxattableend', r'\\sphinxatlongtableend', endtable, 1, re.M | re.I | re.U) + newtablecontent = othercontent + endtable + + singletablecontent = longhead + newtablecontent + return singletablecontent + + #为表格增加h属性对齐方式,避免表格浮动,让表格在当前位置显示 + def __AddHAttrForTable(self,content): + + newcontent="" + searchstr = r'(\\begin{tabular}|\\begin{tabulary}|\\begin{longtable})({\\linewidth})?(\[(?P[a-z]{1,4})\])?([\s\S].*)' + m = re.search(searchstr, content, re.M|re.I|re.U ) + attrcontent = m.group('attr') + posarr = m.span() + if m.group(3) == '': #group(3)是[htp]的组合,如果没有表格属性则添加上表格属性 + newcontent = content[0:posarr[0]]+m.group(1)+m.group(2)+r'[h]'+m.group(5)+content[posarr[1]:len(content)] + else: + replacestr = m.group(4) #需要被替换的表格属性 + #判断该表格是否有P属性,如果有p属性需要删掉 + replacestr.replace('p','') + #给该表格添加h属性,避免表格浮动,让表格在当前位置显示 + replacestr = 'h'+replacestr + newcontent = content[0:posarr[0]]+m.group(1)+m.group(2)+'['+replacestr+']'+m.group(5)+content[posarr[1]:len(content)] + + return newcontent + def __GetColumnParaByWidths(self,content): + # 可能用的widths参数设置的列宽,再根据widths参数解析列宽 + columnwidthlst = [] + # 如果列宽不等于100,为了使表格不超过页面显示范围,以100为基数重新调整各列的列宽 + newcolumnwidth = "|" + isadjustcol = False + newcontent = content + try: + searchstr = r'(?<=\|)\\X\{(?P[0-9]*)\}\{(?P[0-9]*)\}(?=\|)' + pattern = re.compile(searchstr, re.I | re.U) + match = pattern.finditer(content) + if match is None: + return [],content + for m in match: + count1 = int(m.group('count1')) + count2 = int(m.group('count2')) + + tempcount = round(count1/count2,2) + colwidth = "%(float)s\\textwidth" + colwidth = colwidth % { + 'float': str(tempcount) + } + columnwidthlst.append(colwidth) + if count2 != 100: + isadjustcol = True + #重新调整count1和count2的值,并加入列表 + percolwidth = r"\X{" + str(int(tempcount * 100)) + "}{100}|" + newcolumnwidth = newcolumnwidth + percolwidth + + if isadjustcol is True: + searchstr = r"(?<=\{)\|\\X[\s\S]*?\|(?=\})" + match = re.search(searchstr,content,re.I | re.U) + if match is not None: + newcontent = content.replace(match.group(),newcolumnwidth,1) + + except Exception as e: + print(e) + finally: + return columnwidthlst,newcontent + + def __GetColumnParameters(self,content): + ''' + 得到通过tablecolumns参数设置的列宽,multirow的自动换行,必须有列宽参数,否则无法实现自动换行。 + 而且不能通过widths参数设置每列的列宽,否则生成的列宽无法用于multirow,依然无法实现multirow的自动换行 + :param content: + :return: + ''' + # 默认ismultirowautowrap为自动换行 + # 如果表格用tabluarcolumns指令设置的列宽,因为每列的宽度有具体的值,则默认自动换行 + # 如果表格用widths指令设置的列宽,因为每列的宽度不具体,则默认不自动换行,需要手动换行 + self.tablesattrobj.ismultirowautowrap = True #默认ismultirowautowrap为自动换行 + newcontent = content + columnwidthlst=[] #保存每一列宽度的列表 + #根据表格内容,得到每一列宽度参数,方便multirow的时候自动换行 + #m垂直居中,p垂直靠上,b垂直靠下 + searchstr= r'(?<=\|)[\s\S]*?[m|p|b]\{(?P[\s\S]*?)\}[\s\S]*?(?=\|)' + pattern = re.compile(searchstr,re.I | re.U) + match = pattern.finditer(content) + if match is None: + #columnwidthlst,newcontent = self.__GetColumnParaByWidths(content) + self.tablesattrobj.ismultirowautowrap = False #则默认需要手动换行 + return [],newcontent + + for m in match: + columnwidthlst.append(m.group('width')) + if len(columnwidthlst) ==0: + #columnwidthlst,newcontent= self.__GetColumnParaByWidths(content) + self.tablesattrobj.ismultirowautowrap = False #则默认需要手动换行 + return [],newcontent + return columnwidthlst,newcontent + + def __FindTableHeadForFamily(self,content,isLongTable): + ''' + 根据sphinxstyletheadfamily标志查找表格的表头,有sphinxstyletheadfamily标志的表头,源文件必须以“======”分割。 + 具有合并单元格的表头,源文件必须以“===”分割,而且表头最多支持两行,否则合并单元格的表头不支持渲染。 + :param content: 表格内容 + :return: headcontent 表格内容 + ''' + #查找表头的起始位置,以第一个\hline开始。 + if isLongTable is True: + searchstr =r"(\\hline[\s\S]*)(?=\\endfirsthead)" + matchobj = re.search(searchstr, content, re.I|re.M|re.U) + tablehead = matchobj.group() + else: + searchstr = r'\\hline' + matchobj = re.search(searchstr,content,re.I) + if matchobj is None: + return "" + startpos = matchobj.start() #表头的起始位置 + familypos = content.rfind(r'\sphinxstyletheadfamily') #查找最后一个sphinxstyletheadfamily出现的位置 + if familypos ==-1: + return "" + + #从familypos开始,查找到的第一个“\\”,则为表头的行结束符号。 + endpos = content.find(r"\hline",familypos) + # 得到表头内容,表头内容从起始\hline开始,到\hline结束,这种查找方式,兼容表头为两行的情况。 + tablehead = content[startpos:endpos+len(r"\hline")] + + return tablehead + + def __FindAndFlagPos(self,tablehead,andpos): + if andpos != -1: + # 查找&前面有没有转义字符,有转义字符,则&不为单元格分割符,需要再继续查找 + preflag = tablehead[andpos-1] + if preflag == '\\': + #如果是\&说明这个&不是单元格的分隔符,则需要继续查找 + newpos=tablehead.find("&",andpos+1) + andpos = self.__FindAndFlagPos(tablehead,newpos) + return andpos + else: + return andpos + return andpos + + def __FindCellContentForHead(self,tablehead,startpos): + #查找每个单元格的内容 + if startpos == -1: + return 0,-1,None + + # 查找&和\\的位置,&为单元格的分割标志,\\为行结束标志,一个单元格的分割标志可能是&,也可能是\\ + cellcontent = [] + andpos = tablehead.find("&", startpos) + # 得到实际的分隔符位置,避免&内容里面的符号,而不是单元格分隔符 + andpos = self.__FindAndFlagPos(tablehead, andpos) + linepos = tablehead.find(r"\\", startpos) + + if andpos ==-1 and linepos ==-1: + return 0,-1,None + + newpos = -1 + lineflag = 0 # 是否启动新行的标志。0说明在当前行;1说明遇到了\\,需要启动新的一行 + if andpos == -1 or linepos < andpos: + # 该单元格的分割符为\\ + cellcontent.append(tablehead[startpos:linepos]) + cellcontent.append(r"\\") + newpos = linepos + len(r"\\") + lineflag = 1 + else: + cellcontent.append(tablehead[startpos:andpos]) + cellcontent.append(r"&") + newpos = andpos + len(r"&") + return lineflag,newpos,cellcontent + ''' + #暂时注释,不在使用该函数,使用上面函数的写法 + def __delNotMustStrFromContent(self, content): + content = content.strip().strip("\n").strip("\r") + for delstr in self.delstrlst: + if delstr in content: + content = content.replace(delstr,"") + + content = content.strip().strip("\n").strip("\r") + print('content=%s' % content) + strlst = content.split('\n') + newcontent = "" + for i in strlst: + if i != "": + i.strip("\r") + if i != strlst[-1] and i != "": + i = i + r"\\ " + newcontent = newcontent + i + content = self.autolinetemplat % { + 'autocontent': newcontent + } + return content + ''' + + def __delNotMustrowStrFromContent(self, content, ismanualwraplst=None): + """ + 该函数用于multirow的手动换行,为了更好的显示效果,multicolumn的手动换行请参见__delNotMustStrFromContent + :param content: + :param ismanualwraplst: list类型,只有一个元素。是否有手动换行,为了回传,并兼容之前的调用,设计了list类型。 + :return: + """ + # 实现自动换行 + # 去掉首尾换行和空白字符 + content = content.strip().strip("\n").strip("\r") + # 删除单元格内容里非必须的成分 + for delstr in self.delstrlst: + if delstr in content: + content = content.replace(delstr, "") + # 正则匹配既是加了“r”,也要用转义字符,因为r只是去掉了python的转义, + # 但正则本身也需要一层转义 + content = content.strip().strip("\n").strip("\r") + searchstr = r"\\sphinxAtStartPar" + match = re.search(searchstr, content, re.I | re.M) + if match is not None: + # 说明在字符串的起始位置匹配成功 + content = content[match.end():len(content)] + # 再去掉首尾空格和换行 + content = content.strip().strip("\n").strip("\r") + if r"\sphinxAtStartPar" in content: + # 字符串替换比较,如果用了r则无需再加一次转义字符 + # 替换所有的回车换行符,否则autolinetemplat中间不能有回车换行,否则将导致错误 + content = content.replace("\n", "").replace("\r", "") + content = content.replace(r"\sphinxAtStartPar", r"\newline ") + if ismanualwraplst is not None: + ismanualwraplst[0] = True + else: + strlst = content.split('\n') + newcontent = "" + for i in range(0, len(strlst)): + str = strlst[i].strip("\r") + if str != "": + strlst[i] = str + else: + if i > 0 and i < (len(strlst) - 1) and strlst[i - 1] != r"\newline ": + strlst[i] = r"\newline " + + content = "".join(strlst) + if r"\newline" in content and ismanualwraplst is not None: + ismanualwraplst[0] = True + + return content + + def __delNotMustStrFromContent(self,content,ismanualwraplst=None): + """ + 该函数用于multicolumn的手动换行,为了更好的显示效果,multirow的手动换行请参见__delNotMustrowStrFromContent + :param content: + :param ismanualwraplst: list类型,只有一个元素。是否有手动换行,为了回传,并兼容之前的调用,设计了list类型。 + :return: + """ + #实现自动换行 + #去掉首尾换行和空白字符 + content = content.strip().strip("\n").strip("\r") + #删除单元格内容里非必须的成分 + for delstr in self.delstrlst: + if delstr in content: + content = content.replace(delstr,"") + # 正则匹配既是加了“r”,也要用转义字符,因为r只是去掉了python的转义, + # 但正则本身也需要一层转义 + content = content.strip().strip("\n").strip("\r") + searchstr = r"\\sphinxAtStartPar" + match = re.search(searchstr,content,re.I|re.M) + if match is not None: + #说明在字符串的起始位置匹配成功 + content = content[match.end():len(content)] + # 再去掉首尾空格和换行 + content = content.strip().strip("\n").strip("\r") + if r"\sphinxAtStartPar" in content: + #字符串替换比较,如果用了r则无需再加一次转义字符 + #替换所有的回车换行符,否则autolinetemplat中间不能有回车换行,否则将导致错误 + content=content.replace("\n","").replace("\r","") + content = content.replace(r"\sphinxAtStartPar", r"\\ ") + content = self.autolinetemplat % { + 'autocontent': content + } + if ismanualwraplst is not None: + ismanualwraplst[0] = True + else: + strlst = content.split('\n') + newcontent = "" + for i in range(0,len(strlst)): + str = strlst[i].strip("\r") + if str != "": + strlst[i]=str + else: + if i>0 and i< (len(strlst)-1) and strlst[i-1] != r"\\ ": + strlst[i]=r"\\ " + + newcontent = "".join(strlst) + if r"\\ " in newcontent: + content = self.autolinetemplat % { + 'autocontent': newcontent + } + if ismanualwraplst is not None: + ismanualwraplst[0] = True + else: + content = newcontent + + return content + + def __ModifySphinxMultirow(self,cellstr,contentdict,columnwidthlst,jsonmultirow,lineindex,colindex,multirowindex,tablestyle_dict,isvertical=False): + ''' + 修改multirow内容 + 取出关键内容合并单元格的数量和单元格内容 + :param isvertical:如果是渲染第一列,multirow加\raggedright参数为了对齐 + 渲染行,无需加该参数,显示效果没明显差别。 + ''' + #searchstr=r"(?<=\\sphinxmultirow){(?P[0-9]*)}{(?P[0-9]*)}([\s\S]*?)\\sphinxstyletheadfamily(?P[\s\S]*?)\\par([\s\S]*?)(?=\}\%)" + searchstr=r"(?<=\\sphinxmultirow){(?P[1-9]*)}{(?P[0-9]*)}{([\s\S]*?){\\sphinxcolwidth{[0-9]*}{[0-9]*}}[\s]*(\\sphinxstyletheadfamily)?(?P[\s\S]*?)\\par([\s\S]*?)(?=\}\%)" + match = re.finditer(searchstr, cellstr, re.M | re.I | re.U) + count = "" + flag = "" + content = "" + for m in match: + if count == "": + count = m.group('count') + if flag == "": + flag = m.group('flag') + if content == "": + content = m.group('content') + + #如果是自动换行,设置列宽 + coluwidth = "*" + if len(columnwidthlst) > 0: + if tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict,"ismultirowautowrap"): + if tablestyle_dict['ismultirowautowrap'] is True: + coluwidth = columnwidthlst[colindex-1] + else: + if self.tablesattrobj.ismultirowautowrap is True: + coluwidth = columnwidthlst[colindex-1] + + #if (jsonmultirow is not None) and \ + # (multirowindex>=0 and len(jsonmultirow)>multirowindex) and \ + # (len(columnwidthlst)>jsonmultirow[multirowindex]-1): + # #得到要设置自动换行multirow的列号 + # coluwidth = columnwidthlst[jsonmultirow[multirowindex]-1] + # #print("coluwidth = %s" % coluwidth) + #else: + # coluwidth = "*" + + #if r"\sphinxstyletheadfamily" in cellstr: + #实测,multirow要支持渲染背景色,必须带上sphinxstyletheadfamily,否则会编译失败 + sphinxstyletheadfamily = r"\sphinxstyletheadfamily" + raggedright = "" + if (isvertical is True) or \ + (tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict, "ismultirowatupcell") and \ + tablestyle_dict['ismultirowatupcell'] is True) or \ + self.tablesattrobj.ismultirowatupcell is True: + #raggedright = r"{\raggedright}" + # 得到单元格内容 + cellstr = self.commonstr % { + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustrowStrFromContent(content) + } + multistr = self.multirowstr % { + 'count': count, + 'coluwidth': coluwidth, + 'raggedright': raggedright, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': "" + } + #将flag对应的内容加到字典里,以方便后续交换 + contentdict[flag] = str(int(count) + lineindex-1) + "," + multistr #为了支持超过两行的合并单元格,将count数量也带出去。第一个","之前的就是count数量 + + #返回空的字符串 + return cellstr + else: + ismanualwraplst = [False] + if (tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict, "ismultirowatlowcell") and \ + tablestyle_dict['ismultirowatlowcell'] is True ) or \ + self.tablesattrobj.ismultirowatlowcell is True: + + simplemultirow = self.simplemultirow % { + 'count': count, + 'coluwidth': coluwidth + } + + cellstr = self.commonstr % { + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustrowStrFromContent(content, ismanualwraplst) + } + + contentdict[flag] = str(int(count) + lineindex - 1) + "," + cellstr + + return simplemultirow + + multistr = self.multirowstr % { + 'count': count, + 'coluwidth': coluwidth, + 'raggedright': raggedright, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustStrFromContent(content) + } + # 将flag对应的内容加到字典里,以方便后续交换 + contentdict[flag] = str(int(count) + lineindex - 1) + "," + multistr # 为了支持超过两行的合并单元格,将count数量也带出去。第一个","之前的就是count数量 + + # 返回空的字符串 + return self.emptystr + + + + def __ModifyMulticolumn(self,cellstr,tablestyle_dict): + ''' + :param cellstr: + :return: + ''' + searchstr = r"(?<=\\multicolumn){(?P[1-9]*)}\{(?P[\s\S]*?)\}\{%(?P[\s\S]*?)\\sphinxstyletheadfamily(?P[\s\S]*?)\\par(?P[\s\S]*)" + match = re.finditer(searchstr, cellstr, re.M | re.I | re.U) + count = "" + content = "" + flag="" + prefix="" + suffix="" + for m in match: + if count == "": + count = m.group('count') + if flag == "": + flag = m.group('flag') + if content == "": + content = m.group('content') + if prefix == "": + prefix = m.group('prefix') + #如果有\begin{varwidth}[t],替换为\begin{varwidth}[c] + #[t]会使得当自动换行的时候,单元格上面会有很大的空白 + #print('===========================================') + #print(prefix) + searchstr = r"\\begin{varwidth}\[t\]" + replstr = r"\\begin{varwidth}[c]" + prefix = re.sub(searchstr,replstr,prefix,1,re.I|re.U|re.M) + #print(prefix) + #print('===========================================') + if suffix == "": + suffix = '\n\\par' + m.group('suffix') #因为正则表达式把\par去掉了,再加上 + + #行合并单元格,改成居中显示 + newflag=re.sub(r"[a-z]",r"c",flag,1,re.I) + #newflag = flag + if (tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict,'ismulticolleftalign') and \ + tablestyle_dict['ismulticolleftalign'] is True) or \ + self.tablesattrobj.ismulticolleftalign is True: + #multistr = self.comcolstr % { + # 'count': count, + # 'flag': flag, + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(content) + #} + multistr = self.multicolumnstr % { + 'count':count, + 'flag': flag, + 'prefix':prefix, + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content':self.__delNotMustStrFromContent(content), + 'suffix':suffix + } + else: + multistr = self.multicolumnstr % { + 'count':count, + 'flag': newflag, + 'prefix':prefix, + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content':self.__delNotMustStrFromContent(content), + 'suffix':suffix + } + + #multistr = self.comcolstr % { + # 'count': count, + # 'flag': flag, + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(content) + #} + return multistr + + def __ModifyCellContentForMergeCell(self,cellcontent, + tablestyle_dict,islongtable, + contentdict,columnwidthlst,jsonmultirow, + colindex,lineindex,multirowindex,ismulticol = False): + ''' + 修改单元格内容 + :param cellcontent: + :param ismulticol: 当前行是否有列合并单元格。 + 如果有列合并单元格,则独立的单元格需要用multcolumn命令,否则无法对齐。 + :return: + ''' + multirowflag = False #是否是multirow单元格 + cellstr = cellcontent[0] #取得单元格内容 + #print('===========================') + #print(cellstr) + + #查找有没有\cline指令,如果有这个指令,为了划线清晰在前面再多加一个。 + searchstr= r"\\cline\{[\s\S]*?\}" + match = re.search(searchstr,cellstr,re.I | re.M) + if match is not None: + cline = match.group() + #去掉开头的换行符,否则影响后面的正则表达式,导致正则表达式查找失败 + cellstr = cellstr.lstrip() + cellstr = cline + cellstr + + sphinxstyletheadfamily="" + if r'\sphinxmultirow' in cellstr: + cellcontent[0] = self.__ModifySphinxMultirow(cellstr, + contentdict, + columnwidthlst, + jsonmultirow, + lineindex, + colindex, + multirowindex, + tablestyle_dict) + multirowflag = True + elif r'\multicolumn' in cellstr: + cellcontent[0] = self.__ModifyMulticolumn(cellstr,tablestyle_dict) + else: + if colindex == 1: + if len(columnwidthlst)>0 and colindex > 0: + flag = "|m{" + columnwidthlst[colindex - 1] + "}|" + else: + flag = "|l|" + else: + if len(columnwidthlst) > 0 and colindex > 0: + flag = "m{" + columnwidthlst[colindex - 1] + "}|" + else: + flag = "l|" + + #查找\sphinxAtStartPar位置 + precell = "" + pos = cellstr.find(r'\sphinxstyletheadfamily') + #print("cellstr = %s" % cellstr) + #print("pos=%d" % pos) + if pos > 0: + #说明前面还有内容,需要把前面的内容保存下来 + precell = cellstr[0:pos] + "\n" + if pos != -1: + cellstr = cellstr[pos+len(r'\sphinxstyletheadfamily'):len(cellstr)] + sphinxstyletheadfamily = "\\sphinxstyletheadfamily" + + ismanualwraplst=[False] + self.__delNotMustStrFromContent(cellstr,ismanualwraplst) + if ismanualwraplst[0] is True: + if colindex == 1: + flag = "|l|" + else: + flag = "l|" + colwidth = "" + if (len(columnwidthlst) > 0) and (colindex > 0): + colwidth = columnwidthlst[colindex-1] + + fullcontentstr = self.normalstr % { + 'flag': flag, + 'colwidth': colwidth, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustStrFromContent(cellstr) + } + else: + fullcontentstr = self.comcolstr % { + 'count': "1", + 'flag': flag, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustStrFromContent(cellstr) + } + #if ismulticol is True: + #必须用multicolumn命令,否则当单元格后面还有列的时候,刷新会出现问题 + #fullcontentstr = self.comcolstr % { + # 'count': "1", + # 'flag': flag, + # 'sphinxstyletheadfamily': sphinxstyletheadfamily, + # 'content': self.__delNotMustStrFromContent(cellstr) + #} + #else: + # fullcontentstr = self.commonstr % { + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(contentstr) + # } + cellcontent[0] =precell + fullcontentstr + #print('===========================') + return multirowflag + + def __FindMergeRowCountAndContent(self,multirowstr): + #找出合并行的数量和合并单元格的内容 + count = 0 + newmultistr = multirowstr + + try: + searchstr = "([1-9]*?)(?=,)" + match = re.match(searchstr,newmultistr,re.I) + if match is not None: + count = int(match.group(0)) + startpos = match.end() + newmultistr = multirowstr[startpos+1:len(multirowstr)] + finally: + return count,newmultistr + + def __AdjustMultirowContent(self,cellcontent,contentdict,lineindex,colindex,columnwidthlst): + ''' + 调整multirow单元格到对应的下面的单元格 + ''' + #查找flag数量 + cellstr = cellcontent[0] + #查找有没有\cline指令,如果有这个指令,为了划线清晰在前面再多加一个。 + searchstr= r"\\cline\{[\s\S]*?\}" + match = re.search(searchstr,cellstr,re.I | re.M) + if match is not None: + cline = match.group() + #去掉开头的换行符,否则影响后面的正则表达式,导致正则表达式查找失败 + cellstr = cellstr.lstrip() + cellstr = cline + cellstr + + searchstr = r"\\sphinxtablestrut\{(?P[0-9]*)\}" + matchobj = re.search(searchstr, cellstr, re.I | re.M | re.U) + flag = matchobj.group('flag') + #得到修改后的multirow内容 + if __dictIsHasKey__(contentdict,flag): + multirowstr = contentdict[flag] + #替换实际列号 + replstr = r'\\sphinxtablestrut{' + str(colindex)+'}' + cellstr = re.sub(searchstr,replstr,cellstr,1,re.I|re.U|re.M) + #找出合并行的最大数量 + count,multirowstr = self.__FindMergeRowCountAndContent(multirowstr) + if lineindex == count: #是最后一行了再合并,否则会重复 + cellcontent[0] = cellstr+"\n"+multirowstr + else: + cellcontent[0] = cellstr+'\n' + r"\cellcolor" + self.tablesattrobj.headtype + "{}" + + return cellcontent + + def __GetFactColIndex(self,cellcontent): + ''' + 得到真实的列索引。当有multicolumn的时候,列索引需要加合并单元格的数量。 + 返回合并单元格的数量,没合并单元格返回1 + :para cellcontent: 单元格内容字符串 + ''' + if len(cellcontent) ==0: + return 0 + searchstr = r"\\multicolumn{(?P[1-9]+)}" + match = re.search(searchstr,cellcontent,re.I|re.U) + if match is None: + return 1 #不是合并单元格,默认加1列 + factcount = 1 + try: + factcount = int(match.group('count')) + except Exception as e: + print('------------------') + print(e) + # traceback.print_stack() + traceback.print_exc() + print('------------------') + + return factcount + def __ModifyTableHeadForFamilyList(self,tablehead,isLongtable,columnwidthlst,jsonmultirow,tablestyle_dict): + ''' + 用列表的方式逐个拆解每个单元格内容,列表格式为[['cell content','分割符号&或者\\'],[],[]] + :param tablehead: + :return: + ''' + #内容的起始位置 + startpos = tablehead.find(r'\\hline') + len(r'\\hline') + + contentdict ={} #创建一个字典,用来保存mlutirow交换内容 + #得到每个单元格的内容 + colindex = 0 #是否第一个单元格的索引 + lineindex = 1 #行索引,为了支持多行合并,默认起始在第一行 + # multirowindex:第几个multirow单元格的索引号,用于作为conf.json中multirowcolumn字段的索引值, + # 取出当前multirow单元格真实所在列 + #得到该列的列宽,填写到multirow指令里面,实现multirow指令中的内容自动换行 + multirowindex = 0 + lineflag = 0 + newtablehead = "\\hline\n" + while startpos !=-1: + multirowflag = False # 是否是multirow处理,是的话__ModifyCellContentForMergeCell返回true,默认返回false + lineflag,startpos,cellcontent = self.__FindCellContentForHead(tablehead,startpos) + if cellcontent is not None: + factcolcount = self.__GetFactColIndex(cellcontent[0]) + colindex = colindex+factcolcount #当前列号 + if r"\sphinxtablestrut" in cellcontent[0]: + # 需要将multirow的内容copy到该单元格 + self.__AdjustMultirowContent(cellcontent, contentdict,lineindex,colindex,columnwidthlst) + else: + + multirowflag = self.__ModifyCellContentForMergeCell(cellcontent, + tablestyle_dict, + isLongtable, + contentdict, + columnwidthlst, + jsonmultirow, + colindex, + lineindex, + multirowindex) + newtablehead = newtablehead+cellcontent[0] + "\n" + cellcontent[1]+ "\n" + + if multirowflag is True: + multirowindex = multirowindex+1 + lineindex = lineindex + lineflag # 计算新的行数 + if lineflag: + # 如果开启了新行,则列索引从0开始计算 + colindex = 0 + + #组成新的表头内容 + newtablehead =newtablehead + "\\hline\n" + return newtablehead + def __modifyTableHeadForMinusSign(self,content,isLongTable,headtype,tablestyle_dict): + """ + 如果表头不是用“=”分割的,而是用“-”分割,调用该函数修改表头 + """ + # 先找出第一行 + if isLongTable: + searchstr = r'\\endlastfoot[\s]+(\\sphinxtableatstartofbodyhook)?(?P[\s\S]*?)\\hline' + else: + searchstr = r'\\hline(?P[\s\S]*?)\\hline' + m = re.search(searchstr, content, re.M | re.I | re.U) + headcontent = m.group('content') # 匹配到的第一个即为表头内容 + posarr = m.span('content') # 保存起始位置和结束位置,便于组成新的内容 + + if 'multicolumn' in headcontent: + return content + + if r'\sphinxstyletheadfamily' in headcontent: + pattern = re.compile(r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)', + re.M | re.I | re.U) + aftercontent = headcontent + # pattern = re.compile(r'(?<=\\sphinxstylethead{\\sphinxstyletheadfamily)([\s\S]*?)(?=\\unskip}\\relax &)', re.M | re.I|re.U) + else: + aftercontent = headcontent.replace(r'\\', '&', 1) + pattern = re.compile(r'(?P[\s\S]*?)(&\s{1})', re.M | re.I | re.U) + # pattern = re.compile(r'[\s\S]*?&|\\unskip', re.M | re.I|re.U) + + mobjarr = pattern.finditer(aftercontent) + headlist = [] + preposlist = [] + fontcolor = "" + for mobj in mobjarr: + amarr = mobj.group('value') + curposlist = [mobj.start(), mobj.start() + len(amarr)] + + # 用表头内容数组替换 + if self.tablesattrobj.headfontcolor != "": + fontcolor = self.tablesattrobj.headfontcolor + # 先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() # 再去掉首尾空格,避免多余的空格出现 + if amarr == '': + continue + amarr = self.__delNotMustStrFromContent(amarr) + fontcolor = fontcolor.replace('{}', '{' + amarr + '}', 1) + # 当表头没有合并单元格时,不能用cellcolor,否则表头不对齐 + # fontcolor = self.commonstr % { + # 'sphinxstyletheadfamily':'', + # 'content': self.__delNotMustStrFromContent(amarr) + # } + if len(preposlist) > 0: + headlist.append(headcontent[preposlist[1]:curposlist[0]]) + else: + headlist.append(headcontent[0:curposlist[0]]) + headlist.append(fontcolor) + preposlist = curposlist + + headlist.append(headcontent[preposlist[1]:len(headcontent)]) # 把最后一个字符串加上 + headcontent = '' + for prelist in headlist: + headcontent = headcontent + prelist + '\n' + newcontent = content[0:posarr[0]] + r'\rowcolor' + headtype + '\n' + headcontent + content[ + posarr[1]:len(content)] + return newcontent + #修改sphinx自动生成的长表格表头 + def __ModifyLongTableHead(self,content,columnwidthlst,headtype,jsonmultirow,tablestyle_dict): + + tablehead = self.__FindTableHeadForFamily(content,True) + if tablehead != "" and (r"\sphinxmultirow" in tablehead or r"\multicolumn" in tablehead): + #长表格自定义sphinxcolwidth的第2个参数默认为5,否则无法实现自动换行 + #self.normalstr= self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth':'5', + # 'sphinxstyletheadfamily':'%(sphinxstyletheadfamily)s', + # 'content':'%(content)s' + #} + newtablehead = self.__ModifyTableHeadForFamilyList(tablehead, + True, + columnwidthlst, + jsonmultirow, + tablestyle_dict) + return content.replace(tablehead,newtablehead,2) + + #先找出第一行 + if tablehead.strip() == r"\hline": + #进入该分支,说明表格设置为了长表格,但是表头没有用“=”分割,用“-”做的分割。 + return self.__modifyTableHeadForMinusSign(content,True,headtype,tablestyle_dict) + + searchstr = r'\\hline(?P[\s\S]*?)\\hline' + pattern = re.compile(searchstr,re.M | re.I|re.U) + matchiter = pattern.finditer(content) + posarr = [] + i = 0 + for m in matchiter: + + if i > 1: + break + + posarr.append([]) + posarr[i] = m.span() #保存起始位置和结束位置,便于组成新的内容 + + if i ==0: + newcontent = content[0:posarr[i][0]] + else: + newcontent = newcontent+content[posarr[i-1][1]:posarr[i][0]] + + newcontent += r'\hline\rowcolor'+headtype + headcontent = m.group(1) #匹配到的第一个即为表头内容 + + if 'multicolumn' in headcontent: + return content + + headlist = [] + + if r'\sphinxstyletheadfamily' in headcontent: + pattern = re.compile(r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)', re.M | re.I|re.U) + aftercontent = headcontent + mobjarr = pattern.finditer(aftercontent) + + preposlist = [] + for mobj in mobjarr: + amarr = mobj.group('value') + curposlist = mobj.span() + + #用表头内容数组替换 + fontcolor = self.tablesattrobj.headfontcolor + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + if amarr == '': + continue + amarr = self.__delNotMustStrFromContent(amarr) + fontcolor = fontcolor.replace('{}','{'+ amarr+'}',1) + # 当表头没有合并单元格时,不能用cellcolor,否则表头不对齐 + #fontcolor = self.commonstr % { + # 'sphinxstyletheadfamily': '', + # 'content': self.__delNotMustStrFromContent(amarr) + #} + if len(preposlist) > 0: + headlist.append(headcontent[preposlist[1]:curposlist[0]]) + else: + headlist.append(headcontent[0:curposlist[0]]) + headlist.append(fontcolor) + preposlist = curposlist + headlist.append(headcontent[preposlist[1]:len(headcontent)]) #把最后一个字符串加上 + headcontent = '' + for prelist in headlist: + headcontent = headcontent + prelist + '\n' + newcontent += headcontent+r'\hline' + i +=1 + newcontent += content[posarr[i-1][1]:len(content)] + return newcontent + + def __ModifyTableHead(self, content,columnwidthlst, headtype,jsonmultirow,tablestyle_dict): + + tablehead = self.__FindTableHeadForFamily(content,False) + if tablehead != "" and (r"\sphinxmultirow" in tablehead or r"\multicolumn" in tablehead): + #为了实现自动换行,普通表格自定义sphinxcolwidth的第2个参数默认为3,否则无法实现自动换行 + #self.normalstr= self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth':'3', + # 'sphinxstyletheadfamily':'%(sphinxstyletheadfamily)s', + # 'content':'%(content)s' + #} + # 得到每一列的列宽供multirow自动换行使用 + newtablehead = self.__ModifyTableHeadForFamilyList(tablehead, + False, + columnwidthlst, + jsonmultirow, + tablestyle_dict) + return content.replace(tablehead,newtablehead,2) + + return self.__modifyTableHeadForMinusSign(content,False,headtype,tablestyle_dict) + """ + #先找出第一行 + searchstr = r'\\hline(?P[\s\S]*?)\\hline' + m = re.search(searchstr, content, re.M|re.I|re.U ) + headcontent = m.group(1) #匹配到的第一个即为表头内容 + posarr = m.span(1) #保存起始位置和结束位置,便于组成新的内容 + + if 'multicolumn' in headcontent: + return content + + if r'\sphinxstyletheadfamily' in headcontent: + pattern = re.compile(r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)', re.M | re.I|re.U) + aftercontent = headcontent + #pattern = re.compile(r'(?<=\\sphinxstylethead{\\sphinxstyletheadfamily)([\s\S]*?)(?=\\unskip}\\relax &)', re.M | re.I|re.U) + else: + aftercontent = headcontent.replace(r'\\','&',1) + pattern = re.compile(r'(?P[\s\S]*?)(&\s{1})', re.M | re.I|re.U) + #pattern = re.compile(r'[\s\S]*?&|\\unskip', re.M | re.I|re.U) + + mobjarr = pattern.finditer(aftercontent) + headlist = [] + preposlist = [] + fontcolor="" + for mobj in mobjarr: + amarr = mobj.group('value') + curposlist = [mobj.start(),mobj.start()+len(amarr)] + + #用表头内容数组替换 + if self.tablesattrobj.headfontcolor !="": + fontcolor = self.tablesattrobj.headfontcolor + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + if amarr == '': + continue + amarr = self.__delNotMustStrFromContent(amarr) + fontcolor = fontcolor.replace('{}','{'+ amarr+'}',1) + #当表头没有合并单元格时,不能用cellcolor,否则表头不对齐 + #fontcolor = self.commonstr % { + # 'sphinxstyletheadfamily':'', + # 'content': self.__delNotMustStrFromContent(amarr) + #} + if len(preposlist) > 0: + headlist.append(headcontent[preposlist[1]:curposlist[0]]) + else: + headlist.append(headcontent[0:curposlist[0]]) + headlist.append(fontcolor) + preposlist = curposlist + + headlist.append(headcontent[preposlist[1]:len(headcontent)]) #把最后一个字符串加上 + headcontent = '' + for prelist in headlist: + headcontent = headcontent + prelist + '\n' + newcontent = content[0:posarr[0]]+r'\rowcolor'+headtype+'\n'+headcontent+content[posarr[1]:len(content)] + return newcontent + """ + + def __ModifySecondTableHeadForLongTable(self,headcontent): + #如果表头没有渲染过,对于只渲染第一列的长表格,分页的表格不能有重复的表头内容,需要将其删除 + startpos = headcontent.find(r"\hline") + newhead = headcontent[0:startpos] + "\\hline\n\\endhead" #因为前面删除了行结束符,最后再添加上 + return newhead + + def __GetTableRowCount(self,tablecontent): + ''' + 得到传进来的表格行数量 + 统计原则:latex以“\\”+回车换行为表格的一行,因此以"\\CRLF"或者“\\LF”作为一行, + 正则表达式统计"\\CRLF"或者“\\LF”的数量作为行的数量。 + :param tablecontent: 表格内容,必须去除了latex表格的前面参数部分,否则得到的第一行也包含其内容 + :return: 返回行的数量和每一行内容的列表,如果找不到的话,返回0和None。 + ''' + #以下正则表达式还可以解析出表格的每一行内容 + #在末尾补充一个换行符,避免如果最后一行以“\\”结尾,漏掉最后一行。 + tablecontent = tablecontent+'\n' + searchstr = "[\s\S]*?(?\\\\[\r|\r\n|\n|\n\r])[\s\S]*?" + pattern = re.compile(searchstr,re.I | re.M |re.U) + linelst = pattern.findall(tablecontent) + if linelst is not None: + return len(linelst),linelst + + return 0,None + + + def __ModifyVerticalLongTable(self, singletablecontent,tablestyle_dict,columnwidthlst): + #修改长表格的第一列 + #查找长表格的表头,如果是通过sphinx指令设置的长表格,表头分为三个部分, + #第一个表头以第一个\hline开始,以\endfirsthead结束 + #第二个表头,以\endfirsthead后续的内容开始,\endhead结束 + #下页继续的说明,以\endhead后续内容开始,以\endlastfoot结束,该部分内容不做处理 + #以\endlastfoot后续内容的第一部分即第一列的内容,需要添加背景色和字体颜色。后续的\hline按正常处理即可 + #先将所有表头内容找出来 + isCusHead = tablestyle_dict['isCusHead'] + headstartpos = singletablecontent.find(r"\hline") + headendpos = singletablecontent.find(r"\endfirsthead") + fcontent = singletablecontent[0:headstartpos] #字符串起始内容 + fhead = singletablecontent[headstartpos:headendpos+len(r"\endfirsthead")] #第一个表头内容 + #修改第一个表头内容 + if not isCusHead: #如果表头已经渲染过了,则不再对第一个单元格再渲染一边。 + fhead = self.__ModifyTableByLine(fhead,True,True,tablestyle_dict,columnwidthlst) + + #获得第2个表头的原始内容 + headstartpos = singletablecontent.find(r"\endhead") + shead = singletablecontent[headendpos+len(r"\endfirsthead"):headstartpos+len(r"\endhead")] #第2个表头的内容 + #修改第2个表头内容 + if not isCusHead: #如果表头已经渲染过了,则不再对第一个单元格再渲染一遍。 + shead = self.__ModifySecondTableHeadForLongTable(shead) + #获得第3部分内容 + headendpos = singletablecontent.find(r"\endlastfoot") + thead = singletablecontent[headstartpos+len(r"\endhead"):headendpos+len(r"\endlastfoot")] #第3部分表头内容 + othercontent = singletablecontent[headendpos+len(r"\endlastfoot"):len(singletablecontent)] #其它表格内容 + + #因为第3部分后续紧跟的内容即为下一行的内容,为了统一处理,在前面固定添加行元素,再修改完成后,再将其删除 + othercontent = "\\hline\n" + othercontent + #修改表格内容部分 + othercontent = self.__ModifyTableByLine(othercontent,True,True,tablestyle_dict,columnwidthlst) + #修改完成后,去掉固定添加的内容 + othercontent = othercontent[len("\\hline\n"):len(othercontent)] + return fcontent + fhead + shead + thead + othercontent + + def __RenderTableHeadByLine(self, linecontent, lineindex, isLongtable, columnwidthlst, jsonmultirow, tablestyle_dict, contentdict): + ''' + 该函数暂时未用 + 用列表的方式逐个拆解每个单元格内容,列表格式为[['cell content','分割符号&或者\\'],[],[]] + :return: + ''' + # 内容的起始位置 + searchstr = r"(\\hline|\\cline|\\sphinxcline)" + match = re.match(searchstr,linecontent) + if match is None: + return linecontent + if match.group() == r"\hline": + startpos = match.end() + newtablehead = "\\hline\n" + else: + startpos = 0 + newtablehead = "" + + # 得到每个单元格的内容 + colindex = 0 # 是否第一个单元格的索引 + lineindex = 1 # 行索引,为了支持多行合并,默认起始在第一行 + # multirowindex:第几个multirow单元格的索引号,用于作为conf.json中multirowcolumn字段的索引值, + # 取出当前multirow单元格真实所在列 + # 得到该列的列宽,填写到multirow指令里面,实现multirow指令中的内容自动换行 + multirowindex = 0 + lineflag = 0 + + while startpos != -1: + multirowflag = False # 是否是multirow处理,是的话__ModifyCellContentForMergeCell返回true,默认返回false + lineflag, startpos, cellcontent = self.__FindCellContentForHead(linecontent, startpos) + if cellcontent is not None: + factcolcount = self.__GetFactColIndex(cellcontent[0]) + colindex = colindex + factcolcount # 当前列号 + if r"\sphinxtablestrut" in cellcontent[0]: + # 需要将multirow的内容copy到该单元格 + self.__AdjustMultirowContent(cellcontent, contentdict, lineindex,colindex,columnwidthlst) + else: + + multirowflag = self.__ModifyCellContentForMergeCell(cellcontent, + tablestyle_dict, + isLongtable, + contentdict, + columnwidthlst, + jsonmultirow, + colindex, + lineindex, + multirowindex) + newtablehead = newtablehead + cellcontent[0] + "\n" + cellcontent[1] + "\n" + + if multirowflag is True: + multirowindex = multirowindex + 1 + lineindex = lineindex + lineflag # 计算新的行数 + if lineflag: + # 如果开启了新行,则列索引从0开始计算 + colindex = 0 + + # 组成新的表头内容 + #newtablehead = newtablehead + "\\hline\n" + return newtablehead + + def __ModifyLineContent(self,linecontent,isvertical,islongtable,lineindex,columnwidthlst,colno,maxcolno,tablestyle_dict,rowdict): + ''' + 修改行内容,当colno为-1时,则修改整行内容,否则根据colno的值修改单元格 + :param linecontent: 该行的内容,以行结束标志“\\”结束。 + :param columnwidthlst: 保存的每一列的列宽内容,根据tablecolumns参数解析出来 + 如果实现自动换行,必须用tablecolumns这个参数,不能用widths参数 + :param colno: 要修改的列号,如果为-1,则修改整行内容。否则根据colno的值修改单元格。 + 如果该行有合并列单元格,则colno为合并单元格后的列号。 + :param maxcolno: 将合并单元格拆分后,得到的最大的列号,该列号为了取得列宽,使multirow内容自动换行。 + :param rowdict: 传出参数,如果是multirow的合并单元格,则返回单元格的内容 + :return: 返回修改后的行的内容。 + ''' + if isvertical is False: + #修改行,暂时不支持修改任意行,表头的渲染由其他函数完成 + #按表头内容渲染 + + jsonmultirow = [] # 保存需要设置自动换行的合并单元格的列号。 + if __dictIsHasKey__(tablestyle_dict, 'multirowcolumn'): + jsonmultirow = tablestyle_dict['multirowcolumn'] + newlinecontent = self.__RenderTableHeadByLine(linecontent, + lineindex, + islongtable, + columnwidthlst, + jsonmultirow, + tablestyle_dict, + rowdict) + return newlinecontent + + #下面的内容修改列,即只修改行的第一个单元格 + linecontentstr=linecontent + #startpos = linecontent.find('\\hline') + #lineprefix = "" + #if startpos != -1: + # linecontentstr = linecontent[startpos + len("\\hline"):len(linecontent)] + # lineprefix = linecontent[0:startpos+len("\\hline")] + #得到行内容和前缀 + lineprefix = "" + searchstr=r"\\hline|\\cline{\S+?}|\\sphinxcline{\S+?}\\sphinxfixclines{\d+?}" + match = re.search(searchstr,linecontent,re.I|re.U|re.M) + if match is not None: + linecontentstr = linecontent[match.end():len(linecontent)] + lineprefix = linecontent[0:match.start()+ len(match.group())] + + othercontent = "" + newlinecontent = "" + + #暂时仅支持修改第一列 + lineflag,nextpos,cellcontent = self.__FindCellContentForHead(linecontentstr,0) + if (cellcontent is None) or (r"\cellcolor" in cellcontent[0]): + #说明该单元格已经修改过了,直接返回 + return linecontent + + othercontent = linecontentstr[nextpos:len(linecontentstr)] + + if r"\sphinxtablestrut" in cellcontent[0]: + # 需要将multirow的内容copy到该单元格 + self.__AdjustMultirowContent(cellcontent, rowdict, lineindex,maxcolno,columnwidthlst) + else: + jsonmultirow =[1] + cellstr = cellcontent[0] # 取得单元格内容 + + if r'\sphinxmultirow' in cellstr: + cellcontent[0] = self.__ModifySphinxMultirow(cellstr, + rowdict, + columnwidthlst, + jsonmultirow, + lineindex, + maxcolno, + 0, + tablestyle_dict, + isvertical) + else: + coluwidth = "*" + if len(columnwidthlst) > 0: + coluwidth = columnwidthlst[0] + # 查找\sphinxAtStartPar位置 + pos = cellstr.find(r'\sphinxstyletheadfamily') + if pos != -1: + contentstr = cellstr[pos + len(r'\sphinxstyletheadfamily'):len(cellstr)] + + fullcontentstr = self.commonstr % { + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content': self.__delNotMustStrFromContent(contentstr) + } + #fullcontentstr = self.multirowstr % { + # 'count': 1, + # 'coluwidth': coluwidth, + # 'sphinxstyletheadfamily': r"\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(contentstr) + #} + #fullcontentstr = self.normalstr % { + # 'flag': '|l|', + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(contentstr) + #} + + cellcontent[0] = fullcontentstr + else: + fullcontentstr = self.commonstr % { + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content': self.__delNotMustStrFromContent(cellstr) + } + #fullcontentstr = self.multirowstr % { + # 'count': 1, + # 'coluwidth': coluwidth, + # 'sphinxstyletheadfamily': r"\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(cellstr) + #} + #fullcontentstr = self.normalstr % { + # 'flag': '|l|', + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(cellstr) + #} + + cellcontent[0] = fullcontentstr + + newlinecontent = lineprefix +"\n" + cellcontent[0] + "\n" + cellcontent[1] + "\n" + othercontent + return newlinecontent + + + def __ModifyTableByLine(self,tablecontent,isvertical,islongtable,tablestyle_dict,columnwidthlst): + ''' + 按行渲染每一个单元格。根据正则表达式解析出每一行的内容, + 然后再解析这一行中每一个单元格的内容,根据条件判断是否需要渲染该单元格 + :param tablecontent: + 表格内容,对于长表格去除表格头的表格,对于普通表格,整个表格内容 + :param isvertical: + 是否修改列还是修改行 + :return: 返回修改后的内容 + ''' + #找出每一行的内容 + + searchstr = r'(\\hline|\\cline{\S+}|\\sphinxcline{\S+})(?P[\s\S]*?)(?=\\hline|\\cline{\S+}|\\sphinxcline{\S+})' + pattern = re.compile(searchstr,re.M | re.I|re.U) + # 得到每一行的内容,以\hline或者\cline开头,以\\结尾 + linelst = pattern.findall(tablecontent) + if len(linelst) ==0: + return tablecontent + + tableprefix = "" #得到latex表格的前缀参数部分,除表格内容外的部分 + tablepostfix = "" #得到latex表格的后缀参数部分,出表格内容外的部分 + prepos = tablecontent.find(''.join(linelst[0])) + if prepos > 0: + tableprefix = tablecontent[0:prepos] + lastline = ''.join(linelst[-1]) + postpos = tablecontent.find(lastline) + if (postpos+len(lastline)) < len(tablecontent): + tablepostfix = tablecontent[postpos+len(lastline):len(tablecontent)] + + newtablecontent = "" + rowdict = {} #保存合并单元格的字典,用于multirow + for lineindex in range(0,len(linelst)): + line =''.join(linelst[lineindex]) + newline= self.__ModifyLineContent(line,isvertical,islongtable,lineindex+1,columnwidthlst,1,1,tablestyle_dict,rowdict) + newtablecontent = newtablecontent + newline + + newtablecontent = tableprefix + newtablecontent+tablepostfix + return newtablecontent + + #渲染树型表格的第一列 + def __ModifyVerticalTable(self, singletablecontent, tablestyle_dict, columnwidthlst): + + #找出每一行的内容 + searchstr = r'(\\hline|\\cline{\S+}|\\sphinxcline{\S+})(?P[\s\S]*?)(?=\\hline|\\cline{\S+}|\\sphinxcline{\S+})' + pattern = re.compile(searchstr,re.M | re.I|re.U) + matchiter = pattern.finditer(singletablecontent) + posarr=[] #保存位置,便于组合 + i = 0 + for m in matchiter: + + posarr.append([]) + posarr[i] = m.span() + + if i ==0: + newcontent = singletablecontent[0:posarr[i][0]]+m.group(1) + else: + newcontent = newcontent+singletablecontent[posarr[i-1][1]:posarr[i][0]]+m.group(1) + + cellcontent = m.group('content') + #匹配到的第一个即为表头内容 + #将第一个单元格内容渲染成蓝底白字 + firstcellcontent = self.__ModifyFirstColumnType(cellcontent) + newcontent += firstcellcontent + i+=1 + newcontent += singletablecontent[posarr[i-1][1]:len(singletablecontent)] + return newcontent + + #渲染第一个单元格内容 + def __ModifyFirstColumnType(self,cellcontent): + + new_cellcontent = "" + + if r'\sphinxstyletheadfamily' in cellcontent: + + searchstr = r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)' + + aftercontent = cellcontent.strip() + aftercontent = aftercontent.strip('\r') + aftercontent = aftercontent.strip('\n') + aftercontent = aftercontent.strip() + mobj = re.search(searchstr, aftercontent, re.M|re.I|re.U ) #匹配到的第一个既是需要修改的内容 + + #修改字体颜色 + amarr = mobj.group('value') + posarr = mobj.span() + new_cellcontent = aftercontent[0:posarr[0]]+'\n'+r'\cellcolor'+self.tablesattrobj.headtype + #用表头内容数组替换 + fontcolor = self.tablesattrobj.headfontcolor + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + #if amarr == '': + # continue + if (r'\textbf' or r'\textcolor') in amarr: + return cellcontent + fontcolor = fontcolor.replace('{}','{'+ amarr+'}',1) + new_cellcontent +=r'{'+fontcolor + '}\n' + aftercontent[posarr[1]:len(aftercontent)] + else: + aftercontent = cellcontent.replace(r'\\','&',1) + #去掉首尾空格和换行符 + aftercontent = aftercontent.strip() + aftercontent = aftercontent.strip('\r') + aftercontent = aftercontent.strip('\n') + aftercontent = aftercontent.strip() + + tmplist = re.split(r'&',aftercontent) + + preposlist = 0 + #只对第一个做修改 + onelist = tmplist[0] + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + onelist = onelist.strip() + onelist = onelist.strip('\r') + onelist = onelist.strip('\n') + onelist = onelist.strip() #再去掉首尾空格,避免多余的空格出现 + #if onelist == '': + # continue + if (r'\textbf' or r'\textcolor') in onelist: + return cellcontent + new_cellcontent = self.__ModifyFirstColumnContentPart(onelist) + + for i in range(1,len(tmplist)): + if len(tmplist[i])>0: + new_cellcontent += '&' +tmplist[i] + new_cellcontent+=r'\\' #将最后被替换掉的\\再加上 + + return new_cellcontent + '\n' + + def __ModifyFirstColumnContentPart(self,firstcontent): + #查找firtcontent的内容部分,因为如果存在合并单元格,第一列的内容起始部分不一定是内容,可能是latex的其它标签。 + #sphinx生成的latex表格都以“\sphinxAtStartPar”标识开始 + startpos = firstcontent.find(r"\sphinxAtStartPar") + firstpart = "" + if startpos==-1: + contentpart = firstcontent + else: + firstpart = firstcontent[0:startpos] + contentpart=firstcontent[startpos:len(firstcontent)] + fontcolor = self.tablesattrobj.headfontcolor + # 先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + new_cellcontent = '\n'+r'\cellcolor'+self.tablesattrobj.headtype+r'{'+fontcolor.replace('{}','{'+ contentpart+'}',1)+r'}'+'\n' + return firstpart+new_cellcontent + + def __ModifyReplaceContent(self): + ''' + 修改replace内容。为了删除sphinx在替换时多加的空格,在替换时,每个替换内容前后都加了"@@"。 + 对于前后都加了“@@”的内容,删除前后的"@@",同时删除前后的空格。 + :return: + ''' + searchstr = r"@@(?P[\s\S]+?)@@" + pattern = re.compile(searchstr) + match = pattern.search(self.content) + if match is None: + return None + str = match.group() + replace = match.group('content') + + prepos = match.start() + endpos = match.end() + replaceed = str + if prepos > 1: + #判断前一个字符是否为空格 + prechar = self.content[prepos-1] + if prechar==" ": + replaceed = " "+replaceed + + endchar = self.content[endpos] + if endchar == " ": + replaceed = replaceed + " " + + #进行替换 + self.content = self.content.replace(replaceed,replace) + + return self.content + + + def ModifyReplaceContent(self): + + while True: + ''' + 可能会有多个需要替换,直到找不到为止 + ''' + result = self.__ModifyReplaceContent() + if result is None: + break + +class clsSensitiveWords: + + def __init__(self,entext): + self.passkey="Hello123" + self.entext = entext + self.jsonwords = GetSensitiveword() + + def GetFullSensitiveWords(self): + senwords = "|".join(self.jsonwords) + detext = "" + if self.entext !="": + command="echo '" + self.entext +"'| openssl aes-128-cbc -d -base64 -pbkdf2 -k " + self.passkey + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + detext = detext + line.decode().strip() + senwords = senwords + '|' + detext + + return senwords + +class clsCheckDocCopyrightYear: + ''' + 检查文档是否符合规范。 + 2022年07月12日更新 + ------------------- + 增加对当前年份的检查,如果年份不符合要求,自动修改conf.py和copyright.rst文档中的年份。 + ''' + def __init__(self,app,copyrightfile): + self.app=app + self.confdir = app.confdir + self.language = app.config.language + self.confright = app.config.copyright + if (app.config.latex_elements is not None) and __dictIsHasKey__(app.config.latex_elements,'preamble'): + self.latex_preamble = app.config.latex_elements['preamble'] + else: + self.latex_preamble = "" + + if len(copyrightfile) > 0: + self.copyrightfile = copyrightfile + elif hasattr(app.config,"copyfile_path") and len(app.config.copyfile_path)>0: + self.copyrightfile = app.config.copyfile_path + else: + self.copyrightfile = "" + + def __FindCopyrightFile(self): + ''' + #如果copyrightfile为空,则查找copyrightfile文件 + ''' + if self.copyrightfile =="": + + copyrightpre = "copyright/copyright" + filepath = os.path.join(self.app.srcdir,copyrightpre + ".rst") + + if os.path.isfile(filepath): + return filepath + + if self.language=="zh_CN": + filepath = os.path.join(self.app.srcdir,copyrightpre + "_zh.rst") + else: + filepath = os.path.join(self.app.srcdir,copyrightpre +"_en.rst") + + if os.path.isfile(filepath): + return filepath + else: + return "" + else: + filepath = os.path.join(self.confdir , self.copyrightfile) + + #转为绝对路径 + filepath = os.path.abspath(filepath) + if os.path.isfile(filepath): + return filepath + else: + return "" + + def __IsModifyfootYear(self,year): + ''' + 查找lfoot设置版权的语句年份是否需要修改 + :return: True - 需要修改年份 + False - 不需要修改年份 + ''' + + ismodify = False + searchstr = r'\\lfoot{[ ]*Copyright [\s\S]*?}' + pattern = re.compile(searchstr,re.I|re.U) + matchiter = pattern.finditer(self.latex_preamble) + if matchiter is None: + return ismodify + + for m in matchiter: + lfoot = m.group() + if year in lfoot: + #如果当前年份已在脚注中,则认为年份已修改,则不再修改。 + break + else: + # 如果当前年份不在脚注中,则认为年份还未修改,需要修改。 + ismodify = True + break + return ismodify + + def __ModifyYearofConfforDetail(self,year,ispremble=False): + ''' + 修改conf.py中的年份 + :param ispremble: 是否只修改ispremble中脚注的年份。 + True-只修改脚注的年份;False-同时修改copyright和脚注的年份。 + ''' + #打开文件 + conffile = self.confdir +"/conf.py" + fo = codecs.open(conffile, "r+", encoding='utf-8') + textcontent = fo.read() + fo.close() + + if ispremble is False: + #修改copyright的年份 + searchstr = r"[0-9]{4}" + matchobj = re.search(searchstr,self.confright) + if matchobj is not None: + copyright = re.sub(searchstr,year,self.confright) + textcontent=textcontent.replace(self.confright,copyright,1) + #为了看到最终效果,同时修改已经读取的变量,因为调用该函数前,conf.py的配置已经被sphinx读取 + self.app.config.copyright=copyright + + #开始替换脚注中的年份 + #得到脚注中带有年份的声明 + searchstr = r'\\lfoot{[ ]*Copyright [\s\S]*?}' + pattern = re.compile(searchstr, re.I | re.U) + matchiter = pattern.finditer(self.latex_preamble) + if matchiter is not None: + for m in matchiter: + lfoot = m.group() + searchstr= r"[0-9]{4}" + matchobj = re.search(searchstr,lfoot) + if matchobj is not None: + newlfoot = re.sub(searchstr,year,lfoot) + textcontent=textcontent.replace(lfoot,newlfoot) + #两个脚注相同,因此只替换一次即可 + # 为了看到最终效果,同时修改已经读取的变量,因为调用该函数前,conf.py的配置已经被sphinx读取 + self.app.config.latex_elements['preamble'] = self.latex_preamble.replace(lfoot,newlfoot) + break + #将替换后的内容重新写入文件 + fw = codecs.open(conffile, "w+", encoding='utf-8') + fw.write(textcontent) + fw.close() + + def __ModifyYearofCopyright(self,year): + + #得到copyright绝对文件名 + filepath = self.__FindCopyrightFile() + if filepath == "": + return + + fo = codecs.open(filepath, "r+", encoding='utf-8') + textcontent = fo.read() + fo.close() + #新的版权年份 + newyear = '© ' + year + #if newyear in textcontent: + # return + isupdate = False + #避免替换错误,循环找到最后一个进行替换 + searchstr = "©( )*([0-9]{4})" + pattern = re.compile(searchstr, re.I | re.U) + matchiter = pattern.finditer(textcontent) + for m in matchiter: + oldyear = m.group() + if oldyear == newyear: + continue + textcontent = textcontent.replace(oldyear,newyear,1) + isupdate = True + + if isupdate: + fw = codecs.open(filepath, "w+", encoding='utf-8') + fw.write(textcontent) + fw.close() + #pattern = re.compile(searchstr, re.I | re.U) + #matchlst = pattern.findall(textcontent) + #if len(matchlst)==0: + # #如果没有找到四数字的年份,直接返回 + # return + + #i=0 + #content="" + #matchiter = pattern.finditer(textcontent) + #for m in matchiter: + # if (i+1)==len(matchlst): + # #对最后一个年份进行替换 + # content = textcontent[0:m.start()] + year + textcontent[m.end():len(textcontent)] + # break + # else: + # i = i+1 + + #fw = codecs.open(filepath, "w+", encoding='utf-8') + #fw.write(content) + #fw.close + + def __ModifyYearofConf(self,year): + #开始修改conf.py中的年份 + if (year not in self.confright) and (self.confright != ""): + #修改conf.py中的年份 + self.__ModifyYearofConfforDetail(year) + return + #开始修改conf.py中脚注的年份 + if (self.latex_preamble!="") and (self.__IsModifyfootYear(year) is True): + #需要修改年份,走到这一步conf.py的copyright年份肯定修改过的,因此只修改foot中的年份即可 + self.__ModifyYearofConfforDetail(year,True) + return + + def CheckYearIsRight(self): + ''' + 检查年份是否正确 + ''' + #得到系统当前年份 + today = dt.date.today() + year = today.strftime('%Y') + #开始检查conf.py中的年份 + self.__ModifyYearofConf(year) + #开始检查copyright文件中的年份 + self.__ModifyYearofCopyright(year) + + def __AddAdSymbolForReplaceZh(self,content): + ''' + 解析是否需要替换,需要的话就进行替换。 + 替换的条件:1.被替换字符串全部为中文,2。被替换字符的首尾有一个不为英文 + :param content: 要替换的内容 + :return: 已替换或者未替换的字符串 + ''' + #用正则表达式解析出要替换的内容 + searchstr = r"[\s\S]+ replace:: (?P[\s\S]+)" + match = re.search(searchstr,content) + if match is None: + return content + replacestr = match.group('replace').strip('\r') + #取得字符串的首和尾字符 + #prochr = ord(replacestr[0]) + #print(prochr) + #epichr = ord(replacestr[-1]) + #print(epichr) + #if ((64 < prochr and prochr<91) or \ + # (96 < prochr and prochr<123)) and \ + # ((64 < epichr and epichr < 91) or \ + # (96 < epichr and epichr < 123)): + # #首尾都为英文则无需替换空格,直接返回 + # return content + newreplacestr = "@@"+replacestr+"@@" + return content.replace(replacestr,newreplacestr) + + def __ParseReplaceList(self,contentlst): + ''' + 为中文替换词前后添加@@符号 + :param contentlst: 需要替换的list内容 + :return: 新生成列表,如果无需替换,返回None。 + ''' + if contentlst is None or len(contentlst)==0: + return None + + for i in range(0,len(contentlst)): + content = contentlst[i] + if content=="": + continue + #进行替换 + newcontent = self.__AddAdSymbolForReplaceZh(content) + contentlst[i] = newcontent + + return contentlst + + def CheckReplaceContent(self): + ''' + 该函数用来判断conf.py文件中是否包含rst_epilog或者rst_prolog的配置 + 如果有的话,修改这两个参数中的配置在需替换的名词前后都加@@符号,以方便在latex文件中删除前后的空格。 + 该修改为了解决sphinx对中文的替换前后都加空格的bug。 + :return: + ''' + epilogstr = self.app.config.rst_epilog + prologstr = self.app.config.rst_prolog + + if self.app.config.language != "zh_CN": + ''' + 如果不是中文文档,则直接返回。 + ''' + return + + if prologstr is not None: + # 为rst_prolog的中文替换前后添加@@符号 + prologlst = prologstr.split('\n') + newlst = self.__ParseReplaceList(prologlst) + if newlst is not None: + self.app.config.rst_prolog = '\n'.join(newlst) + + if epilogstr is not None: + #为rst_epilog的中文替换前后添加@@符号 + epiloglst = epilogstr.split('\n') + newlst = self.__ParseReplaceList(epiloglst) + if newlst is not None: + self.app.config.rst_epilog = '\n'.join(newlst) + + + +class clsCheckHtmlRule: + ''' + 检查html配置是否符合规范 + ''' + def __init__(self,app): + self.app=app + #错误描述模板 + #配置参数错误 + self.configparam_error="Parameter %(param)s configuration error!" + self.correctvalue="The correct value is %(value)s!" + self.config_error=self.configparam_error + self.correctvalue+"\n" + #工程名称配置不一致 + self.inconsistent_error="The project names configured by conf.py and %(rootdoc)s.rst are inconsistent!\n" + #缺少custom.css文件 + self.csslack_error="custom.css file not found!\n" + self.config_detail = "For detailed configuration, see wiki: pageId=46137795\n" + + def __CheckCommonParamForHtml(self, theme): + ''' + 检查相关配置是否正确 + :param theme:为主题文件,如果和配置的不一致返回错误。默认为sphinx_rtd_theme。 + ''' + error = "" + if self.app.config.html_theme != theme: + error = self.config_error % { + 'param': 'html_theme', + 'value': theme + } + if self.app.config.html_copy_source is True: + error = error + self.config_error % { + 'param': 'html_copy_source', + 'value': 'False' + } + # 判断配置的css文件是否存在 + if len(self.app.config.html_css_files) == 0: + error = error + self.configparam_error % { + 'param': 'html_css_files', + } + \ + "Please configure a css file, such as custom.css." + "The css file needs to be found under the path specified by html_static_path!" + if len(self.app.config.html_static_path) == 0: + error = error + self.configparam_error % { + 'param': 'html_static_path', + } + \ + "Please configure the path of the file specified by html_css_files." + return error + + def CheckProjectNameConsistent(self): + ''' + 检查conf.py中project配置工程名称和主索引文件index.rst中的标题是否一致 + 两者不一致返回错误。 + 该接口必须在builder-inited事件之后调用,否则得不到index.rst的标题,无法比较。 + ''' + #为了兼容性获取config对象的属性,没有的属性不做比较 + cfgattrlst=dir(self.app.config) + #主索引文件名 + if "master_doc" in cfgattrlst: + indexname = self.app.config.master_doc + elif "root_doc" in cfgattrlst: + indexname = self.app.config.root_doc + else: + return "" + + #得到主索引文件的标题 + indextitle = self.app.env.titles[indexname][0] + #得到配置的工程名称 + projectname = self.app.config.project + if indextitle != projectname: + error = self.inconsistent_error % { + 'rootdoc': indexname + } + return error + else: + return "" + + def CheckCommConfigHtmlRule(self, theme): + ''' + 检查生成html前参数配置是否符合规范 + ''' + # 开始检查参数配置 + error = self.__CheckCommonParamForHtml(theme) + return error + +def CheckProjectNameIsConsistent(app): + ''' + 检查conf.py中配置的工程名称与主索引文件的标题是否一致 + ''' + checkDocobj = clsCheckHtmlRule(app) + return checkDocobj.CheckProjectNameConsistent() + +def CheckHtmlConfigIsCorrect(app,theme='sphinx_rtd_theme'): + ''' + 检查html配置参数是否正确 + ''' + checkDocobj = clsCheckHtmlRule(app) + return checkDocobj.CheckCommConfigHtmlRule(theme) + +def CheckCurrentYearAndModify(app,copyrightfile=""): + ''' + 检查conf.py和copyright中的年份是否为当前年份,如果不是当前年份则进行修改。 + Conf.py主要检查latex_elements和copyright中的年份是否为当前年份 + copyright主要检查copyright.rst文件最后一句的年份是否为当前系统年份。 + :param app: + :param copyrightfile: 相对于conf.py的相对路径。 + ''' + checkDocobj = clsCheckDocCopyrightYear(app,copyrightfile) + #检查年份是否需要修改 + checkDocobj.CheckYearIsRight() + #if (app.builder is not None) and \ + # (app.builder.name=='latex' or app.builder.name=='latexpdf'): + # checkDocobj.CheckReplaceContent() + +def __returnCfgDate(app): + ''' + 返回conf.py中配置的today的日期。如果today为空,则返回当前日期。 + ''' + if app.config.today != "": + return app.config.today + today = dt.date.today() + if app.config.language =="zh_CN": + return today.strftime('%Y年%m月%d日') + else: + return today.strftime('%b %d, %Y') + +def CheckUpdateHistory(app,docname,source): + ''' + 检查是否需要添加更新历史,在source_read_handler这个事件里调用。 + 为了提高效率,不写在类里,减少创建和销毁对象的消耗。 + :param app: + :param docname: + :param source: 源文件内容 + ''' + #判断文档中是否包含更新历史 + # 默认为英文 + title = "Update History" + update="Date:" + changes="Changes:" + language = app.config.language + if language == "zh_CN": + title = "更新历史" + update="更新时间:" + changes="更新内容:" + index = source.find(title) + if index==-1: + return source + + #说明该文档中包含更新历史,判断是否为更新历史标题 + precontent = source[0:index] + leftcontent = source[index:len(source)] + #按行读取剩余内容,理论上第2行为标题的下标,如果不是,则认为这个更新历史不是标题,直接忽略。 + linelst = leftcontent.splitlines(True) + flagline = linelst[1] + search = r"([\x21-\x2f|\x3a-\x40|\x5b-\x60|\x7b-\x7d])\1+(?=[\r|\n|\r\n])" + matchobj=re.match(search,flagline) + if matchobj is None: + return source + #说明是更新历史章节,再比较版本号是否一致 + search = r"\* \*\*V(?P[0-9]{1,2}\.[0-9]{1,2}\.[0-9]{1,2})\*\*" + matchobj=re.search(search,leftcontent) + if matchobj is None: + return source + version=matchobj.group('version') + if version==app.config.version: + return source + #如果版本不一致,自动添加新的版本内容 + dateformat = __returnCfgDate(app) + updatehistory="\n* **V%(version)s**\n\n" \ + " **%(date)s** %(dateformat)s\n\n" \ + " **%(changes)s**\n\n" \ + " - 无内容更新。\n" + + upcontent= updatehistory % { + 'version':app.config.version, + 'date':update, + 'dateformat':dateformat, + 'changes':changes + } + #拆分列表 + titlestr =''.join(linelst[0:2]) + otherstr =''.join(linelst[2:len(linelst)+1]) + newcontent = precontent + titlestr + upcontent + otherstr + #将新内容写入文档 + filepath = app.project.doc2path(docname) + fw = codecs.open(filepath, "w+", encoding='utf-8') + fw.write(newcontent) + fw.close() + return newcontent + + +def __exe_command_forprint(command): + """ + 执行 shell 命令并实时打印输出 + :param command: shell 命令 + :return: process, exitcode + """ + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + print(line.decode().strip()) + exitcode = process.wait() + return process, exitcode + +def CheckSensitiveWordsForLinux(app,docname,entext): + ''' + 该接口对外提供,可在conf.py的source_read_handler处理函数中调用。对sphinx工程中的文件检查敏感词。 + + **注意:** + + 因为检查敏感词使用的linux命令,如果敏感词中包含中文,则在不支持中文的linux系统下不能使用该接口, + 否则无法完成敏感词。 + + :param app: sphinx环境变量,可以获取输出目录、文件后缀等相关信息。 + :param docname: 需检查的文件路径。该路径是基于源目录的相对路径,而且不包含文件扩展名, + 因此在检查时需要补充上扩展名,并转为绝对路径。 + :param entext: 已加密的密文,如果为空,则只对conf.json中保存的敏感词做检查。 + 如何加密参考GetFullSensitiveWords的实现,加解密保持一致。 + :return: 无返回值。 + ''' + #windows系统请使用工具组开发的sensitive-filter.exe工具 + try: + if platform.system().lower() == 'windows': + return + + filepath = app.project.doc2path(docname,True) + + #print("Checking file:["+docname+"]") + sensitivewordsobj = clsSensitiveWords(entext) + sensitivewords = sensitivewordsobj.GetFullSensitiveWords() + #print(sensitivewords) + cmd = "egrep -rni --color=always '"+ sensitivewords + "' " + filepath + __exe_command_forprint(cmd) + except Exception as e: + print("Sensitive words check failed, the system may not support Chinese. " + "The detailed error information is as follows:") + print(e) + return + +# 打开Makefile文件查找source和build文件夹 +def OpenMakefile(): + global source_dir + global build_dir + source_dir = '' + build_dir = '' + try: + if platform.system().lower() == 'windows': + with open('make.bat',"r") as f: + fstr = f.read() + + #用正则表达式查找source和build文件夹具体路径 + searchstr = r"set *SOURCEDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + source_dir = m.group(1) #匹配到的第一个即为source所在目录 + + searchstr = r"set *BUILDDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + build_dir = m.group(1) #匹配到的第一个即为build所在目录 + else: + with open('Makefile',"r") as f: + fstr = f.read() + + #用正则表达式查找source和build文件夹具体路径 + searchstr = r"SOURCEDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + source_dir = m.group(1) #匹配到的第一个即为源所在目录 + + searchstr = r"BUILDDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + build_dir = m.group(1) #匹配到的第一个即为源所在目录 + + except Exception as e: + print(e) + return + +def GetLatex_documents(): + global source_dir + if source_dir == '': + return + #得到配置文件conf.py的路径 + if source_dir == '.': + confdir = './conf.py' + else: + confdir = './' + source_dir +'/conf.py' + + conffile = os.path.abspath(confdir) + #打开conf.py文件 + with codecs.open(conffile,"r+",encoding='utf-8') as f: + fstr = f.read() + list = [] + versioninfo = GetVersionInfo(fstr) + fileprefile = GetFilePreInfo(fstr) + if versioninfo=="" or fileprefile=="": + #根据正则表达式,找出latex_documents内容 + searchstr = r"latex_documents *= *\[([\s\S]*?)\]" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + latex_documents = m.group(1) #匹配到的第一个即为源所在目录 + #拆分二维数组,兼容多个情况 + list = latex_documents.split(")") + for i in range(len(list)): + if IsComment(list[i]): + list[i]= list[i].split(",") + list.pop() + else: + #创建2维列表 + for i in range(1): + list.append([]) + for j in range(2): + list[i].append("") + list[0][0] = "comment" #为了兼容文件解析内容,进行补位。 + list[0][1] = '"' + fileprefile + versioninfo + '.tex"' #为了兼容解析过程,添加双引号。 + return list + +#得到版本信息 +def GetVersionInfo(fstr): + + if releasever != '': + return releasever + + versioninfo = "" + searchstr = r"version *= *(u*)'(.*?)'" + m = re.search(searchstr, fstr, re.I) + if m is not None: + versioninfo = m.group(2) #匹配到的第一个即为源所在目录 + print("version = " + versioninfo) + return versioninfo + +#得到文件前缀信息 +def GetFilePreInfo(fstr): + filepre = "" + searchstr = r"curfnpre *= *(u*)'(.*?)'" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + if m is not None: + filepre = m.group(2) #匹配到的第一个即为源所在目录 + return filepre + +#判断是否为注释行 +def IsComment(instr): + + if instr.strip() is None: + return False + + rule = re.compile('^#.*$') + if rule.match(instr.strip()) is None: + return True + else: + return False + +#根据正则表达式取单引号和双引号中的内容 +def getquomarkcontent(strarr): + #根据正则表达式,找出双引号和单引号中的内容 + searchstr = r"[\"|'](.*?)[\"|']" + m = re.search(searchstr, strarr, re.M|re.I|re.U ) + if m is None: + return None + return m.group(1).strip() #匹配到的第一个即为源所在目录 + +def GetConfjsondict(app): + ''' + 得到配置文件conf.json的字典 + :param app: + :return: + ''' + global __load_dict__ + + + if app is not None: + #先判断conf.json路径在配置文件中是否配置 + if hasattr(app.config, "confjson_path") and len(app.config.confjson_path) > 0: + confpath = os.path.abspath(app.confdir + "/" + app.config.confjson_path) + else: + confpath = os.path.join(app.confdir ,"conf.json") + + if os.path.exists(confpath): + __load_dict__ = __openconfjsonfile__(confpath) + else: + __load_dict__ = __openconfjsonfile__() + else: + __load_dict__ = __openconfjsonfile__() + + +def Modifylatex_main(build_dir,latexdocumentslst,app=None,language='zh_CN'): + + global __load_dict__ + global doc_language + doc_language=language + + if len(__load_dict__) == 0: + GetConfjsondict(app) + + + doclen = len(latexdocumentslst) + for i in range(0,doclen): + #得到latex路径 + latexpath = build_dir + desbkpaperfile= os.path.join(latexpath,'chapterbkpaper.pdf') + #copy 背景图到latex路径,背景图必须与该文件在同一个目录下 + if (app is not None) and hasattr(app.config,'chapterbkpaper_image') and (app.config.chapterbkpaper_image!=""): + #转为绝对路径进行判断,兼容conf.py和parsejson.py不在同一目录的情况 + bkpaperimage = os.path.abspath(os.path.join(app.confdir,app.config.chapterbkpaper_image)) + if os.path.isfile (bkpaperimage): + shutil.copy(bkpaperimage,desbkpaperfile) + elif os.path.exists('./chapterbkpaper.pdf'): + shutil.copy('./chapterbkpaper.pdf', latexpath) + else: + #取配置文件所在目录下的chapterbkpaper.pdf,将其copy到latex文件夹 + bkpaperimage = os.path.abspath(os.path.join(app.confdir,"chapterbkpaper.pdf")) + if os.path.isfile(bkpaperimage): + shutil.copy(bkpaperimage, latexpath) + + #得到相对路径 + filename = latexdocumentslst[i][1] + if filename is None: + continue + texfilepath = os.path.join(latexpath,filename) + #相对路径转绝对路径 + texfile = os.path.abspath(texfilepath) + if not os.path.exists(texfile): + continue + fo = codecs.open(texfile, "r+",encoding = 'utf-8') + texcontent = fo.read() + fo.close() + + #得到修改tex文件的对象 + ModTexobj = clsModifyTex(texcontent) + ModTexobj.AddPackageToTex() + ModTexobj.AddOtherTocToTex() + ModTexobj.AddCustormOptionsToTex() + ModTexobj.ModifyReplacePackage() + ModTexobj.ModifyFunctionBackColor() + ModTexobj.ModifyTablesAttributes() + #if (app is not None) and \ + # (app.builder.name=='latex' or app.builder.name=='latexpdf'): + # ModTexobj.ModifyReplaceContent() + + fw = codecs.open(texfile, "w+",encoding = 'utf-8') + fw.write(ModTexobj.content) + fw.close() + +def parseargv(): + global source_dir + global latex_dir + if len(sys.argv) <= 1: + return + for i in range(1,len(sys.argv)): + param = sys.argv[i] + if param=='-c': #解析conf.py所在目录 + source_dir = sys.argv[i+1] #保存相对路径 + #print("argv source=" + source_dir) + i = i+2 + elif param=='-l': #保存输出的latex目录 + latex_dir = sys.argv[i+1] #保存相对路径 + i = i+2 + #print("latex_dir = "+latex_dir) + +source_dir = '' #保存源文件所在目录 +build_dir = '' #保存生成文件所在目录 +releasever = '' +latex_dir = '' +doc_language ='' +__load_dict__ = {} + +if __name__ == '__main__': + + parseargv() #解析系统参数 + OpenMakefile() + + latex_documents = GetLatex_documents() #保存latex文档所在数组 + __load_dict__ = __openconfjsonfile__() + + latexdir="" + if latex_dir == '': + latexdir = '/latex/' + else: + latexdir = '/' + latex_dir + '/' + + doclen = len(latex_documents) + for i in range(0,doclen): + #得到latex路径 + latexpath = './' + build_dir + latexdir + #copy 背景图到latex路径,背景图必须与该文件在同一个目录下 + if os.path.exists('./chapterbkpaper.pdf'): + shutil.copy('./chapterbkpaper.pdf',latexpath) + #得到相对路径 + if getquomarkcontent(latex_documents[i][1]) is None: + continue + texfilepath = latexpath + getquomarkcontent(latex_documents[i][1]) + #相对路径转绝对路径 + texfile = os.path.abspath(texfilepath) + if not os.path.exists(texfile): + continue + fo = codecs.open(texfile, "r+",encoding = 'utf-8') + texcontent = fo.read() + fo.close() + + #得到修改tex文件的对象 + ModTexobj = clsModifyTex(texcontent) + ModTexobj.AddPackageToTex() + ModTexobj.AddOtherTocToTex() + ModTexobj.AddCustormOptionsToTex() + ModTexobj.ModifyFunctionBackColor() + ModTexobj.ModifyTablesAttributes() + ModTexobj.ModifyReplacePackage() + + fw = codecs.open(texfile, "w+",encoding = 'utf-8') + fw.write(ModTexobj.content) + fw.close() diff --git a/torch_mlu_ops-v1.3.2/docs/release_notes/version.rst b/torch_mlu_ops-v1.3.2/docs/release_notes/version.rst new file mode 100644 index 0000000..9e945dd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/release_notes/version.rst @@ -0,0 +1,220 @@ +V1.3.2 +=================== + +特性变更 +----------------- + +- 废弃原legacy目录下的BangTransformer大模型推理网络性能评测代码。 +- ``fused_layer_norm`` 与 ``fused_rms_norm`` 的输入和输出支持更多的 ``stride`` 组合。 + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 + +V1.3.1 +=================== + +特性变更 +----------------- + +- ``smooth_quant_group_gemm`` 与 ``fused_moe`` 算子支持w4w8混合量化。 +- ``reshape_paged_cache`` 算子支持 ``v`` 和 ``v_cache`` 传入None。 +- ``quant_to_paged_cache`` 算子支持 ``v`` 、 ``v_cache`` 和 ``v_cache_quant_scale`` 传入None。 +- ``offline_quant_to_paged_cache`` 算子支持 ``v`` 和 ``v_cache_scale`` 传入None。 +- ``copy_blocks`` 算子支持 ``v_caches`` 传入None。 +- ``moe_softmax_topk`` 支持 ``mask`` 广播及定制化 ``normalize`` 。该算子的 ``Input`` 和 ``Mask`` 必须保证连续。 +- ``single_query_cached_kv_attn`` 支持 ``window_size_left`` 。 +- 新增 ``dequant_from_paged_cache`` 算子。 +- ``dequant_from_linear_cache`` 算子不再支持float32的 ``key`` 和 ``value``。 + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 + + +V1.3.0 +=================== + +特性变更 +----------------- + +- ``group_gemm`` 与 ``smooth_quant_group_gemm`` 算子的max_m参数会影响性能,其默认值不一定是最佳性能,将由参数可选修改成必填参数。 +- ``moe_softmax_topk`` 支持mask功能。 +- 支持导出算子 ``gen_case`` 功能。 +- 新增 ``dequant_from_linear_cache`` 算子。 +- ``moe_softmax_topk`` 支持mask广播及定制化normalize。 + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 + + +V1.2.3 +=================== + +特性变更 +----------------- + +- 适配CNToolkit 3.15.X相关特性。 +- 不再支持Ubuntu20.04操作系统。 +- ``single_query_cached_kv_attn`` 与 ``flash_attention`` 算子支持 ``head_size_qk != head_size_v``。 +- ``group_gemm`` 与 ``smooth_quant_group_gemm`` 算子支持bias。 +- ``matmul`` 删除原位输出参数,增加指定输出类型参数。 + +已修复问题 +--------------------- + +- ``quant_to_linear_cache`` 算子修复精度问题。 +- 修复 ``matmul`` 算子形状推导问题。 +- 修复Debug模式下 ``quant_to_linear_cache`` 算子编译问题。 + +已知遗留问题 +-------------- + +无。 + + +V1.2.2 +=================== + +特性变更 +----------------- + +- ``smooth_quant_group_gemm`` 与 ``fused_moe`` 支持int4 group量化功能。 +- ``allreduce`` 类算子删除 ``act_mode`` 。 +- ``weight_only_quant_matmul`` 与 ``smooth_quant_matmul`` 新增控制激活计算方式。 + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 + +V1.2.1 +=================== + +特性变更 +----------------- + +- ``moe_combine_result`` 算子优化吞吐场景下性能。 +- ``quant_to_linear_cache`` 算子新增group量化和int4量化功能。 +- 新增 ``moe_cast_gating`` 算子。 +- 新增 ``update_out_and_lse`` 算子。 +- ``fused_rope`` 算子支持int8/int4 kv cache。 +- ``matmul`` 与 ``batch_matmul`` 新增支持trans_a, trans_b。 +- 新增 ``single_query_mixed_cached_kv_attn`` 算子。 +- ``single_query_cached_kv_attn`` 支持output_lse。 +- ``fused_layer_norm`` 与 ``fused_rms_norm`` 支持输出动态量化。 +- legacy目录下的BangTransformer网络评测代码,仅支持在PyTorch2.1环境下编译和运行。 + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 + +V1.2.0 +=================== + +特性变更 +----------------- + +- ``moe_softmax_topk`` 算子新增grouped_topk功能。 +- ``moe_softmax_topk`` 算子不再支持原位功能。 +- ``moe_gen_idx`` 算子不再支持原位功能。 +- Torch-MLU-Ops首次支持PyTorch2.5,不再支持PyTorch2.3。 +- 新增 ``fused_rope`` 算子。 +- ``matmul`` 算子新增支持INT8输入。 +- 新增 ``batch_matmul`` 算子。 + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 + +V1.1.4 +=================== + +特性变更 +----------------- + +- 新增 ``offline_quant_to_paged_cache`` 算子。 +- 新增 ``moe_gen_idx`` 算子。 +- 新增 ``moe_expand_input`` 算子。 +- 新增 ``moe_combine_result`` 算子。 +- 新增 ``moe_quantize`` 算子。 +- 新增 ``moe_softmax_topk`` 算子。 +- 删除 ``quant_matmul`` 算子,由 ``smooth_quant_matmul`` 和 ``weight_only_quant_matmul`` 实现其功能。 +- ``flash_attention`` 算子新增 ``block_tables``, ``k/v_cache_quant_scale`` 参数。 +- ``matmul`` 算子支持激活配置参数。 +- ``fused_moe`` 算子支持量化EP。 +- 新增 ``moe_active`` 算子。 + +已修复问题 +--------------------- + +- 修复 ``fused_moe`` 算子通算融合模式的精度问题。 +- 修复 ``moe_combine_result`` 算子在EP模式下特定规模下的coredump问题。 +- 修复 ``fused_norm`` 算子非连续情况下的精度问题。 + +已知遗留问题 +-------------- + +无。 + +V1.1.3 +=================== + +特性变更 +----------------- + +- BangTransformer更名为Torch-MLU-Ops, 定位PyTorch第三方算子库。对于使用PyTorch框架的开发者,通过Torch-MLU-Ops,能够便捷地使用这些自定义算子,进行算子的集成、评测和业务部署。 +- bt_ops的命名空间变化为torch_mlu_ops。 +- 原BangTransformer的LLM网络推理评测相关内容被迁移到lagacy目录下进行维护。 +- 后续在寒武纪计算卡上的LLM网络推理评测建议使用Cambricon vLLM、Cambricon TGI、Cambricon Stable Diffusion web UI、Cambricon ComfyUI以及Cambricon Diffusers组件。 +- Single Query Cached Attention算子支持per_token量化和per_channel量化。 +- Fused Moe算子在非量化情况下支持EP模式,内部支持group gemm和allreduce并行。 +- 提供smooth_quant_matmul_allreduce、matmul_allreduce、flash_attn_sq_mm_allreduce通算融合。 +- 更新flash_attention、single_query_cached_kv_attn、fused_rms_norm、fused_layer_norm的接口说明。 + + +已修复问题 +--------------------- + +无。 + +已知遗留问题 +-------------- + +无。 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/Makefile b/torch_mlu_ops-v1.3.2/docs/user_guide/Makefile new file mode 100644 index 0000000..12c851d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = torch_mlu_ops_rst +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/README.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/README.rst new file mode 100644 index 0000000..81a94dd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/README.rst @@ -0,0 +1,90 @@ +使用说明 +============================= + +重要文件及文件夹说明 +----------------------------- + +**注意:** + + 工程下载后的所有内容,都不能删除,否则会造成工程编译失败。 + +* source文件夹:保存rst源文件。所有新增的rst文件及文件夹务必都保存在source文件夹中,否则生成的文件将不包含新增内容。 +* build文件夹:保存构建后的文件,该文件在执行make html和./makelatexpdf.sh后会自动生成。build及其子文件夹都是构建后自动生成,所以无需关注。 + + - 执行make html生成html文件,保存在“build/html”文件夹中。 + - 执行./makelatexpdf.sh生成latex和pdf文件,保存在“build/latex”文件夹中。 + +* makelatexpdf.sh:核心编译脚本,生成latex和pdf文件。 + + **注意:** + + 下载工程后,请执行“chmod +x makelatexpdf.sh”命令修改该文件的权限,否则脚本无法执行。 + +* source/index.rst:核心索引文件,sphinx依据该文件生成文件的大纲目录。所有新增的rst文件都要添加到该文件中,否则生成的文档中将不包含新增内容。 + +文档写作约束 +-------------------------------- + +**rst语法参考网址** :http://www.pythondoc.com/sphinx/rest.html + +rst文件为纯文本文件,因此可以使用vim、notepad++或者vscode等任何文本编辑器编写。其中vim天然支持rst,会对rst语法进行高亮显示。 +vscode可以通过下载插件的方式实现rst语法高亮和实时预览,但是vscode对rst表格的兼容性不是很好。 + +**注意:** + + 不管用哪种编辑器,都务必将tab设置为4个空格,否则格式可能会达不到预期,或者引入编译告警。尤其影响表格的编辑,因此务必将tab设置为4个空格。 + +* 文档编译服务器,工程可以下载到该服务器编写编译。目前只有该服务器搭建了文档编译环境,能够完成文档的编译。 +* 所有新增内容,都务必放在“source”文件夹下,不能放在其它文件夹,避免工程编译失败,达不到预期效果。 +* 一级大纲或者新建的独立模块建议新建文件夹,rst源文件都保存在新建文件夹下,以便工程有清晰的层级结构,方便维护。可以参考“preface”文件夹的构建。 +* 文件目录深度最多为4级,超过4级不方便用户阅读。4个级别大纲的标识按以下顺序: + + 1. 一级大纲标识:=; + #. 二级大纲标识:-; + #. 三级大纲标识:^; + #. 四级大纲标识:~; + + 举例如下: + + :: + + 我是一级大纲举例 + ======================== + 我是二级大纲举例 + ------------------------ + 我是三级大纲举例 + ^^^^^^^^^^^^^^^^^^^^^^^^ + 我是四级大纲举例 + ~~~~~~~~~~~~~~~~~~~~~~~~ + +使用举例 +---------------------------------------- + +以新建“前言”章节为例,举例说明如何新增内容,新增内容后,如何编译生成html和pdf文件。下载的工程中“前言”章节已经添加完成,可以参考。 + +* 步骤一:通过ssh登陆文档编译服务器。 +* 步骤二:在Torch-MLU-Ops仓库下载user_guide文件夹及其内容。 +* 步骤三:执行命令“chmod +x makelatexpdf.sh”修改makelatexpdf.sh脚本的权限,避免无法编译。 +* 步骤四:在“source”文件夹下,新建“preface”文件,保存前言相关的rst文件。 +* 步骤五:在“preface”文件夹下,新建两个文件:preface.rst和index.rst。 + + - preface.rst:用来保存前言相关内容 + - index.rst:用来保存"preface"文件夹下所有rst文件的索引。 + + 将preface.rst添加到index.rst。 + + **说明** + + preface文件夹下新建index.rst是为了方便管理同一文件夹下的内容,这样在外层的主索引文件“index.rst”中只要添加该文件夹的index索引文件即可。 + + 也可以不新建index.rst,将preface文件夹下的文件一一添加到外层的主索引文件“index.rst”中。为了方便管理建议每个文件夹下都新建index.rst统一管理该文件夹下的索引。 + +* 步骤六:按照rst语法和文档写作约束编写preface.rst”文件。 +* 步骤七:将“preface/index.rst”文件添加到外层的主索引文件“index.rst”中。 +* 步骤八:执行“./makelatexpdf.sh”脚本生成pdf文件,保存在“build/latex”文件夹下。执行“make html”命令生成html文件,保存在“build/html”文件夹下。 + + **注意:** + + 在执行“./makelatexpdf.sh”或者“make html”过程中,要注意sphinx红色字体的“WARNING"告警,务必将告警清除,否则可能无法生成文件或者生成的文件无法达到预期。 + +经过以上步骤,“前言”章节添加成功,可以查看最终生成的pdf和html文件,检查最终效果。 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/_ext/customdirective.py b/torch_mlu_ops-v1.3.2/docs/user_guide/_ext/customdirective.py new file mode 100644 index 0000000..22eaa49 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/_ext/customdirective.py @@ -0,0 +1,1610 @@ +#!/usr/bin/python +# -*- coding: UTF-8 + +import os +import sys +import warnings +import re +from sphinx.directives.other import TocTree +from docutils.parsers.rst import Directive +from docutils.parsers.rst import directives +from docutils.parsers.rst.roles import set_classes +from docutils import nodes + +from sphinx.util.docutils import SphinxDirective +#from sphinx.util.typing import OptionSpec +from docutils.statemachine import ViewList +#from sphinx.util.nodes import nested_parse_with_titles +#import traceback + +class CNOnly(Directive): + + ''' + cnonly指令,实现在一个列表内、一个段落内、一个表格内按定义的tag,条件输出不同内容。 + cnonly指令包含的内容为同一个列表、同一个段落、同一个表格,如果本身就是两个段落的内容,请使用原始only指令。 + cnonly指令和only指令的区别: + 1. 可以用在一个列表内,被指令包含的item与前后item之间无空行。only指令默认就是两个段落,因此指令内容和前后item之间有空行,效果比较难看。 + 2. 用在一个段落内,按tag条件输出。only指令只能用于两个段落,不支持在同一段落内使用。 + 3. 可用于一个表格内,按tag条件输出。only指令不支持用于表格内。 + ''' + + has_content=True + + ''' + notailspace + --------------- + cnonly指令参数,是否在指令内容后添加空行,无需符值。 + 添加了该参数,则在指令内容和下一段之间不添加空行,则指令内容和下一段内容属于同一段落。 + 不添加该参数,则指令内容和下一段之间添加空行,指令内容和下一段内容为独立的两个段落。 + cnonly指令默认不添加该参数。 + ''' + option=["notailspace"] + + def run(self): + return [] + +class CNOnlyPro(): + + def __init__(self,app,docname,sourcestr): + self.tags=app.tags + self.docname=docname + self.source = sourcestr + + #判断cnoly指令是否存在notailspace参数,如果有该参数,则在指令内容和剩余内容间不添加空行,比如表格。 + #添加空行的目的是为了本来该是两个段落的内容,避免合并成一个段落。 + #默认没有该参数,添加空行。 + def __parsecnonlyoption(self,optionstr)->bool: + if optionstr.strip()==":notailspace:": + return True + else: + return False + #根据tag判断cnonly指令是否应该包含 + def __parsecnonlylst(self,cnonlystr,onlyspacecount,onlylst,spacelst): + + regstr = r".. cnonly::([\s\S]*)" + pattern = re.compile(regstr,re.I|re.U) + tagstr = pattern.search(cnonlystr).group(1).strip() + #判断tag是否存在 + #print('tagstr=%s,onlycontent=%d' % (tagstr,len(onlycontent))) + tagsflag = parsejointtags(self.tags,tagstr) + #print(tagsflag) + if (not tagsflag) or (len(onlylst)==0): + #不包含该tag,则tag包含内容全部删除 + return "" + else: + delspacecount = 0 + for i,j in enumerate(onlylst): + if j.strip()=="": + continue + if delspacecount==0: #cnonly后的第一个列表前面添加0个空格,因为要和cnonly前的空格保持一致。 + delspacecount= spacelst[i]-onlyspacecount + #print("delspacecount=%d" % delspacecount) + #后续的所有列表都删掉前面的delspacecount数量的空格 + onlylst[i] = dellefspaces(j,delspacecount) + #print('onlylst[%d]=%s' % (i,onlylst[i])) + + onlycontentstr = "\n".join(onlylst)+'\n' + + return onlycontentstr + #根据cnonly指令将字符串拆分为三部分 + def parsecnonlycontent(self,sourcestr): + regstr = r"( *).. cnonly::([\s\S]*)" + pattern = re.compile(regstr,re.I|re.U) + matchobj = pattern.search(sourcestr) + + if matchobj != None: + #转字符串为列表,方便操作 + startpos = matchobj.start() #得到起始位置 + if startpos>0: + #先转成列表,删除指令前的空行 + presourcelst = sourcestr[0:startpos].split('\n') + presourcelst.pop() + if presourcelst[-1].strip()=="": + presourcelst.pop() + presource=('\n').join(presourcelst)+'\n' + else: + presource="" + leftsource = sourcestr[startpos:len(sourcestr)] + #将带cnonly的字符串转成列表,方便解析 + sourcelst = leftsource.split('\n') + #列表的第一个元素即为cnonly字符串 + #循环找出cnonly字符串涵盖的内容 + onlyspacecount = countspaces(sourcelst[0]) + onlycontent=[] + spacelst=[] + notailspace= self.__parsecnonlyoption(sourcelst[1]) #判断第一行是否为“notailspace”参数行。 + if notailspace: + startpos=3 #忽略掉指令行、参数行、和空行 + else: + startpos=2 #忽略掉指令行、参数行 + for i,j in enumerate(sourcelst[startpos:]): + if j=="" or j.isspace()==True: + onlycontent.append(j) + continue + + spacecount=countspaces(j) + + if spacecount>=(onlyspacecount+3): + spacelst.append(spacecount) #将空格数量加到空格数量列表 + onlycontent.append(j) + else: + break + + #删除指令后,内容前的空行 + while True: + if len(onlycontent) > 0 and (onlycontent[0]=="" or onlycontent[0].isspace()==True): + onlycontent.pop(0) + else: + break + + #删除指令内容后的空行 + while True: + if len(onlycontent) > 0 and (onlycontent[-1]=="" or onlycontent[-1].isspace()==True): + onlycontent.pop() + else: + break + #print("i=%d" % i) + #经过解析后onlycontent包含cnonly指令包含的所有内容 + onlycontentstr = self.__parsecnonlylst(sourcelst[0],onlyspacecount,onlycontent,spacelst) + leftsource ='\n'.join(sourcelst[i+startpos:]) + #print(leftsource) + leftsource = self.parsecnonlycontent(leftsource) + if notailspace: + return presource+onlycontentstr+leftsource + else: + return presource+ onlycontentstr +'\n' + leftsource + else: + return sourcestr + +def parsejointnobrackets(tags,tagstr): + ''' + 解析没有括号的and or组合。因为and的优先级高于or,因此从or开始拆分字符串。 + :param tags: + :param tagstr: + :return: + ''' + orlst = tagstr.split(r" or ") + if len(orlst)<=1: + #说明全是and的组合 + return parsemultitags(tags,tagstr) + + result = False + for i, j in enumerate(orlst): + if j ==" or " or len(j)==0 or j.isspace() is True: + continue + #and组合的内容 + result = parsemultitags(tags,orlst[i]) + if result is True: + break + + return result + +def parsejointtags(tags,tagstr): + ''' + 解析带括号的逻辑判断,支持括号嵌套 + :param tags: + :param tagstr: + :return: + ''' + #根据正则表达式,先将括号内容筛选出来 + searchstr = r"\(([^\(\)]*?)\)" + mobjlst = re.findall(searchstr,tagstr,re.I) + if len(mobjlst) == 0: + #如果没有括号,就按and or的混合组合处理 + return parsejointnobrackets(tags,tagstr) + + mobjarr = re.finditer(searchstr,tagstr,re.I) + prepos = 0 #字符串拆分的起始位置 + newtagstr = ""#组合后的新的字符串 + for mobj in mobjarr: + newtagstr += tagstr[prepos:mobj.start()] + #判断括号中的组合结果,根据结果连接成新的字符串 + #内部第一层括号 + result = parsejointnobrackets(tags,mobj.group(1)) + newtagstr += "_"+str(result)+"_" #为了避免和True、Flase标签冲突,前后添加下划线 + prepos = mobj.end() + + if len(tagstr) > prepos: + #再将剩余的内容加入到firsttaglst中 + newtagstr += tagstr[prepos:len(tagstr)] + + #递归调用,避免有括号嵌套 + return parsejointtags(tags,newtagstr) + +#判断是否有指定的tag +def parsemultitags(tags,tagstr)->bool: + #该函数判断是否有多个tags的组合,比如 or组合,and组合。 + #当前仅支持全是or或者全是and的组合,不支持or和and的混合组合。 + #or组合优先,有or组合先按or组合判断。 + strlst=tagstr.split(r" or ") + #print("tags or:" + ",".join(strlst)) + if len(strlst) <=1: + #如果不能根据or分割出数据,再根据and分割 + strlst=tagstr.split(r" and ") + #print("tags and:" + ",".join(strlst)) + if len(strlst) <=1: + #"_True_"是括号中的内容已经被解析后的结果,因此不用在重新解析了 + if ("_True_" == tagstr.strip()) or tags.has(tagstr.strip()): + return True + else: + return False + else: + count = len(strlst) + icount = 0 + for i in range(0,count): + if ("_True_" == strlst[i].strip()) or tags.has(strlst[i].strip()): + icount+=1 + #and组合全部都有则返回true,有一个标签不包含就返回fasle + if icount == count: + return True + else: + return False + else: + for i in range(0,len(strlst)): + if ("_True_" == strlst[i].strip()) or tags.has(strlst[i].strip()): + return True + return False +#得到字符串开头的空格数量 +def countspaces(str): + for i, j in enumerate(str): + if j != ' ': + return i + +#删除字符串前指定数量的空格 +def dellefspaces(str,count): + if count<=0: + return str + + for i in range(0,count): + if str[0]==" ": + str = str[1:] + return str + +#自定义cntoctree指令,使其支持only指令 +#使用方式:使用toctree指令的地方,直接修改为cntoctree,则在cntoctree指令的下面可以使用only指令,根据条件编译包含的文件。 +class TocTreeFilt(TocTree): + ''' + 该指令实现在大纲目录中按定义的条件输出,即在大纲目录中可以使用only指令。 + + 使用方法 + ----------------- + 将toctree直接替换为cntoctree即可。 + + 注意 + ----------------- + cntoctree中的only指令仅支持tag的全or或者全and的组合,不支持or和and的混合组合,也不支持使用“()”括号进行分组。 + ''' + def __GetOnlyStrByReg(self,contentstr): + #根据正则表达式,得到字符串中所有的only完整的字符串,并放到list列表里 + onlylst=[] + regstr = r"(.. only::[\s\S]*?)," + pattern = re.compile(regstr,re.I|re.U) + onlylst = pattern.findall(contentstr) + return onlylst + + #判断only的内容 + def __GetOnlyContent(self,content): + + onlystr = content[0] #第一个元素为only指令所在的字符串 + #得到only字符串前面的空格数 + onlyspacecount=countspaces(onlystr) + #print("onlyspacecount=%d,onlystr=%s" % (onlyspacecount,onlystr)) + #找出only包含的所有内容,以缩进的空格为依据。only指令包含的内容,必须缩进数量>(onlyspacecount+3) + onlycontent=[] + leftcontent=[] + endonly=False + for i,j in enumerate(content[1:]): + if j=="": + continue + spacecount=countspaces(j) + #print("spacecount=%d" % spacecount) + if spacecount>=(onlyspacecount+3) and (not endonly): + onlycontent.append(j) + else: + leftcontent.append(j) + endonly = True + + if len(onlycontent)>0: + #判断only标签是否已定义 + #根据only字符串,返回修改后的列表 + #根据正则表达式,解析出only后的tag标签 + regstr = r".. only::([\s\S]*)" + pattern = re.compile(regstr,re.I|re.U) + tagstr = pattern.search(onlystr).group(1).strip() + #print('tagstr=%s' % tagstr) + tagsflag = parsejointtags(self.env.app.tags,tagstr) + if not tagsflag: + onlycontent.clear() + return leftcontent,onlycontent + + def __filter_only(self,content): + #解析only指令 + #print(content) + onlystr = '.. only::' + #将列表转为字符串,用英文逗号链接,方便判断 + contentstr = ','.join(content) + newcontent=[] + count = len(content) + for i,j in enumerate(content): + #print("j=%s,i=%d" % (j,i)) + + if onlystr in j: + #print(content[i:]) + leftcontent,onlycontent = self.__GetOnlyContent(content[i:]) + if len(onlycontent)>0: + newcontent += self.__filter_only(onlycontent) + if len(leftcontent)>0: + newcontent += self.__filter_only(leftcontent) + + break + else: + newcontent.append(j) + + return newcontent + + def run(self): + # 过滤only条件 + self.content = self.__filter_only(self.content) + #清除目录前后的空格,否则无法找到该目录 + for i,j in enumerate(self.content): + self.content[i]=self.content[i].strip() + + return super().run() + +#--------------------------------------------------------------------------- + +#latex字体与pt的对应表,数值的单位为pt,pt与px的对应关系pt=px * 3/4 +latex_fontsize={5:r'\tiny',7:r'\scriptsize',8:r'\footnotesize', + 9:r'\small',10:r'\normalsize',11:r'\large', + 14.4:r'\Large',17.28:r'\LARGE',20.74:r'\huge',24.88:r'\Huge'} + +fontmaxsize_pt = 25 #超过25pt,字体为最大字体Huge,仅用于latex +fontmaxsieze_px = 34 #超过34px,字体为最大字体Huge,仅用于latex + + + +class cnautowrapnode(nodes.Inline, nodes.TextElement): + """ + cnautowrap node类,用于解析autowrap节点,支持长英文的换行 + """ + pass + +def __ModifyPureTextforlatex(text): + """ + 修改原来的文本,如果文本中有\sphinxhyphen{}或者\PYGZhy{} + 需要添加大括号进行分割,否则\seqsplit命令无效 + """ + #用正则表达式查找\sphinxhyphen{}或者\PYGZhy{} + newtext = "" + startpos = 0 + #searchstr = r"\\sphinxhyphen\{\}|\\PYGZhy\{\}" + searchstr = r"-|—|——" + m = re.finditer(searchstr, text, re.I | re.U) + for match in m: + newtext = newtext + text[startpos: match.start()] + " :raw-latex:`{\PYGXhy{}}` " + startpos = match.end() + if len(text) > startpos: + newtext = newtext + text[startpos:len(text)] + + #print(newtext) + return newtext + +def autowrap_role(name, rawtext, text, lineno, inliner, options={}, content=[]): + """ + autowrap role主函数,该函数实现对有符号连续长英文的自动换行。 + """ + #attrs ={'style':'color:RGB(255,0,0);font-weight:bold;font-style:italic;font-size:15px'} + #判断文本中间有没有其它指令,默认有红色属性 + #puretext = text + #print(text) + env = inliner.document.settings.env + if hasattr(env, 'autowrapfile') and (env.docname not in env.autowrapfile): + #print("-----------html enter---------------------") + env.autowrapfile.append(env.docname) + else: + #print("-----------html enter1---------------------") + env.autowrapfile = [] + env.autowrapfile.append(env.docname) + #print('===================') + #print('make latex') + #print('===================') + #puretext = __ModifyPureTextforlatex(text) #为特殊指令添加大括号,否则将失败 + nodelst = [] + nodelst.append(cnautowrapnode(rawtext,text,*content,**options)) + #print(cnautowrap(rawtext,puretext,*content,**attrs)) + return nodelst,[] + + +def latex_visit_autowrap_node(self, node): + #print('-----------------') + #print(node.attributes) + #print(node.rawsource) + # print(type(self)) + #print('-----------------') + __modifyNodeforlatexraw(node) + self.body.append(r"\seqsplit{") + + +def latex_depart_autowrap_node(self, node): + #print(node) + self.body.append('}') + + +def html_visit_autowrap_node(self, node): + pass + +def html_depart_autowrap_node(self, node): + pass + +class cncolorrolenode(nodes.Inline, nodes.TextElement): + """ + cncolorrole node类,用于解析cncolor role节点 + """ + pass +def __GetComplexColorAttr(attrstr,colorkey): + """ + :param colorkey: color或者bg,查找字体颜色还是背景颜色。 + """ + clrkey = colorkey +':' + + match = re.search(clrkey,attrstr) + if match is None: + return '',attrstr + + #如果有类似关键字,为了正则表达式的准确性,先查找是否有rgb颜色值 + searchstr = clrkey + '(RGB\((0\.)?[0-9]+,(0\.)?[0-9]+,(0\.)?[0-9]+\))' + match = re.search(searchstr,attrstr,re.I) + if match: + pre = '' + left = '' + if match.start()>0: + pre = attrstr[0:match.start()] + if match.end() < len(attrstr): + left = attrstr[match.end()+1:len(attrstr)] + + return match.group(1),pre + left + attrstr +=',' #末尾加逗号,避免在最后的位置正则表达式匹配失败 + #再查找是否有纯颜色,比如:color:yellow + searchstr = clrkey +'([a-zA-Z]+?)(?=,)' + match = re.search(searchstr,attrstr,re.I) + if match: + pre = '' + left = '' + if match.start() > 0: + pre = attrstr[0:match.start()] + if match.end() < len(attrstr)-1: + left = attrstr[match.end()+1:len(attrstr)-1] #因为正则表达式排除了逗号,因此删除剩余的逗号 + + return match.group(1), pre + left + else: + return '', attrstr.strip(',') #去掉末尾的逗号 + + + +def __GetcomplexfzAttr(attrstr): + attrstr += ',' # 末尾加逗号,避免在最后的位置正则表达式匹配失败 + #查找是否有字体大小属性 + searchstr ='fz:([\s\S]+?)(?=,)' + match = re.search(searchstr,attrstr,re.I) + if match: + pre = '' + left = '' + if match.start() > 0: + pre = attrstr[0:match.start()] + if match.end() < len(attrstr)-1: #去掉末尾逗号 + left = attrstr[match.end()+1:len(attrstr)-1] + + return match.group(1), pre + left + else: + return '',attrstr.strip(',') #去掉末尾逗号 + +def __GetRGBColorandAttrlist(attrstr): + """ + 判断属性是否包含RGB()格式的自定义颜色属性 + 返回RGB格式颜色和其它属性列表,无RGB格式属性,该返回值为""。 + """ + #先判断是否有独立颜色 + searchstr = 'rgb' + matchobj = re.search(searchstr, attrstr, re.I) + if matchobj is None: + # 没有自定义的RGB颜色值,返回属性列表 + return "", attrstr.split(",") + + #在判断是否有RGB颜色 + searchstr = r'RGB\([0-9]+,[0-9]+,[0-9]+\)' + matchobj = re.search(searchstr,attrstr) + if matchobj: + newattr = attrstr[0:matchobj.start()]+attrstr[matchobj.end()+1:len(attrstr)] + return matchobj.group(),newattr.split(",") + + #在判断是否有rgb颜色 + searchstr = r'rgb\(0\.[1-9]+,0\.[1-9]+,0\.[1-9]+\)' + matchobj = re.search(searchstr,attrstr) + if matchobj: + newattr = attrstr[0:matchobj.start()]+attrstr[matchobj.end()+1:len(attrstr)] + return matchobj.group(),newattr.split(",") + + +def __GetAttrAndPureText(text): + """ + 判断是否有属性内容,属性内容放到纯文本的前面,被“{}”包围,如果要显示“{}”, + 请在开头添加转义字符,即“\{}”。如果内容不符合规则也会当作纯文本。 + 属性内容的写法:{属性1,属性2,属性3,属性4},中间用英文逗号“,”分割,属性无固定顺序关系。 + 支持的属性: + + * 字体颜色属性,语法:color:<要设置的颜色>,支持的关键字:red,green,blue,black,white,yellow,RGB(,,),rgb(,,)。 + 除了上面提到的固定关键子的颜色外,为了同时兼容html和latex,请使用大写RGB的方式定义颜色。 + 默认字体颜色为红色,添加red属性关键字。 + * 背景颜色属性,语法:bg:<要设置的颜色>,支持的关键字同上。 + * 字体大小属性,语法:fz:<要设置的字体大小>。 + * 是否为斜体,关键字:itatic,添加了该关键字,则设置字体为斜体,默认非斜体。 + * 字体是否加粗,关键字:bold,添加了该关键字,则对字体加粗处理,默认不加粗。 + 注意:斜体和粗体属性,如果同时被设置,latex只对字体加粗,斜体设置无效。html支持同时设置。 + * 是否添加下画线,关键字:underline,添加了该关键字,则对字体加添加下划线,默认不加下划线。 + * 是否添加删除线,关键字:strike,添加了该关键字,则对字体加添加删除线,默认不加删除线。 + + 比如,如果将一段字体设置为红色带下划线的斜体,则写法如下: + :cncolor:`{color:red,bg:yellow,fz:50px,strike,itatic}描述` + 以上写法将“描述”设置为红色带下划线的斜体。 + """ + + #判断文本开头是否符合属性定义的语法,如果有转义字符或不是以“{”开头,则直接返回。 + attrs = {} + if ord(text[0]) == 92 or ord(text[0]) != 123: + return {'style': 'color:red'},text + #查找属性定义 + #searchstr = "(?<=\{)([\s\S]+?)(?=\})" #使用match匹配,不支持?<=语法 + searchstr = "{([\s\S]+?)}" + matchobj = re.match(searchstr,text,re.I|re.M) + if matchobj is None: + return {'style': 'color:red'},text + + #中间有可能换行,将换行符替换掉 + attrstr = re.sub(r'\n|\r','',matchobj.group(1).strip('\n').strip('\r'),re.M) + endpos = matchobj.end() + puretext = text[endpos:len(text)] + isattr = False #是否有属性内容,默认没有属性内容 + #先判断是否设置了字体颜色 + rgbcolor,leftstr = __GetComplexColorAttr(attrstr,'color') + if rgbcolor == "": + attrs['style'] = 'color:red' + else: + attrs['style'] = 'color:' + rgbcolor + isattr = True + #判断是否有背景色 + bgcolor,leftstr = __GetComplexColorAttr(leftstr,'bg') + if bgcolor !='': + attrs['style'] = attrs['style'] + ';background:' + bgcolor + isattr = True + #判断是否设置了字体尺寸 + fontsize,leftstr = __GetcomplexfzAttr(leftstr) + if fontsize !='': + attrs['style'] = attrs['style'] + ';font-size:' + fontsize + isattr = True + #用逗号分割leftstr + attrlst = leftstr.split(',') + # 判断是否有斜体属性 + if 'italic' in attrlst: + attrs['style'] = attrs['style'] + ";font-style:italic" + isattr = True + if 'bold' in attrlst: + attrs['style'] = attrs['style'] + ";font-weight:bold" + isattr = True + if 'underline' in attrlst: + attrs['style'] = attrs['style'] + ";text-decoration:underline" + isattr = True + if 'strike' in attrlst: + attrs['style'] = attrs['style'] + ";text-decoration:line-through" + isattr = True + if isattr: + return attrs, puretext + else: + return attrs, text + ''' + #根据逗号进行分割 + rgbcolor,attrlst = __GetRGBColorandAttrlist(attrstr) + if rgbcolor !="": + attrs['style'] = 'color:' + rgbcolor + isattr = True + else: + #判断属性里是否有颜色设置,先判断是否有固定颜色值的设置 + cuscolor = list(set(colorlst) & set(attrlst)) + if len(cuscolor)==1: + #有多个固定颜色的设置认为出错 + # 有固定颜色值设置,属性换为自定义颜色 + attrs['style'] = 'color:'+cuscolor[0] + isattr = True + else: + attrs['style']= 'color:red' + + #判断是否有斜体属性 + if 'italic' in attrlst: + attrs['style'] = attrs['style'] + ";font-style:italic" + isattr = True + if 'bold' in attrlst: + attrs['style'] = attrs['style'] + ";font-weight:bold" + isattr = True + if 'underline' in attrlst: + attrs['style'] = attrs['style'] + ";text-decoration:underline" + isattr = True + if 'strike' in attrlst: + attrs['style'] = attrs['style'] + ";text-decoration:line-through" + isattr = True + + if isattr: + return attrs,puretext + else: + return attrs,text + ''' + +def cncolor_role(name, rawtext, text, lineno, inliner, options={}, content=[]): + """ + cncolor role主函数,该函数实现对指定文本的颜色渲染,默认渲染为红色。 + """ + #attrs ={'style':'color:RGB(255,0,0);font-weight:bold;font-style:italic;font-size:15px'} + #判断文本中间有没有其它指令,默认有红色属性 + #app = inliner.document.settings.env.app + attrs,puretext = __GetAttrAndPureText(text) + nodelst = [] + nodelst.append(cncolorrolenode(rawtext,puretext,*content,**attrs)) + #print(cncolor(rawtext,puretext,*content,**attrs)) + return nodelst,[] + +def __IsStrikeAttr(attrstr): + """ + 判断属性里面是否有删除标识,有删除标志将删除标识删除掉,同时返回true。 + """ + searchstr = "text-decoration:line-through" + matchobj = re.search("text-decoration:line-through",attrstr,re.I) + if matchobj is None: + return attrstr,False + attrstr = attrstr[0:matchobj.start()]+attrstr[matchobj.end():len(attrstr)] + return attrstr,True + +def html_visit_cncolor_node(self, node): + + #判断属性里面有没有text-decoration:line-through删除标志,有删除标志使用标签 + #否则和text-decoration:underline标志会存在冲突 + #print(node.attributes['style']) + strikeflag = False + attrstr,strikeflag = __IsStrikeAttr(node.attributes['style']) + attrs={'style':attrstr} + tag = self.starttag(node,'span','',CLASS='cncolorrole',**attrs) + self.body.append(tag) + if strikeflag: + self.body.append(self.starttag(node, 'del')) + #print(node) + #print(node.attributes) + #print(node.tagname) + #print(node.astext()) + #print(type(self)) + + +def html_depart_cncolor_node(self, node): + strikeflag = False + attrstr, strikeflag = __IsStrikeAttr(node.attributes['style']) + if strikeflag: + self.body.append('') + self.body.append('') + #print(node) + +def __GetColorAttr(attrstr,colorkey): + ''' + :param colorkey: color或者bg,得到字体或者背景色 + ''' + #为了正则表达式好查找,给属性字符串添加{} + newattr = '{'+attrstr+'}' + searchstr = colorkey + ':([\s\S]+?)(?=[;|\}])' + matchobj = re.search(searchstr,newattr,re.I) + if matchobj is None: + return "{}" + colorattr = matchobj.group(1) + + #得到颜色 + searchstr = 'rgb' + matchobj = re.match(searchstr, colorattr, re.I) + if matchobj is None: + return "{" + colorattr + "}" # 独立颜色 + # 判断是大写RGB还是小写rgb + # 先按大写RGB搜索 + searchstr = 'RGB\(([0-9]+,[0-9]+,[0-9]+)\)' + matchobj = re.match(searchstr, colorattr) + if matchobj: + return "[RGB]{" + matchobj.group(1) + "}" + # 再搜索是否有小写的rgb颜色设置 + searchstr = 'rgb\((0\.[1-9]+,0\.[1-9]+,0\.[1-9]+)\)' + matchobj = re.match(searchstr, colorattr) + if matchobj: + return "[rgb]{" + matchobj.group(1) + "}" + +def __GetLatexFontSizeFlag(attrstr,fontsize): + attrstr +=';' #最后添加分号,方便正则表达式查找 + searchstr = fontsize+":([0-9]+)(pt|px)(?=;)" + + match = re.search(searchstr,attrstr,re.I) + if match is None: + return "" + try: + fontvalue = int(match.group(1)) + fontunit = match.group(2) + if fontunit == 'px': + #将px转为pt,按96dpi进行的转化,pt = px * (72/96) + fontvalue = fontvalue * 0.75 + if fontvalue >= fontmaxsize_pt: + return '\Huge' + dicttmp = {} + #取出latex字典的所有key值 + keylst = latex_fontsize.keys() + for value in keylst: + diffabs = abs(fontvalue-value) + if not dicttmp: + dicttmp[diffabs]=value + else: + prevalue = next(iter(dicttmp)) + if diffabs < prevalue: + dicttmp[diffabs]=value + del dicttmp[prevalue] #只保存最小的一个,即最接近的一个。 + + #取出保存最接近的pt值,即字典的第一个,因为该字典只有一个 + value = dicttmp.get(next(iter(dicttmp))) + return latex_fontsize[value] + + except Exception as e: + print(e) + return "" + +def latex_visit_cncolor_node(self, node): + #print('-----------------') + # print(node.attributes['style']) + #print(node) + #print(type(self)) + #print('-----------------') + + attrstr = node.attributes['style'] + self.in_production_list = 0 + latexattr = "" + + #设置字体尺寸,如果是pt或者px,取和latex最相近的字体设置 + if "font-size:" in attrstr: + latexflag = __GetLatexFontSizeFlag(attrstr,'font-size') + if latexflag != "": + latexattr = latexflag+'{'+latexattr + self.in_production_list+=1 + + #如果有斜体属性 + if "font-style:italic" in attrstr: + latexattr = r'\textit{'+latexattr + self.in_production_list+=1 + #如果有粗体属性 + if "font-weight:bold" in attrstr: + latexattr = r'\textbf{'+latexattr + self.in_production_list+=1 + #取出设置的颜色 + if 'color:' in attrstr: + color = __GetColorAttr(attrstr,'color') + colorattr = r'\textcolor' + color + latexattr = colorattr + '{' + latexattr + self.in_production_list += 1 + #如果有删除线属性 + if "text-decoration:line-through" in attrstr: + latexattr = r"\sout{" + latexattr + self.in_production_list += 1 + #如果有下划线属性 + if "text-decoration:underline" in attrstr: + latexattr = r"\uline{" + latexattr + self.in_production_list += 1 + if "background:" in attrstr: + color = __GetColorAttr(attrstr,'background') + colorattr = r'\colorbox' + color + latexattr = colorattr + '{' + latexattr + self.in_production_list += 1 + + self.body.append(latexattr) + +def latex_depart_cncolor_node(self, node): + for i in range(self.in_production_list, 0, -1): + self.body.append('}') + self.in_production_list = 0 + +#--------------------------------cncolordirective------------------------------------------------ + +class cncolordirective(nodes.Inline, nodes.TextElement): + """ + cncolorrole node类,用于解析cncolor role节点 + """ + pass +class cncolorliteral(nodes.literal_block): + """ + 为了能够自定义html标签和latex符号 + """ + pass +class cncolorline(nodes.literal_block): + """ + 该节点是为了逐行添加latex指令,如果cncolor指令设置了下划线和删除线选项,则只能逐行添加参数。 + """ + pass +class cncolorsection(nodes.section): + """ + 为了能够自定义html标签和latex符号 + """ + pass +class cncolortext(nodes.Text): + """ + 为了给空行添加换行符 + """ + pass +#自定义cncolor指令 +class cnColorDirectivecls(SphinxDirective): + ''' + 自定义cncolor指令核心类 + 8个options: + * literal: 内容是否当作纯文本解析,默认解析文本中的指令。如果添加了该字段,则内容会被当作纯文本。 + 只解析行内指令,其它sphinx指令不支持。比如list不会被解析。 + * color:字体颜色,字符串。支持red,green,blue,black,white,yellow和RGB(,,)格式的颜色设置。 + * bgcolor:背景颜色,字符串。支持red,green,blue,black,white,yellow和RGB(,,)格式的颜色设置。 + * italic:斜体的标识,添加该参数则将字体设置为斜体,不添加该参数斜体无效。 + * bold:粗体的标识,添加该参数则将字体设置为粗体,不添加该参数粗体无效。 + * underline:添加下划线,添加该参数则为字体设置下划线,不添加该参数不添加下划线。latex暂不支持。 + * strike:添加删除线,添加该参数则为字体设置删除线,不添加该参数不添加删除线。latex暂不支持。 + * istitle: chapter 或者 section。当添加了istitle后,如果内容包含标题,则认为是一级大纲chapter,否则认为是section。 + * fontsize: 字体大小,以px或者pt为单位。html直接使用已设置的值,latex会转为latex字体大小进行设置。 + latex最大字体\Huge为34px,因此当字体大小设置为34px以上时,对latex文件来说字体大小都是\Huge。 + ''' + required_arguments = 0 + optional_arguments = 1 + final_argument_whitespace = False + has_content = True + option_spec={ + 'literal': directives.flag, + 'color': directives.unchanged, + 'bgcolor':directives.unchanged, + 'italic': directives.flag, + 'bold': directives.flag, + 'underline': directives.flag, + 'strike': directives.flag, + 'class': directives.class_option, + 'name': directives.unchanged, + 'istitle':directives.flag, + 'fontsize':directives.unchanged, + } + def __MakeSectionorTitleNode(self,sectionlst): + ''' + 生成section或者title节点 + ''' + titlelst, messages = self.state.inline_text(sectionlst[0], self.lineno) + + if 'istitle' in self.options.keys(): + # htmlfile列表负责将html一级标题的h0修改为h1,否则标题显示达不到预期 + if hasattr(self.env, 'htmlfile') and (self.env.docname not in self.env.htmlfile): + # print(self.env.docname) + self.env.htmlfile.append(self.env.docname) + else: + # print(self.env.docname) + self.env.htmlfile = [] + self.env.htmlfile.append(self.env.docname) + + # 有title的section,在生成latex的时候,需要删除title子节点,否则生成的标题达不到预期 + # 因此有title的内容,每次都得重新编译,否则在增量编译时,无法修改title节点 + if not hasattr(self.env, 'sectionfile'): + # print(self.env.docname) + self.env.sectionfile = [] + self.env.sectionfile.append(self.env.docname) + elif self.env.docname not in self.env.sectionfile: + # print(self.env.docname) + self.env.sectionfile.append(self.env.docname) + + if len(titlelst) == 1: + # 说明标题不带sphinx行内指令,则无需生成title node。 + targetid = 'cncolorsection-%d' % self.env.new_serialno('cncolorsection') + sectionnode = cncolorsection( + '', + nodes.title(text=sectionlst[0]), + ids=[targetid], + names=[nodes.fully_normalize_name(sectionlst[0])], + *[], + **self.options) + return sectionnode + else: + + title_node = nodes.title(sectionlst[0], '', *titlelst, **{}) + targetid = 'cncolorsection-%d' % self.env.new_serialno('cncolorsection') + sectionnode = cncolorsection(ids=[targetid], names=[nodes.fully_normalize_name(sectionlst[0])]) + sectionnode.attributes = {**sectionnode.attributes, **self.options} + sectionnode += title_node + return sectionnode + # 以下注释保留,为后续开发提供参考 + ''' + #rst = ViewList() + #rst.append('\n'.join(sectionlst),'',0) + #section = nodes.section() + #section.document = self.state.document + #nested_parse_with_titles(self.state, rst, section) + + #self.env.titledict = {sectionlst[0]:''} + + ##说明标题内有sphinx行内指令,需要通过nodes.title解析,然后在doctree-read处理事件中进行替换删除 + title_node = nodes.title(sectionlst[0], '', *titlelst, **self.options) + sectionnode = cncolorsection( + '', + nodes.title(text=sectionlst[0]), + ids=[nodes.make_id(sectionlst[0])], + names=[nodes.fully_normalize_name(sectionlst[0])], + *[], + **self.options) + return [title_node,sectionnode,node] + ''' + + def __MakeContentNode(self,contentlst): + #print('ccccccccccccccccccccccc') + text = '\n'.join(contentlst) + '\n' # 末尾加一个换行符,否则如果正则表达式查不出来最后的符号行。 + #print(text) + searchstr = r'([\x21-\x2f|\x3a-\x40|\x5b-\x60|\x7b-\x7e])\1+(?=[\r|\n|\r\n])\n?' + matchobj = re.search(searchstr,text) + if matchobj: + #生成新的属性,主要用于纯文本 + #newdict = self.options.copy() + #if 'istitle' in newdict.keys(): + # del newdict['istitle'] + + parentnode = cncolordirective('','',*[],**self.options) + startpos = matchobj.start() + endpos = matchobj.end() + if startpos > 0: + pretext = text[0:startpos] + if len(pretext) > 0: + #print(len(pretext)) + #print('xxx') + #print(pretext.split('\n')) + #print('zzzz') + pretextlst = ViewList(pretext.split('\n')) + prenode = cncolordirective(pretext, '', *[], **self.options) + self.state.nested_parse(pretextlst, 0, prenode) + parentnode += prenode + + titlenode=nodes.Text(matchobj.group()) + parentnode +=titlenode + + if endpos < len(text): + endtext = text[endpos:len(text)] + if len(endtext) >0:# + #print(len(endtext)) + #print('xxx2') + #print(endtext) + #print('zzzz2') + newlist = ViewList(endtext.split('\n')) + leftnode = self.__MakeContentNode(newlist) + parentnode += leftnode + + return parentnode + + else: + #print('~~~~~~~~~~~~~~~~~~~~~~') + #print(text) + #print('~~~~~~~~~~~~~~~~~~~~~~') + node = cncolordirective(text.strip('\n'), '', *[], **self.options) + self.state.nested_parse(contentlst, 0, node) + return node + + def __MakeContentNode_v2(self, contentlst): + # print('ccccccccccccccccccccccc') + text = '\n'.join(contentlst) + '\n' # 末尾加一个换行符,否则如果正则表达式查不出来最后的符号行。 + # print(text) + searchstr = r'([\x21-\x2f|\x3a-\x40|\x5b-\x60|\x7b-\x7e])\1+(?=[\r|\n|\r\n])\n?' + matchobj = re.search(searchstr, text) + if matchobj: + # 生成新的属性,主要用于纯文本 + # newdict = self.options.copy() + # if 'istitle' in newdict.keys(): + # del newdict['istitle'] + + parentnode = nodes.paragraph() + startpos = matchobj.start() + endpos = matchobj.end() + if startpos > 0: + pretext = text[0:startpos] + if len(pretext) > 0: + # print(len(pretext)) + # print('xxx') + # print(pretext.split('\n')) + # print('zzzz') + pretextlst = ViewList(pretext.split('\n')) + prenode = nodes.inline() + self.state.nested_parse(pretextlst, 0, prenode) + parentnode += prenode + + titlenode = nodes.Text('\n'+matchobj.group().strip('\n')) + parentnode += titlenode + + if endpos < len(text): + endtext = text[endpos:len(text)].strip('\n') + if len(endtext) > 0: # + # print(len(endtext)) + # print('xxx2') + # print(endtext) + # print('zzzz2') + newlist = ViewList(endtext.split('\n')) + leftnode = self.__MakeContentNode_v2(newlist) + parentnode += leftnode + + return parentnode + + else: + # print('~~~~~~~~~~~~~~~~~~~~~~') + # print(text) + # print('~~~~~~~~~~~~~~~~~~~~~~') + node =nodes.paragraph() + self.state.nested_parse(contentlst, 0, node) + return node + + def __GetNoliteralNode(self): + + # 解析是否有标题 + if len(self.content) > 1: + sectionlst, contentlst = self.__GetContentDictbySection(self.content) + else: + sectionlst = [] + contentlst = self.content + + newdict = {} + if len(self.arguments)>0: + newdict['caption'] = self.arguments[0] + + parentnode = cncolordirective('', '', *[], **{**self.options,**newdict}) + + node = self.__MakeContentNode_v2(contentlst) + node.line = self.content_offset + 1 + self.add_name(node) + if sectionlst: + # 增加section + sectionnode = self.__MakeSectionorTitleNode(sectionlst) + parentnode += node + return [sectionnode,parentnode] + else: + parentnode +=node + return [parentnode] + + def __GetLineNode(self): + ''' + 解析每一行,为每一行添加下划线和删除线指令 + ''' + #print(self.content) + newdict = {} + if len(self.arguments)>0: + newdict['caption'] = self.arguments[0] + + parentnode = cncolordirective('', '', *[], **{**self.options,**newdict}) + for line in self.content: + if len(line)>0: + text_nodes, messages = self.state.inline_text(line, 0) + node = cncolorline(line, '', *text_nodes, **self.options) + parentnode+=node + else: + node = cncolortext(line) + parentnode+=node + return [parentnode] + + def run(self): + + set_classes(self.options) + classes = ['cncolorblock'] + if 'classes' in self.options: + classes.extend(self.options['classes']) + self.assert_has_content() + if 'literal' in self.options.keys(): + del self.options['literal'] #先将非必须的option删除 + # 将标题保存传出去 + newdict={} + if len(self.arguments) > 0: + newdict['caption'] = self.arguments[0] + + #全部当作纯文本解析 + text = '\n'.join(self.content) + text_nodes, messages = self.state.inline_text(text, self.lineno) + node = cncolorliteral(text, '',*text_nodes, **{**self.options,**newdict}) + node.line = self.content_offset + 1 + self.add_name(node) + return [node] + messages + elif (('underline' or 'strike') in self.options.keys()) and \ + (self.env.app.builder.name=='latex' or self.env.app.builder.name=='latexpdf'): + return self.__GetLineNode() + else: + return self.__GetNoliteralNode() + + + def __GetContentDictbySection(self,content): + ''' + **注意:** + cncolor指令下的内容,要么开始就是标题名称,要么不包含标题。 + + 带有section的内容使用literal_block或者nested_parse解析,都会给出以下告警:: + + CRITICAL: Unexpected section title. + + 因此检查文本中是否带有section内容,并根据section内容对文本进行拆分, + 返回标题名称列表和段落内容列表。 + ''' + #搜索可能的标题 + searchstr = r"([\x21-\x2f|\x3a-\x40|\x5b-\x60|\x7b-\x7e]+)\n?" + text = content[1] + #判断第2个是否为标题符号行 + match = re.match(searchstr,text) + if match and len(content[0])>0: + #说明内容的前两个是标题内容 + return content[0:2],content[2:len(content)] + + return None,content + + +def __GetHtmlStyleAttr(nodeattr): + ''' + 根据node属性生成,html标签的style属性。 + node属性是一个字典对象 + ''' + attr={} + attr['style'] ='' + delflag = False #是否有删除标志 + keys = nodeattr.keys() + #指令参数名称即为字典的key值 + if 'color' in keys: + attr['style'] = 'color:' + nodeattr['color'] + if 'bgcolor' in keys: + attr['style'] = attr['style'] + ';background-color:' + nodeattr['bgcolor'] + if 'italic' in keys: + attr['style'] = attr['style'] + ';font-style:italic' + if 'bold' in keys: + attr['style'] = attr['style'] + ';font-weight:bold' + if 'underline' in keys: + attr['style'] = attr['style'] + ';text-decoration:underline' + if 'fontsize' in keys: + attr['style'] = attr['style'] + ';font-size:' + nodeattr['fontsize'] + if 'strike' in keys: + delflag = True + #print(attr) + return attr,delflag + +def __GetColorAttrValue(attrstr): + ''' + 判断颜色属性是否为rgb格式,如果是rgb格式需要重定义颜色 + 如果不是rgb格式,就当作独立颜色 + ''' + #先判断开头是否为rgb或者RGB。 + searchstr = 'rgb' + matchobj = re.match(searchstr,attrstr,re.I) + if matchobj is None: + return attrstr,False #独立颜色 + #判断是大写RGB还是小写rgb + #先按大写RGB搜索 + searchstr = 'RGB\(([0-9]+,[0-9]+,[0-9]+)\)' + matchobj = re.match(searchstr, attrstr) + if matchobj: + return "{RGB}{" + matchobj.group(1)+"}",True + #再搜索是否有小写的rgb颜色设置 + searchstr = 'rgb\((0\.[1-9]+,0\.[1-9]+,0\.[1-9]+)\)' + matchobj = re.match(searchstr, attrstr) + if matchobj: + return "{rgb}{"+matchobj.group(1)+"}",True + +def __GetlatexStyleAttr(self,nodeattr): + ''' + 根据node属性生成latex属性,latex暂时不支持下划线和删除线。 + node属性是一个字典对象 + ''' + attr = '' + fontupper = 'fontupper=' + colorlst = [] #自定义颜色列表 + keys = nodeattr.keys() + #指令参数名称即为字典的key值 + if 'color' in keys: + colorvalue,flag = __GetColorAttrValue(nodeattr['color']) + if flag: + #含有rgb的颜色,需要定义 + #生成唯一的颜色名称 + colorid = 'definecoltext%d' % self.settings.env.new_serialno('coltext=' + colorvalue) + definecolor = r'\definecolor{'+colorid+'}'+colorvalue+'\n' + attr +="coltext=" + colorid + ',' + colorlst.append(definecolor) + else: + attr +="coltext=" + nodeattr['color'] + ',' + if 'bgcolor' in keys: + colorvalue,flag = __GetColorAttrValue(nodeattr['bgcolor']) + if flag: + #含有rgb的颜色,需要定义 + #生成唯一的颜色名称 + colorid = 'definecolback%d' % self.settings.env.new_serialno('colback=' + colorvalue) + definecolor = r'\definecolor{'+colorid+'}'+colorvalue+'\n' + attr +="colback=" + colorid +',' + colorlst.append(definecolor) + else: + attr +="colback=" + nodeattr['bgcolor'] + ',' + else: + attr += "colback=white,colframe=white,left*=0mm," #没有配置背景色,则采用白色背景色,制造没有背景色的假象。否则默认为灰色。 + + if 'caption' in keys: + attr += 'title='+nodeattr['caption'] +r',fonttitle=\bfseries,' + #给title添加背景色,否则title默认为白色,显示不出来 + attr += 'colbacktitle = black!50!blue,' + + if 'italic' in keys: + fontupper += r'\itshape' + if 'bold' in keys: + fontupper += r'\bfseries' + if 'fontsize' in keys: + # 得到latex字体大小 + # 组合下面的接口可以认识的字符串 + attrstr = 'font-size:' + nodeattr['fontsize'] + latexflag = __GetLatexFontSizeFlag(attrstr, 'font-size') + fontupper += latexflag + + #print(attr+fontupper) + return attr+fontupper, colorlst + +def __ModifyNodeAttrtoHtmlStyle(nodeattr): + ''' + 根据node属性生成,html标签的style属性。 + node属性是一个字典对象 + ''' + delflag = False #是否有删除标志 + keys = nodeattr.keys() + #指令参数名称即为字典的key值 + if 'color' in keys: + nodeattr['style'] = 'color:' + nodeattr['color'] + del nodeattr['color'] + if 'bgcolor' in keys: + nodeattr['style'] = nodeattr['style'] + ';background-color:' + nodeattr['bgcolor'] + del nodeattr['bgcolor'] + if 'italic' in keys: + nodeattr['style'] = nodeattr['style'] + ';font-style:italic' + del nodeattr['italic'] + if 'bold' in keys: + nodeattr['style'] = nodeattr['style'] + ';font-weight:bold' + del nodeattr['bold'] + if 'underline' in keys: + nodeattr['style'] = nodeattr['style'] + ';text-decoration:underline' + del nodeattr['underline'] + if 'strike' in keys: + del nodeattr['strike'] + if 'caption' in keys: + del nodeattr['caption'] + #print(attr) + return nodeattr + +def html_visit_cncolorsection_node(self, node): + self.strikeflag = False + attrs, self.strikeflag = __GetHtmlStyleAttr(node.attributes) + tag = self.starttag(node, 'div','',CLASS='cncolorsection', **attrs) + self.body.append(tag) + if self.strikeflag: + self.body.append(self.starttag(node, 'del')) + +def html_depart_cncolorsection_node(self, node): + if self.strikeflag: + self.body.append('') + self.strikeflag = False + self.body.append('
\n') + +def latex_visit_cncolorsection_node(self, node): + + if 'istitle' in node.attributes: + self.body.append(r'\chapter{') + else: + self.body.append(r'\section{') + +def latex_depart_cncolorsection_node(self, node): + self.body.append('}') + +def html_visit_cncolordirective_node(self, node): + self.strikeflag = False #是否有删除标志 + attrs,self.strikeflag = __GetHtmlStyleAttr(node.attributes) + tag = self.starttag(node,'div','',CLASS='cncolorblock',**attrs) + self.body.append(tag) + + if self.strikeflag: + self.body.append(self.starttag(node, 'del')) + if 'caption' in node.attributes.keys(): + self.body.append("

"+node.attributes['caption']+'

\n') + + if node.tagname =='cncolorliteral': + self.body.append(self.starttag(node,"pre")) + + +def html_depart_cncolordirective_node(self, node): + if self.strikeflag: + self.body.append('') + self.strikeflag = False + if node.tagname == 'cncolorliteral': + self.body.append('') + self.body.append('
') + +def latex_visit_cncolordirective_node(self, node): + + latexattr,colorlst = __GetlatexStyleAttr(self,node.attributes) + self.body.append('\n') + for color in colorlst: + self.body.append(color) + self.body.append('\n') + latexpre = r'\begin{tcolorbox}[arc=0mm,boxrule=0mm,left=0mm,' + latexattr + ']' + '\n' + #print(latexpre) + self.body.append(latexpre) + if node.tagname == 'cncolorliteral': + self.body.append(r'\begin{Verbatim}[commandchars=\\\{\}]' + '\n') + +def latex_depart_cncolordirective_node(self, node): + + #self.body.append('\n') + if node.tagname == 'cncolorliteral': + self.body.append('\n'+r'\end{Verbatim}' + '\n') + + self.body.append('\n'+ r'\end{tcolorbox}' + '\n') + +def latex_visit_cncolorline_node(self,node): + #print('cncolorline') + latexattr = '' + self.linecount = 0 + if 'strike' in node.attributes: + latexattr = r"\sout{" + latexattr + self.linecount += 1 + if 'underline' in node.attributes: + #print("underline") + latexattr = r"\uline{" + latexattr + self.linecount += 1 + self.body.append(latexattr) + +def latex_depart_cncolorline_node(self,node): + for i in range(self.linecount, 0, -1): + #print(self.linecount) + self.body.append('}') + self.linecount = 0 + +def latex_visit_cncolortext_node(self,node): + self.body.append('\n\\\\') + +def latex_depart_cncolortext_node(self,node): + pass + +def doctree_resolved_process_titlenodes(app, doctree, fromdocname): + #if fromdocname=='preface/preface': + # print(doctree) + pass + #以下注释保留,为后续开发提供参考 + ''' + env = app.builder.env + if not hasattr(env, 'titledict'): + return + + if fromdocname=='index': + + #print(doctree) + for node in doctree.traverse(addnodes.toctree): + print('=======================') + toc = app.env.resolve_toctree(fromdocname, app.builder, node) + print(type(toc)) + print(toc) + print('=======================') + for nodechild in toc.children: + print('----------------------') + print(nodechild.astext()) + print('----------------------') + if nodechild.astext() in env.titledict.keys(): + newnode = env.titledict[nodechild.astext] + nodechild.replace_self(newnode) + node.remove(nodechild) + ''' +def __modifyNodeforlatexraw(node): + #print('~~~~~~~~~~~~~~~~~~~~~~~~~~~') + #print(node.astext()) # 或者其他处理逻辑 + #print('~~~~~~~~~~~~~~~~~~~~~~~~~~~') + text = node.astext() + if "-" in text: + #print(node.children) + node.children = [] # 先清空老节点内容 + # 对文本进行拆分 + textlst = text.split("-") + #print(textlst) + for i in range(0, len(textlst)): + textnode = nodes.Text(textlst[i], textlst[i]) + node.append(textnode) + if i < len(textlst) - 1: + rawnode = nodes.raw('', r"{\PYGZhy{}}", format='latex') + node.append(rawnode) + +def __addpygzhychildnode(doctree): + """ + 如果文本中含有“-”或者“——”或者“—”,则增加latex子节点 + """ + def traverse(node): + if isinstance(node, cnautowrapnode): + __modifyNodeforlatexraw(node) + #print('~~~~~~~~~~~~~~~~~~~~~~~~~~~') + #print(node.astext()) # 或者其他处理逻辑 + #print('~~~~~~~~~~~~~~~~~~~~~~~~~~~') + #text = node.astext() + #if "-" in text: + # print(node.children) + # node.children=[] #先清空老节点内容 + # #对文本进行拆分 + # textlst = text.split("-") + # print(textlst) + # for i in range(0,len(textlst)): + # textnode = nodes.Text(textlst[i],textlst[i]) + # node.append(textnode) + # if i < len(textlst)-1: + # rawnode = nodes.raw('', r"{\PYGZhy{}}", format='latex') + # node.append(rawnode) + + #newtext = newtext.replace("-",r"{\PYGZhy{}}") + #newtext = newtext.replace("_",r"\_") + #rawnode = nodes.raw('', newtext, format='latex') + #node.append(rawnode) #再添加新的子节点 + + for child in node.children: + traverse(child) + + traverse(doctree) + #print(doctree) + +def doctree_read_modify_titlenodes(app, doctree): + #print(doctree) + #docname = app.env.docname + #print(docname) + #print(doctree) + if app.builder.name=='latex' or app.builder.name=='latexpdf': + chaptercount = 0 + for node in doctree.children: + if node.tagname=='chapter': + chaptercount += 1 + if node.tagname=='section': + for nodechild in node.children: + if nodechild.tagname=='cncolorsection': + for childrentitle in nodechild.children: + if childrentitle.tagname =='title': + newchildren = [] + for childnode in childrentitle.children: + newchildren.append(childnode) + childrentitle.replace_self(newchildren) + elif node.tagname=='cncolorsection': + if chaptercount == 0: + node.attributes['istitle']=True + for childrentitle in node.children: + if childrentitle.tagname =='title': + newchildren = [] + for childnode in childrentitle.children: + newchildren.append(childnode) + childrentitle.replace_self(newchildren) + ''' + 如果是latex需要把section中间的title节点删除,否则生成的tex文件章节标题中包含\part指令, + 在latex标题中包含\part指令将导致无法编译通过 + 以下代码只支持sphinx 4.x以上版本,否则findall函数找不到。 + ''' + #for node in doctree.findall(cncolorsection): + # for childrentitle in node.children: + # if childrentitle.tagname =='title': + # newchildren = [] + # for childnode in childrentitle.children: + # newchildren.append(childnode) + # childrentitle.replace_self(newchildren) + #print(doctree) + #pass + #以下注释保留,为后续开发提供参考 + ''' + #print(doctree) + #print(type(doctree)) + #判断是否有cncolortitle节点,有该节点取出解析后的title,替换cncolorsection节点内的title,然后删除cncolortitle节点。 + env = app.builder.env + if not hasattr(env, 'titledict'): + return + + for node in doctree.findall(cncolortitle): + #生成新的title节点 + newnodes = nodes.title() + for childrentitle in node.children: + newnodes +=childrentitle + print(newnodes) + + index = node.parent.index(node) + sectionnode = node.parent[index+1] + sectionstr = sectionnode.astext() + #if env.titledict: + env.titledict[sectionstr] = newnodes.deepcopy() + print(newnodes.deepcopy()) + print(type(newnodes.deepcopy())) + print(env.titledict) + #对老节点进行替换删除 + for sectionchildren in sectionnode.children: + if sectionchildren.tagname == 'title': + sectionchildren.replace_self(newnodes) + node.parent.remove(node) + ''' +def __ModfiyTitleh0toh1(context): + #修改标题h0为h1,因为cncolorsection为自定义node,当标题在一级时,级别为h0,应该为h1 + searchstr = '
([\s\S]+?))[\s\S]+?
' + matchobj = re.search(searchstr,context,re.I|re.U|re.M) + if matchobj: + #group()和group(0)相同 + match = matchobj.group() + h0str = matchobj.group(1) + index = match.find(h0str) + pre = context[0:matchobj.start()+index] + left = context[matchobj.start()+index+len(h0str):len(context)] + newcontent = pre+"

"+matchobj.group(2)+"

"+left + __ModfiyTitleh0toh1(newcontent) + return newcontent + else: + return context + +def html_page_context_handle(app, pagename, templatename, context, doctree): + if hasattr(app.env,"htmlfile"): + if pagename in app.env.htmlfile: + #修改body内容 + context['body']= __ModfiyTitleh0toh1(context['body']) + +def env_get_outdated_handler(app,env,docnames): + #实现增量编译,将带cncolor指令的文件重新编译,否则latex可能无法生成 + if hasattr(env, 'sectionfile'): + for file in env.sectionfile: + if file not in docnames: + docnames.append(file) + if hasattr(env, 'autowrapfile'): + #print('----outdated enter----------') + for file in env.autowrapfile: + if file not in docnames: + #print(file) + docnames.append(file) +def env_updated_handler(app,env): + #print('-----updated------------') + #print(env.all_docs) + #print('-----updated------------') + pass + +def config_inited_handler(app,config): + #加载ulem包,否则latex不支持下划线和删除线 + if 'extrapackages' not in app.config.latex_elements.keys(): + #app.config.latex_elements['extrapackages'] = r'\usepackage[normalem]{ulem}' + app.config.latex_elements['extrapackages'] = r'\usepackage{seqsplit}' + elif r'\usepackage{seqsplit}' not in app.config.latex_elements['extrapackages']: + app.config.latex_elements['extrapackages'] += r'\usepackage{seqsplit}' + #elif r'\usepackage[normalem]{ulem}' not in app.config.latex_elements['extrapackages']: + # app.config.latex_elements['extrapackages'] += r'\usepackage[normalem]{ulem}' + +def setup(app): + + app.add_directive('cntoctree', TocTreeFilt) + app.add_directive('cnonly',CNOnly) + app.connect('doctree-resolved', doctree_resolved_process_titlenodes) + app.connect('doctree-read', doctree_read_modify_titlenodes) + app.connect('html-page-context',html_page_context_handle) + app.connect('env-before-read-docs',env_get_outdated_handler) + app.connect('env-updated', env_updated_handler) + app.connect('config-inited', config_inited_handler) + + #-----------cncolor directive---------------------------------------------------------------- + app.add_directive('cncolor',cnColorDirectivecls) + app.add_node(cncolorliteral, + html=(html_visit_cncolordirective_node, html_depart_cncolordirective_node), + latex=(latex_visit_cncolordirective_node, latex_depart_cncolordirective_node)) + app.add_node(cncolordirective, + html=(html_visit_cncolordirective_node, html_depart_cncolordirective_node), + latex=(latex_visit_cncolordirective_node, latex_depart_cncolordirective_node)) + app.add_node(cncolorsection, + html=(html_visit_cncolorsection_node, html_depart_cncolorsection_node), + latex=(latex_visit_cncolorsection_node, latex_depart_cncolorsection_node)) + app.add_node(cncolorline, + latex=(latex_visit_cncolorline_node, latex_depart_cncolorline_node)) + app.add_node(cncolortext, + latex=(latex_visit_cncolortext_node, latex_depart_cncolortext_node)) + #------------cncolor role-------------------------------------------------------------------- + app.add_node(cncolorrolenode, + html=(html_visit_cncolor_node, html_depart_cncolor_node), + latex=(latex_visit_cncolor_node, latex_depart_cncolor_node)) + app.add_role('cncolor',cncolor_role) + + app.add_node(cnautowrapnode, + html=(html_visit_autowrap_node, html_depart_autowrap_node), + latex=(latex_visit_autowrap_node, latex_depart_autowrap_node)) + app.add_role('autowrap',autowrap_role) + + #增加并行能力,否则,在其它地方再定义事件处理函数会有告警。 + return { + 'version': '0.1', + 'parallel_read_safe': True, + 'parallel_write_safe': True, + } diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/_static/custom.css b/torch_mlu_ops-v1.3.2/docs/user_guide/_static/custom.css new file mode 100644 index 0000000..fa9ad12 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/_static/custom.css @@ -0,0 +1,63 @@ +@import url("theme.css"); + +.wy-nav-content { + max-width: 100%; +} +pre.literal-block{ + background-color: #eeffcc; +} +.highlight .k{ + color: #404040; + font-weight: 400; +} +a.reference.internal:after{ + content: " "; + white-space: pre; +} +div.wy-menu.wy-menu-vertical{ + height:100%; + overflow-y:auto; +} +.dldpdf-btn{ + position:fixed; + top:20px; + right:100px; + background-color:#ff6c37; + color:#ffffff; + font-size:18px; + font-weight:bold; + padding:16px 30px; + border-radius:50%; + cursor:pointer; + box-shadow: 0px 4px 10px 0px rgba(0,0,0,0.5); + } + .feedback-btn{ + position:fixed; + right:5px; + background-color:#1E90FF; + color:#ffffff; + width: 25px; + border: 2px solid #FFEFD5; + height: 90px; + line-height: 20px; + word-wrap: break-word; + text-align: center; + cursor:pointer; + font-weight:bold; + top:40%; + box-shadow: 0px 4px 10px 0px rgba(0,0,0,0.5); + } +.showdiv{ + display: none; + position:fixed; + background-color: #ff6c37; + color: #fff; + top: 45%; + right: 30px; + padding: 8px; + border-radius: 6px; + font-size: 14px; + } + .doccenter{ + transform: translateX(5%); + } diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/chapterbkpaper.pdf b/torch_mlu_ops-v1.3.2/docs/user_guide/chapterbkpaper.pdf new file mode 100644 index 0000000..68afbd8 Binary files /dev/null and b/torch_mlu_ops-v1.3.2/docs/user_guide/chapterbkpaper.pdf differ diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/make.bat b/torch_mlu_ops-v1.3.2/docs/user_guide/make.bat new file mode 100644 index 0000000..9eee76b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build +set SPHINXPROJ=torch_mlu_ops_rst + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/makelatexpdf.sh b/torch_mlu_ops-v1.3.2/docs/user_guide/makelatexpdf.sh new file mode 100755 index 0000000..a6ff2de --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/makelatexpdf.sh @@ -0,0 +1,5 @@ +#! /bin/bash +apt-get install -y dvipng +make clean +make latexpdf +make html&&zip -qr -P"Cambricon" build/torch_mlu_ops_user_guide_html.zip build/html diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/parsejson.py b/torch_mlu_ops-v1.3.2/docs/user_guide/parsejson.py new file mode 100644 index 0000000..40c460d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/parsejson.py @@ -0,0 +1,3552 @@ +#!/usr/bin/python +# -*- coding: UTF-8 +''' +注意: +修改该文件,print打印信息中不能包含中文,否则在不支持中文的操作系统中会出现如下的错误, +导致编译失败或者达不到预期的显示效果。 +'ascii' codec can't encode characters in position + +2023年12月04日 修改历史 +------------------------------------ + +主要修改以下bug: +1. 当表格通过class属性设置为longtable,但是表头用“-”分割而不是用“=”分割的时候,在sphinx5.3.0以下版本编译失败的问题。 + 如果没有替换最新的parsejson.py文件,该bug在sphinx 5.3.0以下版本,可以采用以下方式规避: + 当表格设置为longtable时,表头改用“=”分割,而不是“-”分割,可以规避该问题。 + 如果替换了最新的parsejson.py文件,则不存在该问题,无需刻意规避。 +2. 修复了当表头为两行,并且合并列单元格放在表格的最前几列的时候,表头中的横线制表符没有画的bug。 + +优化了以下实现方式: +1. 将sphinx 5.3.0之后版本,为了适配表格改动所做的修改,固定在该文件的ModifyReplacePackage函数中。 + 这样只需替换该文件,而无需再修改conf.json文件,即可兼容sphinx5.3.0之前及之后版本,做到全版本兼容。 +2. 当有合并单元格的表头时,如果用“widths”参数设置的列宽,则表头中的标题不会自动换行,必须手动换行。 + 如果用“tabularcolumns”指令设置的列宽,则默认会自动换行。之前的行为是默认都会自动换行。 + 因此当前的实现要想使表头能够自动换行,必须采用tabularcolumns指令设置每一列的列宽。如果用“widths”设置列宽,只能手动换行。 + 修改原因: + 因为如果要实现自动换行,每一列的列宽必须以“m{0.1\textwidth}”这种方式实现,之前是把“widths”设置的列宽进行了转换。 + 这种转换对于列比较少的表格没什么问题。但是当表格比较大,列比较多时,这种转换就导致表格的显示超出了页面可显示的宽度, + 而且也会导致如何修改“widths”的值,都没什么效果。因此修改了自动换行的规则。 + 结论: + 对于列比较多而且表头有合并单元格的表格,建议用“tabularcolumns”设置每一列的列宽,因为该指令相当于latex的原始指令,设置的列宽直接反应每一列的实际宽度。 + 对于表头没有合并单元格的表格,建议用“widths”设置列宽比较简单,让sphinx转化实际的列宽,同时对latex和html都生效。 + +2023年11月15日 修改历史 +---------------------------- +将路径字符串相加的赋值方式,修改为os.path.join的方式,解决编译文档提示如下告警的问题: +RemovedInSphinx80Warning: Sphinx 8 will drop support for representing paths as strings. Use "pathlib.Path" or "os.fspath" instead. +sphinx 8之后将不再支持字符串表示路径。 +使用os.path.join的方式一定要特别注意以下几点: +1.join函数会自动添加“/”符号,因此一定不要再添加多余的“/”,否则可能得到的路径是错误的。 +2.join可以链接超过两个的路径,但是最后一个要链接的路径一定不能加“/”符号,否则前面的路径都失效。举例如下: + path1 = "home" + path2 = "rst" + path3 = "/meerkat" + path = os.path.join(path1,path2,path3) + 则得到的path为“/meerkat”,因此最后一个被链接的路径一定不能加“/”。 + +''' + +from __future__ import print_function + +import os +import sys +import json +import re +import codecs +import shutil +import platform +import traceback +import datetime as dt +from subprocess import Popen, PIPE, STDOUT +#import importlib + +try: + import sphinx + sphinx_version = sphinx.__display_version__[0:3] +except Exception as e: + sphinx_version = "" + +Py_version = sys.version_info +Py_v_info = str(Py_version.major) + '.' + str(Py_version.minor) + '.' + str(Py_version.micro) + +#if Py_version >= (3,5,0) and Py_version <(3,7,0): +# print(False) +# +#if Py_version >= (3,7,0): +# print(True) +# +#if platform.system().lower() == 'windows': +# print("windows") +#elif platform.system().lower() == 'linux': +# print("linux") + +filepathlist=os.path.split(os.path.realpath(__file__)) +currfilepath=filepathlist[0] +latexfontsize={'tiny','scriptsize','footnotesize','small','normalsize','large','Large','LARGE','huge','Huge'} + +def __dictIsHasKey__(dict,key): + if Py_version >= (3,0,0): + return dict.__contains__(key) + else: + return dict.has_key(key) + +def __openconfjsonfile__(filepath = ''): + jsonfile = filepath + if filepath == '': + jsonfile = currfilepath+'/conf.json' + + with codecs.open(jsonfile,"r+",encoding = 'utf-8') as load_f: + load_dict = json.load(load_f) + return load_dict + +def getignorewarning(): + return __load_dict__["ignorewarnkey"] + +def GetOtherIncludePackage(): +# packagearr = __load_dict__['package'] +# for package in packagearr: +# print(package) + return __load_dict__['package'] + +def GetReplacePackage(): + return __load_dict__['replacepackage'] + +def GetCustomOptions(): + return __load_dict__['customoptions'] + +def GetIsTocContents(): + return __load_dict__['isfiguretabletoc'] + +def GetSensitiveword(): + #得到敏感词数组,便于搜索文档中是否有敏感词存在 + return __load_dict__['sensitivewords'] + +def GetTablesContent(): + return __load_dict__['tables'] + +def GetTablerowtype(): +# packagearr = __load_dict__['tables']['rowtype'] +# print(packagearr) + return __load_dict__['tables']['rowtype'] + +def GetTableheadtype(): +# packagearr = __load_dict__['tables']['headtype'] +# print(packagearr) + return __load_dict__['tables']['headtype'] + +def GetTableHeadFontColor(): + return __load_dict__['tables']['headfontcolor'] + +def GetTableStylesArr(): +# packagearr = __load_dict__['tables']['styles'] +# for package in packagearr: +# print(package) + return __load_dict__['tables']['styles'] + +def GetImageStyleArr(): +# packagearr = __load_dict__['image']['styles'] +# for package in packagearr: +# print(package) + return __load_dict__['image']['styles'] + +#判断是否有忽略告警的类 +class clsIgnoreWarn: + + def __init__(self,warnfile): + self.warnfile = warnfile #保存告警文件 + + #判断该关键字是否在告警中,在告警中则返回true,不在告警中则返回fasle + def __JudgeWarnKey(self,keystr,warnstr): + + #先判断是否为组合关键字 + if "&&" in keystr: + #解析组合关键字为字符串列表 + keyls = keystr.split("&&") + isignore = True #默认该告警为忽略,如果组合关键字其中一个关键字没匹配上,则该告警不能忽略。 + for i in range(0,len(keyls)): + if keyls[i].strip().lower() not in warnstr.lower(): + isignore =False + break + return isignore + else: + if keystr.lower() in warnstr.lower(): + return True #忽略告警在告警字符串里面,则返回true,表示该告警可以忽略 + else: + return False + + #解析告警文件 + def parsewarnfile(self,ignorekey): + ''' + #make latex命令产生的告警都会保存在stderr文件中,只要判断该文件中的告警是否在忽略告警里就可以了,不在忽略告警里,则肯定有错误,停止执行 + ''' + if not os.path.exists(self.warnfile): + return True + + fs = codecs.open(self.warnfile,"r",encoding = 'utf-8') + fstr = fs.read() #先将所有内容读取到字符串中,方便正则表达式查找 + fs.close() + + #查找带warning的内容 + pattern = re.compile(r"([\s\S].*)[WARNING:|ERROR:]([\s\S].*)", re.I | re.U | re.M) + mobjarr = pattern.finditer(fstr) + + errstr = '' #保存不能被忽略的告警 + isNotError = True + for mobj in mobjarr: + amarr = mobj.group() + + if amarr == '': + continue + + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + + #判断该告警是否在忽略列表里,不在忽略列表里,保存在errstr字符串中 + + if ("WARNING:" not in amarr) and ("ERROR:" not in amarr): #直接忽略 + continue + + isWrite = False #默认不能忽略 + for igkey in ignorekey: + isWrite = self.__JudgeWarnKey(igkey,amarr) + if isWrite: + break + + if isWrite is False: + #写入stderr文件 + isNotError = False + #fe.writelines(amarr+'\n') + errstr += amarr + + if errstr != '': + #如果有不能被忽略的告警,则将不能忽略的告警重新写入warnfile文件中 + #先删除源文件,避免原文件无法写入 + fw = codecs.open(self.warnfile, "w",encoding = 'utf-8') + fw.write(errstr) + fw.close() + else: + #如果所有的告警都能忽略则删除之前的告警文件 + #该功能暂时不实现,还是保存原始的告警文件,方便查找。 + #os.remove(self.warnfile) + pass + + return isNotError + + #该函数暂时未用 + def __parsepdfarg(stdout,stderr,ignorelist,ignorekey): + ''' + make all-pdf命令产生的所有输出都只会保存在stdout中,因此从stdout中判断是否有告警内容 + ''' + stdoutfilepath = currfilepath+'/'+stdout + stderrfilepath = currfilepath+'/'+stderr + + if not os.path.exists(stdoutfilepath): + return True + + fs = codecs.open(stdoutfilepath, "r+",encoding = 'utf-8') + fstr = fs.read() #先将所有内容读取到字符串中,方便正则表达式查找 + fs.close + #查找latexmk的位置,latexmk位置的开始即make all-pdf打印输出的开始 + searchstr = r"latexmk" + m = re.search(searchstr, fstr, re.I|re.U ) + if m == None: + return True + + spos = m.span() #获取位置 + latexcont = fstr[spos[0]:len(fstr)] #获取到make all-pdf产生的内容 + + #查找带warning的内容 + pattern = re.compile(r'([\s\S].*)Warning([\s\S].*)', re.I|re.U) + mobjarr = pattern.finditer(latexcont) + + #打开stderr文件,方便后续写 + fe = codecs.open(stderrfilepath, "a+",encoding = 'utf-8') + isNotError = True + for mobj in mobjarr: + amarr = mobj.group() + if amarr == '': + continue + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + + #判断该告警是否在忽略列表里,不在忽略列表里,要写入stderr文件 + if amarr.lower() not in [elem.lower() for elem in ignorelist]: #如果错误不在忽略告警里面 + #如果不能整行匹配,再判断该行是否包含需忽略的告警的关键字 + isWrite = False #默认不能忽略 + for igkey in ignorekey: + + isWrite = self.__JudgeWarnKey(igkey,amarr) + if isWrite: + break; + + if isWrite == False: + #写入stderr文件 + isNotError = False + fe.writelines(amarr) + fe.close + + return isNotError + +def warn_main(warnfile,app=None): + + global __load_dict__ + + if len(__load_dict__) ==0: + GetConfjsondict(app) + + ignorekey = getignorewarning() + clswarn = clsIgnoreWarn(warnfile) + + if not clswarn.parsewarnfile(ignorekey): + #如果存在不可忽略的告警,则返回True,否则返回False + return True + + return False + + +class clsTableattr: + + def __init__(self, tables): + self.rowtype = tables['rowtype'] + if __dictIsHasKey__(tables,'headtype'): + self.headtype = tables['headtype'] + else: + self.headtype = "" + if __dictIsHasKey__(tables,'headfontcolor'): + self.headfontcolor = tables['headfontcolor'] + else: + self.headfontcolor = "" + if __dictIsHasKey__(tables,'styles'): + self.tablestyles = tables['styles'] + else: + self.tablestyles = None + if __dictIsHasKey__(tables,'isname'): + self.isname = tables['isname'] + else: + self.isname = False + if __dictIsHasKey__(tables,'fontsize'): + self.fontsize = tables['fontsize'] + else: + self.fontsize = "" + + if __dictIsHasKey__(tables,'ismulticolleftalign'): + self.ismulticolleftalign = tables['ismulticolleftalign'] + else: + self.ismulticolleftalign = False + + if __dictIsHasKey__(tables,'ismultirowautowrap'): + self.ismultirowautowrap = tables['ismultirowautowrap'] + else: + self.ismultirowautowrap = True + + if __dictIsHasKey__(tables,'ismultirowatupcell'): + self.ismultirowatupcell = tables['ismultirowatupcell'] + else: + self.ismultirowatupcell = False + + if __dictIsHasKey__(tables,'ismultirowatlowcell'): + self.ismultirowatlowcell = tables['ismultirowatlowcell'] + else: + self.ismultirowatlowcell = False + +class clsModifyTex: + + def __init__(self, content,app=None): + self.content = content + self.app = app + self.tablesattrobj = clsTableattr(GetTablesContent()) + + #构建模板字符串 + #自动换行模板 + self.autolinetemplat =r'''\renewcommand{\arraystretch}{0.8} + \begin{tabular}[c]{@{}l@{}}%(autocontent)s\end{tabular}''' + + fontcolor = self.tablesattrobj.headfontcolor + fontcolor = fontcolor.replace('{}', '{%(content)s}', 1) + self.commonstr = r'''%(sphinxstyletheadfamily)s\cellcolor''' + \ + self.tablesattrobj.headtype + "\n{" + fontcolor + "}" + + # multirow模板 + self.multirowstr = r'''\multirow{-%(count)s}{%(coluwidth)s}%(raggedright)s{ + ''' + self.commonstr + "}" + + self.simplemultirow = r'''\multirow{%(count)s}{%(coluwidth)s}{\cellcolor''' + \ + self.tablesattrobj.headtype + "}" + + # multicolumn模板 + self.multicolumnstr = r'''\multicolumn{%(count)s}{%(flag)s}{%(prefix)s''' + \ + self.commonstr + r'''%(suffix)s''' + + # 为了普通标题文本看起来和使用了multicolum你的位置一致,对于非multirow和multicolumn单元格的内容,都用multicolumn{1}{|l|}标志 + self.normalstr = r'''\multicolumn{1}{%(flag)s}{ + \begin{varwidth}[c]{%(colwidth)s} + ''' + self.commonstr + r''' +\par +\vskip-\baselineskip\vbox{\hbox{\strut}}\end{varwidth}}''' + self.comcolstr = r'''\multicolumn{%(count)s}{%(flag)s}{''' + self.commonstr +"}" + # 空的单元格内容,multirow位置变换后,会有空的单元格 + self.emptystr = r"\cellcolor" + self.tablesattrobj.headtype+"{}" + + #删除字符串列表。有时候sphinx生成的文本内容里会带有特殊的latex指令,该指令不支持设置字体颜色,需要删除。删除不影响内容显示。 + self.delstrlst=[r'\begin{quote}',r'\end{quote}'] + + def __checktpacknameandparam(self,packlist,packname,*packarg): + ''' + 检查指定的包是否需要增加或者修改参数 + :param packlist: 已经读出的包含的package + :param packname: 要修改或者添加的包 + :param packarg: 要修改或者添加的包参数,可变参数。不会覆盖已有参数 + ''' + if len(packarg) == 0: + #如果没有参数则直接添加该包 + loadpackagestr = "\\\\usepackage{" + packname + "}" + if loadpackagestr not in packlist: + packlist.append(loadpackagestr) + return + + params = ",".join(packarg) + loadpackagestr = "\\\\usepackage[" + params + "]{" + packname + "}" + #print(loadpackagestr) + packstr = ",".join(packlist) + #print(packstr) + searchstr = r'\\\\usepackage(\[[\S, ][^\\]+\])?\{'+packname+'\}' + match = re.search(searchstr,packstr) + if match is None: + #如果根本没有加载packname包,则加载该包 + packlist.append(loadpackagestr) + return + latexpack = match.group() + #print(latexpack) + if match.group(1) is None: + #如果该包没有携带参数,则重新加载该包 + #得到老包的位置,尽量原位置插入,避免因为位置问题导致编译失败 + index = packlist.index(latexpack) + del packlist[index] + #再重新加载新的包 + packlist.insert(index,loadpackagestr) + return + #如果携带了参数,则更新该参数 + #为了兼容后续新添加参数,并防止覆盖已有参数,将每个参数进行比较添加 + paramstr = match.group(1) + searchstr = '\[([\S, ]+)?\]' + match = re.search(searchstr,paramstr) + if match.group(1).strip() == params: + #print('-------------enter-----------') + #为了提高运行效率,两个参数相等直接返回,不再走下面的循环 + return + params = match.group(1) + for param in packarg: + if param not in params: + params += "," + param + #print(params) + loadpackagestr = "\\\\usepackage[" + params+"]{" + packname +"}" + #先删除旧包,原位置插入,避免因位置错误导致的编译失败。 + index = packlist.index(latexpack) + del packlist[index] + packlist.insert(index,loadpackagestr) + + #加入其它包 + def AddPackageToTex(self): + #得到需要包的数组 + packarr = GetOtherIncludePackage() + if len(packarr)==0: + return False + + self.__checktpacknameandparam(packarr,"multirow") + self.__checktpacknameandparam(packarr,"tcolorbox","most","skins","breakable") + + #如果数组有内容,就需要将包添加到latex文件的导言区 + #搜索\usepackage{sphinx},将包加在它的前面,用正则表达式搜索它的位置 + #采用正则表达式替换的方式,替换搜索到的字符串,因此需要先构建字符串 + #python认为\u后面为unicode字符,因此需要多加一个转义字符\,python才认为是整个的字符串 + #searchstr = r'\\usepackage\[dontkeepoldnames\]{sphinx}' + searchstr = r'\\usepackage(\[\S*\]*)?{sphinx}' + matchstr = re.search(searchstr,self.content) + + replacestr="" + for package in packarr: + replacestr += package+'\n' + + if Py_version >= (3,7,0): + replacestr += "\\" + matchstr.group(0) + else: + replacestr += matchstr.group(0) + + + self.content = re.sub(searchstr, replacestr, self.content, 1, re.M | re.I|re.U) + + return True + + #加入自定义选项,包放在了sphinx包的前面,因此选项放在sphinx包的后面 + def AddCustormOptionsToTex(self): + #得到需要包的数组 + packarr = GetCustomOptions() + if len(packarr)==0: + return False + + #如果数组有内容,就需要将包添加到latex文件的导言区 + #搜索\usepackage{sphinx},将自定义参数放在它的后面,用正则表达式搜索它的位置 + #采用正则表达式替换的方式,替换搜索到的字符串,因此需要先构建字符串 + #python认为\u后面为unicode字符,因此需要多加一个转义字符\,python才认为是整个的字符串 + searchstr = r'\\usepackage(\[\S*\]*)?{sphinx}' + matchstr = re.search(searchstr,self.content) + + replacestr="" + for package in packarr: + replacestr += package+'\n' + + if Py_version >= (3,7,0): + replacestr = "\\" + matchstr.group(0)+'\n'+replacestr + else: + replacestr = matchstr.group(0)+'\n'+replacestr + + self.content = re.sub(searchstr, replacestr, self.content, 1, re.M | re.I|re.U) + return True + + #增加figure和table toc到tex + def AddOtherTocToTex(self): + #得到需要包的数组 + packarr = GetIsTocContents() + if len(packarr)==0: + return False + + replacestr = "" + if packarr['isfigurestoc']: + figlst = packarr['figurestoc'] + for figstr in figlst: + replacestr += figstr + '\n' + + if packarr['istablestoc']: + figlst = packarr['tablestoc'] + for figstr in figlst: + replacestr += figstr + '\n' + if replacestr == "": + return + + #如果数组有内容,就需要将包添加到latex文件的导言区 + #搜索\usepackage{sphinx},将包加在它的前面,用正则表达式搜索它的位置 + #采用正则表达式替换的方式,替换搜索到的字符串,因此需要先构建字符串 + #python认为\u后面为unicode字符,因此需要多加一个转义字符\,python才认为是整个的字符串 + searchstr = r'\\sphinxtableofcontents' + matchstr = re.search(searchstr,self.content) + + if Py_version >= (3,7,0): + replacestr = "\\" + matchstr.group(0) + '\n' + replacestr + else: + replacestr = matchstr.group(0) + '\n' + replacestr + + self.content = re.sub(searchstr, replacestr, self.content, 1, re.M | re.I|re.U) + return True + + def __IsRequireReplace(self, value): + """ + 根据sphinx版本判断是否需要替换,有两种写法: + 1. {sphinx:5.3.0}:这种写法是当前sphinx版本大于等于5.3.0版本时,才执行替换。 + 2. {5.3.0:sphinx}:这种写法是当前sphinx版本小于5.3.0时,才进行替换 + :return: 返回True需要替换,和需要替换的实际value;或者返回fasle,不需要替换 + """ + #根据正则表达式得到需要替换的sphinx的最小版本号 + if len(sphinx_version)==0: + return True,value + #按大于等于版本号进行匹配查找 + searchstr = r"\{sphinx:([0-9\.]+)\}([\s\S]+)" + match = re.match(searchstr,value,re.I | re.U) + if match is not None: + #print(match.group(1),match.group(2)) + try: + require_version = match.group(1)[:3] + if sphinx_version >= require_version: + return True, match.group(2) + else: + return False, "" + except Exception as e: + print(e) + return True,value + + else: + #按小于版本号进行匹配查找 + searchstr = r"\{([0-9\.]+):sphinx\}([\s\S]+)" + match = re.match(searchstr, value, re.I | re.U) + if match is not None: + try: + require_version = match.group(1)[:3] + if sphinx_version < require_version: + return True, match.group(2) + else: + return False, "" + except Exception as e: + print(e) + return True, value + else: + return True,value + + + def ModifyReplacePackage(self): + #得到字典值 + #得到需要替换的包,用正则表达式替换 + redict = GetReplacePackage() + if len(redict) ==0: + return + + #为了兼容sphinx5.3.0版本,固定添加以下配置 + keys = redict.keys() + if "\\\\sphinxtoprule" not in keys: + redict["\\\\sphinxtoprule"] = "\\\\hline" + if "\\\\sphinxmidrule" not in keys: + redict["\\\\sphinxmidrule"] = "\\\\hline" + if "\\\\sphinxbottomrule" not in keys: + redict["\\\\sphinxbottomrule"] = "\\\\hline" + if "\\\\sphinxhline" not in keys: + redict["\\\\sphinxhline"] = "\\\\hline" + if "\\\\sphinxhyphen" not in keys: + redict["\\\\sphinxhyphen"] = "{sphinx:5.3.0}\\\\PYGZhy" + + #返回字典中所有键值的列表 + keylst = list(redict) + for key in keylst: + if key == 'comment' : + continue + keyvalue = redict[key] #得到键对应的值 + result,keyvalue = self.__IsRequireReplace(keyvalue) + if result : + #对键值进行替换 + self.content = re.sub(key, keyvalue, self.content, 0, re.I|re.U) + return + + def __GetReplacedContent(self,afterstr,replacelst): + """ + 对需要替换的内容进行替换 + :param afterstr: 需要替换的字符串 + :return: + """ + if len(afterstr) ==0 or len(replacelst)==0: + return afterstr + for replacelist in replacelst: + if len(replacelist) !=2: + continue + replacedstr = replacelist[0] + replacestr = replacelist[1] + if len(replacedstr) ==0 or len(replacestr)==0: + continue + afterstr = re.sub(replacedstr, replacestr, afterstr, 0, re.U | re.M) + + return afterstr + + def __ParseStyleForSpecialFunction(self,isshaded,apidict,linecontent): + + if isshaded: + envstartstr = "\n\\begin{shaded}\n" # 环境起始字符串 + envendstr = "\n\\end{shaded}\n" # 环境结束字符串 + else: + envstartstr = "" # 环境起始字符串 + envendstr = "" # 环境结束字符串 + + afterstr = linecontent + + if apidict is None or len(apidict)==0: + return envstartstr,envendstr,afterstr + + if not __dictIsHasKey__(apidict, "isenter") or \ + (__dictIsHasKey__(apidict, "isenter") and \ + apidict['isenter'] is True): + afterstr = self.__strreplacepos(afterstr,r"pysiglinewithargsret",r",",r",\\") + if __dictIsHasKey__(apidict, "istcolorbox") and \ + apidict['istcolorbox'] is True: + envstartstr, envendstr = self.__GettcolorboxForFunction(apidict) + elif __dictIsHasKey__(apidict, "fontsize") and \ + apidict['fontsize'] in latexfontsize: + envstartstr += "\\" + apidict['fontsize'] + "\n" + + #检查是否有替换,有替换的话,对内容进行替换 + if __dictIsHasKey__(apidict, "replace"): + replacelst = apidict['replace'] + if len(replacelst) > 0: + afterstr = self.__GetReplacedContent(afterstr,replacelst) + + return envstartstr,envendstr,afterstr + + def __ModifyFunctionBkColorByPython(self,isshaded,apilst,newcontent,linecontent,curpos,prepos): + #根据python生成的格式为类和函数声明添加灰底,并按逗号换行 + + ''' + 根据正则表达式获得以\pysiglinewithargsret开头,以“{}”结尾的中间字符串 + 对中间字符串添加灰底和换行 + ''' + if isshaded: + envstartstr = "\n\\begin{shaded}\n" # 环境起始字符串 + envendstr = "\n\\end{shaded}\n" # 环境结束字符串 + else: + envstartstr = "" # 环境起始字符串 + envendstr = "" # 环境结束字符串 + + startmultiline = '\\pysigstartmultiline\n' + stopmultiline = '\n\\pysigstopmultiline' + #searchstr = r'(?=\\pysiglinewithargsret).*(?<={})' + searchstr=r'(?=\\pysiglinewithargsret).*' + match = re.search(searchstr, linecontent, re.I|re.U) + if match is not None: + #重新组合linecontent + #pos = match.span() + isapi, apidict = self.__IsSpecialFunction(apilst, linecontent) + if isapi is True and apidict is not None: + envstartstr,envendstr,afterstr = self.__ParseStyleForSpecialFunction(isshaded,apidict,match.group()) + newstr = '\n' + linecontent[:match.start()] + envstartstr + startmultiline + afterstr+ stopmultiline + envendstr+ linecontent[match.end():len(linecontent)] + else: + newstr = '\n' + linecontent[:match.start()] + envstartstr+ startmultiline + match.group().replace(r",",r",\\") + stopmultiline + envendstr + linecontent[match.end():len(linecontent)] + #计算替换前的内容 + if len(prepos)==0: + newcontent = self.content[:curpos[0]-1] + newstr + else: + newcontent += self.content[prepos[1]:curpos[0]-1] + newstr + + return newcontent + + def __strreplacepos(self,srcstr,posstr,oldstr,newstr): + + # 从指定字符串开始对行内字符串进行替换 + ''' + srcstr:原始字符串 + posstr:从该字符串开始进行替换 + oldstr:被替换字符串 + newstr:替换字符串 + ''' + #查找字符串的起始位置 + pos = srcstr.find(posstr) + if pos == -1: + return "" #如果查找的字符串没有找到,则返回空。 + + #根据找到的位置进行字符串拆分 + startstr = srcstr[0:pos] + beforestr = srcstr[pos:len(srcstr)] + #对字符串进行替换 + afterstr = beforestr.replace(oldstr,newstr) + return startstr+afterstr + + def __GetleftrightValueForSpecialapi(self,str): + """ + 得到conf.json specialapi配置中left和right的值 + 字符串格式为“[-][数字][单位]”,根据这个格式查找值和单位 + :param str: + :return: + """ + if len(str) ==0: + return 0,"" + + searchstr = r'[-]{0,1}([0-9]{1,2})([A-Za-z]{1,2})' + match = re.search(searchstr, str, re.I) + if match is not None: + value = int(match.group(1)) + unit = match.group(2) + return value,unit + else: + return 0,"" + + def __GettcolorboxForFunction(self,apidict): + #判断有没有设置页面左右边距,的左右边距 + #必须通过\geometry指令设置,否则按默认处理 + + leftvalue = 0 + leftunit = "mm" + left = "0mm" + + rightvalue = 0 + rightunit = "mm" + right="0mm" + + isleftset = False + isrightset = False + + if __dictIsHasKey__(apidict, "left") and \ + len(apidict['left']) > 0: + leftvalue,leftunit=self.__GetleftrightValueForSpecialapi(apidict['left']) + if len(leftunit) > 0: + left = str(leftvalue)+leftunit + isleftset = True + + if isleftset is not True: + leftvalue = 9 #默认值 + leftunit="mm" + left = str(leftvalue)+leftunit + + searchstr = r'\\geometry{[\s\S]*left=([0-9]{1,2})([A-Za-z]{1,2})[\s\S]*}' + match = re.search(searchstr, self.content, re.I | re.M) + if match is not None: + left = match.group(1) + match.group(2) + leftvalue = int(match.group(1)) + leftunit = match.group(2) + + #print(left) + + if __dictIsHasKey__(apidict,"right") and \ + len(apidict['right']) >0 : + rightvalue, rightunit = self.__GetleftrightValueForSpecialapi(apidict['right']) + if len(rightunit) > 0: + right = str(rightvalue) + rightunit + isrightset = True + + if isrightset is not True: + rightvalue = leftvalue - 1 # 为了使显示效果更好看 + right = str(rightvalue)+leftunit + + searchstr = r'\\geometry{[\s\S]*right=([0-9]{1,2})([A-Za-z]{1,2})[\s\S]*}' + match =re.search(searchstr,self.content,re.I|re.M) + if match is not None: + rightvalue = int(match.group(1)) -2 + rightunit = match.group(2) + right = str(rightvalue) + rightunit + #print(right) + + fontsize = "" + if __dictIsHasKey__(apidict, "fontsize") and \ + apidict['fontsize'] in latexfontsize: + fontsize = "fontupper=\\" + apidict['fontsize'] + + envstartstr = "\n\\begin{tcolorbox}[arc=0mm,boxrule=-1mm,left skip=-" \ + + str(leftvalue+1) + leftunit +",left="+left +",right=-" +right \ + + ",colback=shadecolor," + fontsize +"]\n" + envendstr = "\n\\end{tcolorbox}\n" + #print("=====================") + #print(envstartstr) + #print(envendstr) + #print("=====================") + return envstartstr,envendstr + + + def __IsSpecialFunction(self,apilst,linecontent): + ''' + 判断是否为特殊api + :param apilst: + :param linecontent: + :return: + ''' + if len(apilst)==0: + return False,None + + for apidict in apilst: + if __dictIsHasKey__(apidict,"latexapiname"): + apistr=apidict['latexapiname'] #有可能以逗号分割的多个函数名称 + apiarr=apistr.split(",") #以逗号分割函数名称列表 + for api in apiarr: + if api in linecontent: + return True,apidict + return False,None + + def __RemoveLastSpace(self,headstr): + """ + 删除最后的空格,避免函数含义和关键字间有大的空行 + :param headstr: + :return: + """ + headlist = headstr.split('\n') + index = 0 + for i in range(len(headlist)-1,0,-1): + str = headlist[i].strip('\r') + if len(str) == 0: + index = i + break + if index > 0: + headlist.pop(index) + return '\n'.join(headlist) + else: + return headstr + + def __ModifyParamPos(self,paramcontent): + """ + 调整参数位置在最前面,为了适配breathe 4.24.1之后版本 + :param paramcontent: + :return: + """ + #查找第一个\begin{description}\n\sphinxlineitem{}字符串 + searchstr = r"(\\begin\{quote\})?(\r\n|\n)?\\begin\{description\}(\r\n|\n)?\\sphinxlineitem\{(\sphinxstylestrong\{)?(?P[\s\S].+)?(\})?\}(\r\n|\n)?[\s\S]+?(\\end\{description\}){1}(\r\n|\n)?(\\end\{quote\})?" + match = re.search(searchstr,paramcontent,re.I|re.M|re.U) + if match is not None: + + if "\\begin{quote}" in match.group(): + #删掉缩进 + otherstr = match.group()[len('\\begin{quote}'):len(match.group())-len('\\end{quote}')] + else: + otherstr = match.group() + + paramkey = match.group('paramkey') + headstr = paramcontent[0:match.start()] + headstr = self.__RemoveLastSpace(headstr) + if "Parameters" in paramkey: + #说明参数本身在第一个位置,不做处理 + return headstr + '\n' + otherstr + paramcontent[match.end():len(paramcontent)] + + #如果参数不在第一个位置,查找参数的位置并将第一个位置记录下来 + searchstr = r"(\\begin\{quote\})?(\r\n|\n)?\\begin\{description\}(\r\n|\n)?\\sphinxlineitem\{(\sphinxstylestrong\{)?(?PParameters)?(\})?\}(\r\n|\n)?[\s\S]+?(\\end\{description\}){1}(\r\n|\n)?(\\end\{quote\})?" + matchparam = re.search(searchstr,paramcontent,re.I|re.M|re.U) + if matchparam is None: + return headstr + '\n' + otherstr + paramcontent[match.end():len(paramcontent)] + + if "\\begin{quote}" in matchparam.group(): + #删掉缩进 + paramstr = matchparam.group()[len('\\begin{quote}'):len(matchparam.group())-len('\\end{quote}')] + else: + paramstr = matchparam.group() + + #重新组合字符串,将参数放在第一位 + newparamcontent = headstr + '\n' +paramstr + '\n' +otherstr \ + + paramcontent[match.end():matchparam.start()] \ + + paramcontent[matchparam.end():len(paramcontent)] + + return newparamcontent + else: + return paramcontent + + def ModifyFunctionBackColor(self): + ''' + 该函数用来修改函数的背景色,并根据参数换行 + * c/c++ sphinx生成的函数都被以下三行包围: + \pysigstartmultiline + \pysiglinewithargsret{xxxxx} + \pysigstopmultiline + 因此利用正则表达式查找这三行所在的位置进行替换。 + 注意: + 要实现该功能,latex必须包含framed和color包,同时定义以下颜色: + \definecolor{shadecolor}{RGB}{220,220,220} + 如果shadecolor颜色未定义,则该函数执行失败。 + * python生成的latex文件中函数或者类的声明,有\pysiglinewithargsret指令,但是没有pysigstartmultiline指令。 + 因此需要在\pysiglinewithargsret指令句子的结束,前后先添加pysigstartmultiline和pysigstopmultiline。再添加\begin{shaded}和\end{shaded} + 一句的结束根据"{}"来判断,遇到"{}"即为句末。 + ''' + #判断是否有特殊API需要处理 + apilst = [] + isshaded = True + if __dictIsHasKey__(__load_dict__,"specialapi"): + if __dictIsHasKey__(__load_dict__["specialapi"],"styles"): + apilst=list(__load_dict__["specialapi"]["styles"]) + if __dictIsHasKey__(__load_dict__["specialapi"],"isshaded"): + isshaded = __load_dict__["specialapi"]["isshaded"] + + pythontype = False + newcontent = "" + prepos=[] #定义需要替换的字符串列表,该列表中的每一个元素都包含上面提到的三行。 + #searchstr = r'^(?=.*\\pysiglinewithargsret).*$' + #该正则表达式可以把pysiglinewithargsret或者带template的识别出来。有template的会被pysigline标签标识 + #该修改也兼容sphinx 5.3.0及以后版本。否则5.3.0版本之后将无法使用。 + #searchstr=r'(\\pysigline[\s\S].*){0,1}[\n]{0,1}(\\pysiglinewithargsret[\s\S].*)[\n]' + searchstr=r'(\\pysigstartsignatures){0,1}[\r\n|\n]{0,1}(\\pysigstartmultiline){0,1}(\r\n|\n){0,1}((\\pysigline[\s\S].*){0,1}[\n]{0,1}(\\pysiglinewithargsret[\s\S].*))[\n](\\pysigstopmultiline){0,1}[\r\n|\n]{0,1}(\\pysigstopsignatures){0,1}' + m = re.finditer(searchstr, self.content, re.M|re.I|re.U) + for match in m: + + linecontent = match.group() + #判断是否添加了pysigstartmultiline和pysigstopmultiline,没有添加的话需要添加才能添加灰色背景 + #multilen=len(r'\pysigstartmultiline') + #startmultistr = self.content[match.start()-multilen-1:match.start()] + #如果上一行不是\pysigstartmultiline则需要按python风格修改,添加该标志 + #print(startmultistr.strip()) + if match.group(2) is None: + #print('is python') + # 计算替换前的内容 + newcontent = self.\ + __ModifyFunctionBkColorByPython(isshaded,apilst,newcontent,linecontent,match.span(),prepos) + prepos = match.span() + pythontype = True + else: + #tablestr = match.groups() + #计算替换前的内容 + if len(prepos)==0: + newcontent = self.content[0:match.start()] + else: + # 得到参数信息内容,方便后面调整位置,此时调整的参数为上一个函数的参数位置 + paramcontent = self.content[prepos[1]:match.start()] + # breathe 4.24.1之后版本将parameters放在了最后面, + # 为了适配breathe 4.24.1之后版本对格式的调整 + # 检查参数是否在后面,是的话,需要将参数放到最前面 + # breathe 4.24.1之后版本将doxygen的标签都放在\begin{description}\n\sphinxlineitem{}标志对里面 + # 因此对该标志对进行搜索 + newparamcontent = self.__ModifyParamPos(paramcontent) + newcontent += newparamcontent + + + #当有template的时候,pysiglinewithargsret不一定在行首,因此需要从pysiglinewithargsret开始替换逗号,加强制换行符 + if isshaded: + envstartstr = "\n\\begin{shaded}\n" # 环境起始字符串 + envendstr = "\n\\end{shaded}\n" # 环境结束字符串 + else: + envstartstr = "" # 环境起始字符串 + envendstr = "" # 环境结束字符串 + + isapi,apidict = self.__IsSpecialFunction(apilst,linecontent) + if isapi is True and apidict is not None: + envstartstr,envendstr,afterstr=self.__ParseStyleForSpecialFunction(isshaded,apidict,linecontent) + newstr = envstartstr + afterstr + envendstr + else: + afterstr=self.__strreplacepos(linecontent,r"pysiglinewithargsret",r",",r",\\") + #得到替换后的字符串 + newstr = envstartstr + afterstr + envendstr + #得到替换后的内容 + newcontent += newstr + prepos = match.span() + pythontype = False + + if len(prepos) > 0: + paramcontent = self.content[prepos[1]:len(self.content)] + newparamcontent = self.__ModifyParamPos(paramcontent) + self.content = newcontent + newparamcontent + #self.content = newcontent + self.content[prepos[1]:len(self.content)] + #print('===============================================') + + def __GetTableCaption(self,tablecontent): + ''' + 得到表格的caption,方便输出打印 + :param tablecontent: 表格内容 + :return: 返回表格的caption + ''' + tablecaption = "" + searchstr = r'(\\sphinxcaption|\\caption)\{(?P[\s\S]*?)\}(?=\\label|\\\\)' + matchcaption = re.search(searchstr, tablecontent, re.M | re.I | re.U) + if matchcaption is not None: + tablecaption = matchcaption.group('caption') #得到caption的值 + #长表格caption后面会添加\struct的后缀,去掉 + tablecaption = tablecaption.replace(r"\strut","") + + return tablecaption + + def ModifyTablesAttributes(self): + #修改表格属性 + newcontent = self.content + searchstr = r'(\\begin{savenotes})([\s\S]*?)(\\sphinxattablestart|\\sphinxatlongtablestart)([\s\S]*?)(\\sphinxattableend|\\sphinxatlongtableend)([\s\S]*?)(\\end{savenotes})' + matchobj = re.finditer(searchstr, self.content, re.M|re.I|re.U) + #outputstr = "\033[34mModifying Table:\033[0m" + #print(outputstr, end="") + icount = 0 + for match in matchobj: + try: + icount = icount + 1 + tablestr = match.groups() + tablecontent = tablestr[0]+tablestr[1]+tablestr[2]+tablestr[3]+tablestr[4] + columnwidthlst,newtablecontent = self.__GetColumnParameters(tablecontent) + #得到表格caption + tablecaption=self.__GetTableCaption(tablecontent) + #print(tablecaption) + #outputtable="\033[34m[" + tablecaption +"];\033[0m" + #print(outputtable, end="") + #为了兼容纯英文linux系统,在打印信息中不能包含中文,因此按自动计数累加表示表格, + #避免表格标题中有中文时导致修改表格失败。 + prstr = "\033[34m[Modifying table"+str(icount)+":... ...];\033[0m" + print(prstr, end="\r") + if self.tablesattrobj.tablestyles is None: + #先修改表格的通用属性 + newtableattr,islongtable = self.__StartModifyTableAttr(newtablecontent,columnwidthlst, False) + newcontent = newcontent.replace(tablecontent, newtableattr) + else: + caption_dict = self.__CreateTableCaptionDict(self.tablesattrobj.tablestyles) + if len(caption_dict ) > 0 : + newtableattr = self.__ModifySingleTableattr(newtablecontent, columnwidthlst, caption_dict ) #tablestr也是3个内容的数组,因为正则表达式被分为了3组,只取中间分组的内容。 + #重新组成新的字符串 + newcontent = newcontent.replace(tablecontent, newtableattr) + + except Exception as e: + print(e) + traceback.print_exc() + + print("\r\n") + self.content = newcontent + + #替换表格中字体大小 + def __ModifyTableFontSize(self,singletablecontent,fontsize): + + searchstr = r'(\\begin{savenotes})([\s\S]*?)(\\sphinxattablestart|\\sphinxatlongtablestart)' + m = re.search(searchstr,singletablecontent, re.M|re.I|re.U) + tablestr = m.groups() + #替换表格中的字体 + #同时也要使表格的caption字体大小保持一致,修改表格的caption需要添加如下指令:\renewcommand{\tablename}{\scriptsize{表} } + captionfont = "\n\\renewcommand{\\tablename}{\\" + fontsize + "{Table} }\n" + if doc_language=='zh_CN': + captionfont="\n\\renewcommand{\\tablename}{\\" + fontsize + "{表} }\n" + + tablefontsize = '\\'+fontsize + captionfont + + return tablestr[0]+tablefontsize+tablestr[2]+singletablecontent[m.end():len(singletablecontent)] + + def __CreateTableCaptionDict(self, tablestylesarr): + #根据caption生成表格字典,key=caption,value=属性数组 + cap_dict = {} + + for tablestyle_dict in tablestylesarr: + captionarr = tablestyle_dict['caption'] + #该caption可能是一个逗号分隔的字符串数组,因此需要做拆分 + captionlist = captionarr.split(",") + for caption in captionlist: + cap_dict[caption] = tablestyle_dict #以caption为key重新生成字典,便于查找 + return cap_dict + + #查找表格caption是否被匹配 + def __JudgeIsCaption(self,caption_dict,tablecaption): + + captionlst=caption_dict.keys() #得到json配置的所有caption + for caption in captionlst: + tablestyle_dict = caption_dict[caption] + if __dictIsHasKey__(tablestyle_dict, 'ispartial') and tablestyle_dict['ispartial']==True: + if caption in tablecaption: + return True,tablestyle_dict + else: + if caption == tablecaption: + return True,tablestyle_dict + return False,None + + def __ModifySingleTableattr(self, singletablecontent, columnwidthlst, caption_dict): + #修改单个表格属性 + #从单个表格里用正则表达式找caption + #定义正则表达式,查找caption内容 + tablecaption = '' + new_singletablecontent = singletablecontent + if self.tablesattrobj.isname: + searchstr = r'.*\\label.*?:(?P[\s\S].*)}}.*' + matchcaption = re.search(searchstr, singletablecontent, re.M | re.I | re.U) + if matchcaption is not None: + tablecaption = matchcaption.group('caption') # 得到caption的值 + else: + tablecaption = '' + else: + tablecaption = self.__GetTableCaption(singletablecontent) + + #stylefontsize = False + iscaption = False + iscaption,tablestyle_dict=self.__JudgeIsCaption(caption_dict,tablecaption) + if iscaption: + + #if __dictIsHasKey__(tablestyle_dict,'fontsize') and (tablestyle_dict['fontsize'] in latexfontsize): + # stylefontsize = True + # tablestyle_dict['isLongTable'] = True #只有长表格的\caption指令,设置的字体才能生效,所以如果要设置表格字体,需要强制设置为长表格。 + jsonmultirow = [] #保存需要设置自动换行的合并单元格的列号。 + if __dictIsHasKey__(tablestyle_dict, 'multirowcolumn'): + jsonmultirow = tablestyle_dict['multirowcolumn'] + + if __dictIsHasKey__(tablestyle_dict, 'isLongTable') is not True: + tablestyle_dict['isLongTable'] = False + + if __dictIsHasKey__(tablestyle_dict, 'isCusHead') is not True: + tablestyle_dict['isCusHead'] = False + + #修改表格通用属性 + new_singletablecontent,islongtable= self.__StartModifyTableAttr(new_singletablecontent, + columnwidthlst, + tablestyle_dict['isLongTable'], + tablestyle_dict['isCusHead'], + jsonmultirow, + tablestyle_dict) + #设置该表格定制字体 + if __dictIsHasKey__(tablestyle_dict,'fontsize') and (tablestyle_dict['fontsize'] in latexfontsize): + new_singletablecontent=self.__ModifyTableFontSize(new_singletablecontent,tablestyle_dict['fontsize']) + + #渲染竖型表格的第一列 + if __dictIsHasKey__(tablestyle_dict,'isVertical') and tablestyle_dict['isVertical']==True: + if islongtable: + # 长表格自定义sphinxcolwidth的第2个参数默认为5,否则无法实现自动换行 + #self.normalstr = self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth': '5', + # 'sphinxstyletheadfamily': '%(sphinxstyletheadfamily)s', + # 'content': '%(content)s' + #} + new_singletablecontent = self.__ModifyVerticalLongTable(new_singletablecontent,tablestyle_dict,columnwidthlst) + else: + # 普通表格自定义sphinxcolwidth的第2个参数默认为3,否则无法实现自动换行 + #self.normalstr = self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth': '3', + # 'sphinxstyletheadfamily': '%(sphinxstyletheadfamily)s', + # 'content': '%(content)s' + #} + new_singletablecontent = self.__ModifyTableByLine(new_singletablecontent,True,False,tablestyle_dict,columnwidthlst) + + if __dictIsHasKey__(tablestyle_dict,'isnewpage') and tablestyle_dict['isnewpage']==True: + new_singletablecontent = '\\newpage\n' + new_singletablecontent + + #最后修改表格的行高 + if __dictIsHasKey__(tablestyle_dict, 'ht_rowno') and tablestyle_dict['ht_rowno'] is not None: + new_singletablecontent = self.__ModifyTableRowHeight(islongtable,new_singletablecontent,tablestyle_dict['ht_rowno']) + else: + #修改表格的通用属性 + new_singletablecontent,islongtable = self.__StartModifyTableAttr(singletablecontent, columnwidthlst, False) + if new_singletablecontent == '': + new_singletablecontent = singletablecontent + + return new_singletablecontent + + def __NormaltableRowHeight(self,rowno,rowheightinfo,singletablecontent): + #自定义长表格,第一行以“\hline”开头,因此从第一个“\hline”开头进行替换。 + searchstr = r'\\hline' + match = re.search(searchstr,singletablecontent,re.I | re.U) + if match is None: + return singletablecontent + + #latex表格属性内容 + tablehead = singletablecontent[0:match.start()] + #表格内容 + tablecontent = singletablecontent[match.start():len(singletablecontent)] + tablecontent = self.__ModifyNormalRowHeight(rowno,rowheightinfo,0,tablecontent) + + return tablehead + tablecontent + + def __ModifyNormalRowHeight(self, rowno, rowheightinfo, startline, tablecontent): + # 修改除表格外其它行的行高,并返回修改后的内容 + searchstr = r'\\\\' + pattern = re.compile(searchstr, re.I | re.U) + matchiter = pattern.finditer(tablecontent) + i = startline + for m in matchiter: + i += 1 + if rowno == i: + # 开始设置该行的行高 + # 获取起始和结束位置 + posstart = m.start() + posend = m.end() + tablecontent = tablecontent[0:posstart] + rowheightinfo + tablecontent[posend:len(tablecontent)] + break + + return tablecontent + + def __JudgeRowlistRule(self,rowlist): + #判断行号和行高信息是否符合规则 + + if len(rowlist)!=2: + return False + rowno = rowlist[0] #得到行号 + rowhtstr=rowlist[1] #得到行高信息 + + #判断是否为异常信息 + #设置的最大行高不能超过10,因为10的行高太高,正常不需要这么高的行高。而且sphinx会根据换行自动调整行高,该方法只用户行高的微调。 + if rowno <= 0 or len(rowhtstr) == 0 or not rowhtstr.isdigit() or (int(rowhtstr) > 99): + return False + return True + + def __ModifyLongTableRowHeight(self,rowno,rowheightinfo,singletablecontent): + + # 判断是自定义长表格还是sphinx生成的长表格。 + # sphinx自动生成的长表格有两个表头,第一个表头以“\endfirsthead”结尾。 + # 自定义的长表格没有“\endfirsthead”标志 + searchstr = r'\\endfirsthead' + matchobj = re.search(searchstr, singletablecontent, re.I | re.U) + if matchobj is None: + # 可能是自定义长表格,自定义长表格没有endfirsthead和endhead需要单独设置 + singletablecontent = self.__NormaltableRowHeight(rowno,rowheightinfo,singletablecontent) + return singletablecontent + + if rowno==1: #因为长表格有两个表头,因此两个表头都要设置,避免换页后表头格式不一致。 + #修改长表格的表头 + firsthead = singletablecontent[0:matchobj.end()] + rpos = firsthead.rfind(r"\\") + firsthead = firsthead[0:rpos] + rowheightinfo + firsthead[rpos+len(r"\\"):len(firsthead)] + + #查找第2个表头,第2个表头以“\endhead”结尾 + searchstr = r'\\endhead' + matchobj2 = re.search(searchstr, singletablecontent, re.I | re.U) + endhead = singletablecontent[matchobj.end():matchobj2.end()] + rpos = endhead.rfind(r"\\") + + endhead = endhead[0:rpos]+rowheightinfo+ endhead[rpos+len(r"\\"):len(firsthead)] + + singletablecontent = firsthead+endhead+singletablecontent[matchobj2.end():len(singletablecontent)] + + if rowno > 1: + #长表格先将表头去除掉,否则表头会包含很多“\\”,但不代表行。 + #长表格的表头以“\endlastfoot”结尾 + searchstr = r'\\endlastfoot' + matchobj = re.search(searchstr, singletablecontent, re.I | re.U) + tablehead = singletablecontent[0:matchobj.end()] + othercontent = singletablecontent[matchobj.end():len(singletablecontent)] + + #修改表格内容,因为长表格有两个表头,因此表头后的行号起始行从1开始。 + othercontent = self.__ModifyNormalRowHeight(rowno,rowheightinfo,1,othercontent) + + singletablecontent = tablehead + othercontent + return singletablecontent + + def __ModifyTableRowHeight(self,islongtable,singletablecontent, rowlistarry): + + #解析行信息,包含行号和行高信息 + for rowlist in rowlistarry: + + if not self.__JudgeRowlistRule(rowlist): + continue + rowno = rowlist[0] + rowheightinfo = r"\\[" + rowlist[1] +r"ex]" #要替换的行高信息 + #根据正则表达式,查找换行符"\\",每一个“\\"代表表格的一行,需要在“\\”的末尾添加行高数据:“\\[rowhtstr+ex]” + if islongtable: + singletablecontent = self.__ModifyLongTableRowHeight(rowno, rowheightinfo,singletablecontent) + else: + singletablecontent = self.__NormaltableRowHeight(rowno,rowheightinfo,singletablecontent) + + return singletablecontent + + def __StartModifyTableAttr(self, singletablecontent, columnwidthlst, islongtable, isCusHead=True,jsonmultirow=None,tablestyle_dict=None): + #修改表格的通用属性 + searchstr = r'(\\begin{tabular}|\\begin{tabulary})(\[[a-z]\]|{\\linewidth}\[[a-z]\])([\s\S].*)' + #为了添加表格的通用属性,先对字符串做分割 + #python会把正则表达式中的分组自动分割,因此上面的正则表达式会自动分割为三个字符串 + #加上头尾字符串总共分为5个字符串数组。要修改第1维字符串为\\being{longtable}, + #第2维字符串直接删除,第3维字符串不变 + splittable = re.split(searchstr, singletablecontent,0, re.M | re.I|re.U ) + if splittable is None or len(splittable) < 5: + #再修改长表格属性 + searchstr = r'\\begin{longtable}([\s\S].*)' + #为了添加表格的通用属性,先对字符串做分割 + #python会把正则表达式中的分组自动分割,因此上面的正则表达式会自动分割为三个字符串 + #加上头尾字符串总共分为5个字符串数组。要修改第1维字符串为\\being{longtable},第2维字符串直接删除, + #第3维字符串不变 + splittable = re.split(searchstr, singletablecontent,0, re.M | re.I|re.U ) + if len(splittable) < 3: + #至少是3维的数组,否则不是预想的内容,不做处理 + return singletablecontent + if isCusHead: + newtable4 = self.__ModifyLongTableHead(splittable[2], + columnwidthlst, + self.tablesattrobj.headtype, + jsonmultirow, + tablestyle_dict) + # begin{longtable}必须再加上,因为Python并不认为它是正则表达式,因此不再分组里面第0个分组为空。 + singletablecontent = splittable[0]+r'\begin{longtable}'+splittable[1]+newtable4 + if self.tablesattrobj.fontsize !="": + #设置表格字体 + singletablecontent=self.__ModifyTableFontSize(singletablecontent,self.tablesattrobj.fontsize) + return singletablecontent,True + + #拆分后splittable应该为5个字符串的数组,拆分后便于添加通用属性 + if self.tablesattrobj.rowtype != '': + splittable[0] += self.tablesattrobj.rowtype + '\n' + + if isCusHead is True: + #修改表头字体颜色为白色加粗 + newtable4 = self.__ModifyTableHead(splittable[4], + columnwidthlst, + self.tablesattrobj.headtype, + jsonmultirow, + tablestyle_dict) + else: + newtable4 = splittable[4] + + singletablecontent = splittable[0]+splittable[1]+splittable[2]+splittable[3]+newtable4 + #根据配置设置表格为长表格,现在基本通过将sphinx的class属性设置为longtable实现,不再通过手动设置完成。 + if islongtable is True: + singletablecontent = self.__ModifyTableLongHeadAndTail(singletablecontent) + + if self.tablesattrobj.fontsize !="": + #设置表格字体 + singletablecontent=self.__ModifyTableFontSize(singletablecontent,self.tablesattrobj.fontsize) + + return singletablecontent,False + + def __ModifyTableLongHeadAndTail(self,singletablecontent): + #长表格的头尾都要替换,因此单独封装成一个函数,否则长表格的表格序号将加2 + #替换第一行为长表格 + searchstr = r'(\\begin{savenotes}[\S]*?\\sphinxattablestart)' + splittable = re.search(searchstr, singletablecontent,re.M | re.I|re.U ) + firstend =splittable.end() + #替换为长表格 + tablefirstline = re.sub(r'\\sphinxattablestart',r'\n\\sphinxatlongtablestart\n', splittable.group(0), 1,re.M|re.I|re.U) + if r'\sphinxthistablewithglobalstyle' in singletablecontent: + tablefirstline += '\\sphinxthistablewithglobalstyle\n' + if r'\sphinxthistablewithvlinesstyle' in singletablecontent: + tablefirstline += '\\sphinxthistablewithvlinesstyle\n' + #替换第2行为长表格 + searchstr = r'(\\begin{tabular}|\\begin{tabulary})(\[[a-z]\]|{\\linewidth}\[[a-z]\])([\s\S].*)' + splittable = re.search(searchstr, singletablecontent,re.I|re.U ) + #记录表头的位置 + headstartpos= splittable.start() + headlastpos = splittable.end() + tablesecondline = re.sub(r'\\begin{tabular}|\\begin{tabulary}',r'\\begin{longtable}', splittable.group(0), 1,re.I|re.U) + tablesecondline = re.sub(r'\{\\linewidth\}',r'', tablesecondline, 1,re.I|re.U) + + #查找caption + searchstr = r'\\sphinxcaption([\s\S].*)' + splittable = re.search(searchstr, singletablecontent, re.I|re.U) + if splittable is not None: + longcaption = re.sub(r'\\sphinxcaption',r'\\caption', splittable.group(0), 1,re.I|re.U) + #添加长表格专用指令 + longcaption += r"\\*[\sphinxlongtablecapskipadjust]" + + #构建长表个的表头部分 + longhead = tablefirstline + tablesecondline + '\n' + r'\sphinxthelongtablecaptionisattop' + '\n'+ longcaption+'\n' + else: + #如果没有caption需要把中间的内容加上,否则会丢失label导致连接失效 + longhead = tablefirstline + singletablecontent[firstend:headstartpos]+tablesecondline + '\n' + + #替换表尾 + newtablecontent = singletablecontent[headlastpos:len(singletablecontent)] + index = newtablecontent.rfind('\\end{tabulary}') + if index > 0: + endtable = newtablecontent[index:len(newtablecontent)] + othercontent = newtablecontent[0:index] + endtable = re.sub(r'(\\end{tabular}|\\end{tabulary})', r'\\end{longtable}', endtable, 1,re.M | re.I | re.U) + endtable = re.sub(r'\\par', r'', endtable, 1, re.M | re.I | re.U) + endtable = re.sub(r'\\sphinxattableend', r'\\sphinxatlongtableend', endtable, 1, re.M | re.I | re.U) + newtablecontent = othercontent + endtable + else: + index = newtablecontent.rfind('\\end{tabular}') + if index >0: + endtable = newtablecontent[index:len(newtablecontent)] + othercontent = newtablecontent[0:index] + endtable = re.sub(r'(\\end{tabular}|\\end{tabulary})', r'\\end{longtable}', endtable, 1,re.M | re.I | re.U) + endtable = re.sub(r'\\par', r'', endtable, 1, re.M | re.I | re.U) + endtable = re.sub(r'\\sphinxattableend', r'\\sphinxatlongtableend', endtable, 1, re.M | re.I | re.U) + newtablecontent = othercontent + endtable + + singletablecontent = longhead + newtablecontent + return singletablecontent + + #为表格增加h属性对齐方式,避免表格浮动,让表格在当前位置显示 + def __AddHAttrForTable(self,content): + + newcontent="" + searchstr = r'(\\begin{tabular}|\\begin{tabulary}|\\begin{longtable})({\\linewidth})?(\[(?P[a-z]{1,4})\])?([\s\S].*)' + m = re.search(searchstr, content, re.M|re.I|re.U ) + attrcontent = m.group('attr') + posarr = m.span() + if m.group(3) == '': #group(3)是[htp]的组合,如果没有表格属性则添加上表格属性 + newcontent = content[0:posarr[0]]+m.group(1)+m.group(2)+r'[h]'+m.group(5)+content[posarr[1]:len(content)] + else: + replacestr = m.group(4) #需要被替换的表格属性 + #判断该表格是否有P属性,如果有p属性需要删掉 + replacestr.replace('p','') + #给该表格添加h属性,避免表格浮动,让表格在当前位置显示 + replacestr = 'h'+replacestr + newcontent = content[0:posarr[0]]+m.group(1)+m.group(2)+'['+replacestr+']'+m.group(5)+content[posarr[1]:len(content)] + + return newcontent + def __GetColumnParaByWidths(self,content): + # 可能用的widths参数设置的列宽,再根据widths参数解析列宽 + columnwidthlst = [] + # 如果列宽不等于100,为了使表格不超过页面显示范围,以100为基数重新调整各列的列宽 + newcolumnwidth = "|" + isadjustcol = False + newcontent = content + try: + searchstr = r'(?<=\|)\\X\{(?P[0-9]*)\}\{(?P[0-9]*)\}(?=\|)' + pattern = re.compile(searchstr, re.I | re.U) + match = pattern.finditer(content) + if match is None: + return [],content + for m in match: + count1 = int(m.group('count1')) + count2 = int(m.group('count2')) + + tempcount = round(count1/count2,2) + colwidth = "%(float)s\\textwidth" + colwidth = colwidth % { + 'float': str(tempcount) + } + columnwidthlst.append(colwidth) + if count2 != 100: + isadjustcol = True + #重新调整count1和count2的值,并加入列表 + percolwidth = r"\X{" + str(int(tempcount * 100)) + "}{100}|" + newcolumnwidth = newcolumnwidth + percolwidth + + if isadjustcol is True: + searchstr = r"(?<=\{)\|\\X[\s\S]*?\|(?=\})" + match = re.search(searchstr,content,re.I | re.U) + if match is not None: + newcontent = content.replace(match.group(),newcolumnwidth,1) + + except Exception as e: + print(e) + finally: + return columnwidthlst,newcontent + + def __GetColumnParameters(self,content): + ''' + 得到通过tablecolumns参数设置的列宽,multirow的自动换行,必须有列宽参数,否则无法实现自动换行。 + 而且不能通过widths参数设置每列的列宽,否则生成的列宽无法用于multirow,依然无法实现multirow的自动换行 + :param content: + :return: + ''' + # 默认ismultirowautowrap为自动换行 + # 如果表格用tabluarcolumns指令设置的列宽,因为每列的宽度有具体的值,则默认自动换行 + # 如果表格用widths指令设置的列宽,因为每列的宽度不具体,则默认不自动换行,需要手动换行 + self.tablesattrobj.ismultirowautowrap = True #默认ismultirowautowrap为自动换行 + newcontent = content + columnwidthlst=[] #保存每一列宽度的列表 + #根据表格内容,得到每一列宽度参数,方便multirow的时候自动换行 + #m垂直居中,p垂直靠上,b垂直靠下 + searchstr= r'(?<=\|)[\s\S]*?[m|p|b]\{(?P[\s\S]*?)\}[\s\S]*?(?=\|)' + pattern = re.compile(searchstr,re.I | re.U) + match = pattern.finditer(content) + if match is None: + #columnwidthlst,newcontent = self.__GetColumnParaByWidths(content) + self.tablesattrobj.ismultirowautowrap = False #则默认需要手动换行 + return [],newcontent + + for m in match: + columnwidthlst.append(m.group('width')) + if len(columnwidthlst) ==0: + #columnwidthlst,newcontent= self.__GetColumnParaByWidths(content) + self.tablesattrobj.ismultirowautowrap = False #则默认需要手动换行 + return [],newcontent + return columnwidthlst,newcontent + + def __FindTableHeadForFamily(self,content,isLongTable): + ''' + 根据sphinxstyletheadfamily标志查找表格的表头,有sphinxstyletheadfamily标志的表头,源文件必须以“======”分割。 + 具有合并单元格的表头,源文件必须以“===”分割,而且表头最多支持两行,否则合并单元格的表头不支持渲染。 + :param content: 表格内容 + :return: headcontent 表格内容 + ''' + #查找表头的起始位置,以第一个\hline开始。 + if isLongTable is True: + searchstr =r"(\\hline[\s\S]*)(?=\\endfirsthead)" + matchobj = re.search(searchstr, content, re.I|re.M|re.U) + tablehead = matchobj.group() + else: + searchstr = r'\\hline' + matchobj = re.search(searchstr,content,re.I) + if matchobj is None: + return "" + startpos = matchobj.start() #表头的起始位置 + familypos = content.rfind(r'\sphinxstyletheadfamily') #查找最后一个sphinxstyletheadfamily出现的位置 + if familypos ==-1: + return "" + + #从familypos开始,查找到的第一个“\\”,则为表头的行结束符号。 + endpos = content.find(r"\hline",familypos) + # 得到表头内容,表头内容从起始\hline开始,到\hline结束,这种查找方式,兼容表头为两行的情况。 + tablehead = content[startpos:endpos+len(r"\hline")] + + return tablehead + + def __FindAndFlagPos(self,tablehead,andpos): + if andpos != -1: + # 查找&前面有没有转义字符,有转义字符,则&不为单元格分割符,需要再继续查找 + preflag = tablehead[andpos-1] + if preflag == '\\': + #如果是\&说明这个&不是单元格的分隔符,则需要继续查找 + newpos=tablehead.find("&",andpos+1) + andpos = self.__FindAndFlagPos(tablehead,newpos) + return andpos + else: + return andpos + return andpos + + def __FindCellContentForHead(self,tablehead,startpos): + #查找每个单元格的内容 + if startpos == -1: + return 0,-1,None + + # 查找&和\\的位置,&为单元格的分割标志,\\为行结束标志,一个单元格的分割标志可能是&,也可能是\\ + cellcontent = [] + andpos = tablehead.find("&", startpos) + # 得到实际的分隔符位置,避免&内容里面的符号,而不是单元格分隔符 + andpos = self.__FindAndFlagPos(tablehead, andpos) + linepos = tablehead.find(r"\\", startpos) + + if andpos ==-1 and linepos ==-1: + return 0,-1,None + + newpos = -1 + lineflag = 0 # 是否启动新行的标志。0说明在当前行;1说明遇到了\\,需要启动新的一行 + if andpos == -1 or linepos < andpos: + # 该单元格的分割符为\\ + cellcontent.append(tablehead[startpos:linepos]) + cellcontent.append(r"\\") + newpos = linepos + len(r"\\") + lineflag = 1 + else: + cellcontent.append(tablehead[startpos:andpos]) + cellcontent.append(r"&") + newpos = andpos + len(r"&") + return lineflag,newpos,cellcontent + ''' + #暂时注释,不在使用该函数,使用上面函数的写法 + def __delNotMustStrFromContent(self, content): + content = content.strip().strip("\n").strip("\r") + for delstr in self.delstrlst: + if delstr in content: + content = content.replace(delstr,"") + + content = content.strip().strip("\n").strip("\r") + print('content=%s' % content) + strlst = content.split('\n') + newcontent = "" + for i in strlst: + if i != "": + i.strip("\r") + if i != strlst[-1] and i != "": + i = i + r"\\ " + newcontent = newcontent + i + content = self.autolinetemplat % { + 'autocontent': newcontent + } + return content + ''' + + def __delNotMustrowStrFromContent(self, content, ismanualwraplst=None): + """ + 该函数用于multirow的手动换行,为了更好的显示效果,multicolumn的手动换行请参见__delNotMustStrFromContent + :param content: + :param ismanualwraplst: list类型,只有一个元素。是否有手动换行,为了回传,并兼容之前的调用,设计了list类型。 + :return: + """ + # 实现自动换行 + # 去掉首尾换行和空白字符 + content = content.strip().strip("\n").strip("\r") + # 删除单元格内容里非必须的成分 + for delstr in self.delstrlst: + if delstr in content: + content = content.replace(delstr, "") + # 正则匹配既是加了“r”,也要用转义字符,因为r只是去掉了python的转义, + # 但正则本身也需要一层转义 + content = content.strip().strip("\n").strip("\r") + searchstr = r"\\sphinxAtStartPar" + match = re.search(searchstr, content, re.I | re.M) + if match is not None: + # 说明在字符串的起始位置匹配成功 + content = content[match.end():len(content)] + # 再去掉首尾空格和换行 + content = content.strip().strip("\n").strip("\r") + if r"\sphinxAtStartPar" in content: + # 字符串替换比较,如果用了r则无需再加一次转义字符 + # 替换所有的回车换行符,否则autolinetemplat中间不能有回车换行,否则将导致错误 + content = content.replace("\n", "").replace("\r", "") + content = content.replace(r"\sphinxAtStartPar", r"\newline ") + if ismanualwraplst is not None: + ismanualwraplst[0] = True + else: + strlst = content.split('\n') + newcontent = "" + for i in range(0, len(strlst)): + str = strlst[i].strip("\r") + if str != "": + strlst[i] = str + else: + if i > 0 and i < (len(strlst) - 1) and strlst[i - 1] != r"\newline ": + strlst[i] = r"\newline " + + content = "".join(strlst) + if r"\newline" in content and ismanualwraplst is not None: + ismanualwraplst[0] = True + + return content + + def __delNotMustStrFromContent(self,content,ismanualwraplst=None): + """ + 该函数用于multicolumn的手动换行,为了更好的显示效果,multirow的手动换行请参见__delNotMustrowStrFromContent + :param content: + :param ismanualwraplst: list类型,只有一个元素。是否有手动换行,为了回传,并兼容之前的调用,设计了list类型。 + :return: + """ + #实现自动换行 + #去掉首尾换行和空白字符 + content = content.strip().strip("\n").strip("\r") + #删除单元格内容里非必须的成分 + for delstr in self.delstrlst: + if delstr in content: + content = content.replace(delstr,"") + # 正则匹配既是加了“r”,也要用转义字符,因为r只是去掉了python的转义, + # 但正则本身也需要一层转义 + content = content.strip().strip("\n").strip("\r") + searchstr = r"\\sphinxAtStartPar" + match = re.search(searchstr,content,re.I|re.M) + if match is not None: + #说明在字符串的起始位置匹配成功 + content = content[match.end():len(content)] + # 再去掉首尾空格和换行 + content = content.strip().strip("\n").strip("\r") + if r"\sphinxAtStartPar" in content: + #字符串替换比较,如果用了r则无需再加一次转义字符 + #替换所有的回车换行符,否则autolinetemplat中间不能有回车换行,否则将导致错误 + content=content.replace("\n","").replace("\r","") + content = content.replace(r"\sphinxAtStartPar", r"\\ ") + content = self.autolinetemplat % { + 'autocontent': content + } + if ismanualwraplst is not None: + ismanualwraplst[0] = True + else: + strlst = content.split('\n') + newcontent = "" + for i in range(0,len(strlst)): + str = strlst[i].strip("\r") + if str != "": + strlst[i]=str + else: + if i>0 and i< (len(strlst)-1) and strlst[i-1] != r"\\ ": + strlst[i]=r"\\ " + + newcontent = "".join(strlst) + if r"\\ " in newcontent: + content = self.autolinetemplat % { + 'autocontent': newcontent + } + if ismanualwraplst is not None: + ismanualwraplst[0] = True + else: + content = newcontent + + return content + + def __ModifySphinxMultirow(self,cellstr,contentdict,columnwidthlst,jsonmultirow,lineindex,colindex,multirowindex,tablestyle_dict,isvertical=False): + ''' + 修改multirow内容 + 取出关键内容合并单元格的数量和单元格内容 + :param isvertical:如果是渲染第一列,multirow加\raggedright参数为了对齐 + 渲染行,无需加该参数,显示效果没明显差别。 + ''' + #searchstr=r"(?<=\\sphinxmultirow){(?P[0-9]*)}{(?P[0-9]*)}([\s\S]*?)\\sphinxstyletheadfamily(?P[\s\S]*?)\\par([\s\S]*?)(?=\}\%)" + searchstr=r"(?<=\\sphinxmultirow){(?P[1-9]*)}{(?P[0-9]*)}{([\s\S]*?){\\sphinxcolwidth{[0-9]*}{[0-9]*}}[\s]*(\\sphinxstyletheadfamily)?(?P[\s\S]*?)\\par([\s\S]*?)(?=\}\%)" + match = re.finditer(searchstr, cellstr, re.M | re.I | re.U) + count = "" + flag = "" + content = "" + for m in match: + if count == "": + count = m.group('count') + if flag == "": + flag = m.group('flag') + if content == "": + content = m.group('content') + + #如果是自动换行,设置列宽 + coluwidth = "*" + if len(columnwidthlst) > 0: + if tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict,"ismultirowautowrap"): + if tablestyle_dict['ismultirowautowrap'] is True: + coluwidth = columnwidthlst[colindex-1] + else: + if self.tablesattrobj.ismultirowautowrap is True: + coluwidth = columnwidthlst[colindex-1] + + #if (jsonmultirow is not None) and \ + # (multirowindex>=0 and len(jsonmultirow)>multirowindex) and \ + # (len(columnwidthlst)>jsonmultirow[multirowindex]-1): + # #得到要设置自动换行multirow的列号 + # coluwidth = columnwidthlst[jsonmultirow[multirowindex]-1] + # #print("coluwidth = %s" % coluwidth) + #else: + # coluwidth = "*" + + #if r"\sphinxstyletheadfamily" in cellstr: + #实测,multirow要支持渲染背景色,必须带上sphinxstyletheadfamily,否则会编译失败 + sphinxstyletheadfamily = r"\sphinxstyletheadfamily" + raggedright = "" + if (isvertical is True) or \ + (tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict, "ismultirowatupcell") and \ + tablestyle_dict['ismultirowatupcell'] is True) or \ + self.tablesattrobj.ismultirowatupcell is True: + #raggedright = r"{\raggedright}" + # 得到单元格内容 + cellstr = self.commonstr % { + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustrowStrFromContent(content) + } + multistr = self.multirowstr % { + 'count': count, + 'coluwidth': coluwidth, + 'raggedright': raggedright, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': "" + } + #将flag对应的内容加到字典里,以方便后续交换 + contentdict[flag] = str(int(count) + lineindex-1) + "," + multistr #为了支持超过两行的合并单元格,将count数量也带出去。第一个","之前的就是count数量 + + #返回空的字符串 + return cellstr + else: + ismanualwraplst = [False] + if (tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict, "ismultirowatlowcell") and \ + tablestyle_dict['ismultirowatlowcell'] is True ) or \ + self.tablesattrobj.ismultirowatlowcell is True: + + simplemultirow = self.simplemultirow % { + 'count': count, + 'coluwidth': coluwidth + } + + cellstr = self.commonstr % { + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustrowStrFromContent(content, ismanualwraplst) + } + + contentdict[flag] = str(int(count) + lineindex - 1) + "," + cellstr + + return simplemultirow + + multistr = self.multirowstr % { + 'count': count, + 'coluwidth': coluwidth, + 'raggedright': raggedright, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustStrFromContent(content) + } + # 将flag对应的内容加到字典里,以方便后续交换 + contentdict[flag] = str(int(count) + lineindex - 1) + "," + multistr # 为了支持超过两行的合并单元格,将count数量也带出去。第一个","之前的就是count数量 + + # 返回空的字符串 + return self.emptystr + + + + def __ModifyMulticolumn(self,cellstr,tablestyle_dict): + ''' + :param cellstr: + :return: + ''' + searchstr = r"(?<=\\multicolumn){(?P[1-9]*)}\{(?P[\s\S]*?)\}\{%(?P[\s\S]*?)\\sphinxstyletheadfamily(?P[\s\S]*?)\\par(?P[\s\S]*)" + match = re.finditer(searchstr, cellstr, re.M | re.I | re.U) + count = "" + content = "" + flag="" + prefix="" + suffix="" + for m in match: + if count == "": + count = m.group('count') + if flag == "": + flag = m.group('flag') + if content == "": + content = m.group('content') + if prefix == "": + prefix = m.group('prefix') + #如果有\begin{varwidth}[t],替换为\begin{varwidth}[c] + #[t]会使得当自动换行的时候,单元格上面会有很大的空白 + #print('===========================================') + #print(prefix) + searchstr = r"\\begin{varwidth}\[t\]" + replstr = r"\\begin{varwidth}[c]" + prefix = re.sub(searchstr,replstr,prefix,1,re.I|re.U|re.M) + #print(prefix) + #print('===========================================') + if suffix == "": + suffix = '\n\\par' + m.group('suffix') #因为正则表达式把\par去掉了,再加上 + + #行合并单元格,改成居中显示 + newflag=re.sub(r"[a-z]",r"c",flag,1,re.I) + #newflag = flag + if (tablestyle_dict is not None and \ + __dictIsHasKey__(tablestyle_dict,'ismulticolleftalign') and \ + tablestyle_dict['ismulticolleftalign'] is True) or \ + self.tablesattrobj.ismulticolleftalign is True: + #multistr = self.comcolstr % { + # 'count': count, + # 'flag': flag, + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(content) + #} + multistr = self.multicolumnstr % { + 'count':count, + 'flag': flag, + 'prefix':prefix, + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content':self.__delNotMustStrFromContent(content), + 'suffix':suffix + } + else: + multistr = self.multicolumnstr % { + 'count':count, + 'flag': newflag, + 'prefix':prefix, + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content':self.__delNotMustStrFromContent(content), + 'suffix':suffix + } + + #multistr = self.comcolstr % { + # 'count': count, + # 'flag': flag, + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(content) + #} + return multistr + + def __ModifyCellContentForMergeCell(self,cellcontent, + tablestyle_dict,islongtable, + contentdict,columnwidthlst,jsonmultirow, + colindex,lineindex,multirowindex,ismulticol = False): + ''' + 修改单元格内容 + :param cellcontent: + :param ismulticol: 当前行是否有列合并单元格。 + 如果有列合并单元格,则独立的单元格需要用multcolumn命令,否则无法对齐。 + :return: + ''' + multirowflag = False #是否是multirow单元格 + cellstr = cellcontent[0] #取得单元格内容 + #print('===========================') + #print(cellstr) + + #查找有没有\cline指令,如果有这个指令,为了划线清晰在前面再多加一个。 + searchstr= r"\\cline\{[\s\S]*?\}" + match = re.search(searchstr,cellstr,re.I | re.M) + if match is not None: + cline = match.group() + #去掉开头的换行符,否则影响后面的正则表达式,导致正则表达式查找失败 + cellstr = cellstr.lstrip() + cellstr = cline + cellstr + + sphinxstyletheadfamily="" + if r'\sphinxmultirow' in cellstr: + cellcontent[0] = self.__ModifySphinxMultirow(cellstr, + contentdict, + columnwidthlst, + jsonmultirow, + lineindex, + colindex, + multirowindex, + tablestyle_dict) + multirowflag = True + elif r'\multicolumn' in cellstr: + cellcontent[0] = self.__ModifyMulticolumn(cellstr,tablestyle_dict) + else: + if colindex == 1: + if len(columnwidthlst)>0 and colindex > 0: + flag = "|m{" + columnwidthlst[colindex - 1] + "}|" + else: + flag = "|l|" + else: + if len(columnwidthlst) > 0 and colindex > 0: + flag = "m{" + columnwidthlst[colindex - 1] + "}|" + else: + flag = "l|" + + #查找\sphinxAtStartPar位置 + precell = "" + pos = cellstr.find(r'\sphinxstyletheadfamily') + #print("cellstr = %s" % cellstr) + #print("pos=%d" % pos) + if pos > 0: + #说明前面还有内容,需要把前面的内容保存下来 + precell = cellstr[0:pos] + "\n" + if pos != -1: + cellstr = cellstr[pos+len(r'\sphinxstyletheadfamily'):len(cellstr)] + sphinxstyletheadfamily = "\\sphinxstyletheadfamily" + + ismanualwraplst=[False] + self.__delNotMustStrFromContent(cellstr,ismanualwraplst) + if ismanualwraplst[0] is True: + if colindex == 1: + flag = "|l|" + else: + flag = "l|" + colwidth = "" + if (len(columnwidthlst) > 0) and (colindex > 0): + colwidth = columnwidthlst[colindex-1] + + fullcontentstr = self.normalstr % { + 'flag': flag, + 'colwidth': colwidth, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustStrFromContent(cellstr) + } + else: + fullcontentstr = self.comcolstr % { + 'count': "1", + 'flag': flag, + 'sphinxstyletheadfamily': sphinxstyletheadfamily, + 'content': self.__delNotMustStrFromContent(cellstr) + } + #if ismulticol is True: + #必须用multicolumn命令,否则当单元格后面还有列的时候,刷新会出现问题 + #fullcontentstr = self.comcolstr % { + # 'count': "1", + # 'flag': flag, + # 'sphinxstyletheadfamily': sphinxstyletheadfamily, + # 'content': self.__delNotMustStrFromContent(cellstr) + #} + #else: + # fullcontentstr = self.commonstr % { + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(contentstr) + # } + cellcontent[0] =precell + fullcontentstr + #print('===========================') + return multirowflag + + def __FindMergeRowCountAndContent(self,multirowstr): + #找出合并行的数量和合并单元格的内容 + count = 0 + newmultistr = multirowstr + + try: + searchstr = "([1-9]*?)(?=,)" + match = re.match(searchstr,newmultistr,re.I) + if match is not None: + count = int(match.group(0)) + startpos = match.end() + newmultistr = multirowstr[startpos+1:len(multirowstr)] + finally: + return count,newmultistr + + def __AdjustMultirowContent(self,cellcontent,contentdict,lineindex,colindex,columnwidthlst): + ''' + 调整multirow单元格到对应的下面的单元格 + ''' + #查找flag数量 + cellstr = cellcontent[0] + #查找有没有\cline指令,如果有这个指令,为了划线清晰在前面再多加一个。 + searchstr= r"\\cline\{[\s\S]*?\}" + match = re.search(searchstr,cellstr,re.I | re.M) + if match is not None: + cline = match.group() + #去掉开头的换行符,否则影响后面的正则表达式,导致正则表达式查找失败 + cellstr = cellstr.lstrip() + cellstr = cline + cellstr + + searchstr = r"\\sphinxtablestrut\{(?P[0-9]*)\}" + matchobj = re.search(searchstr, cellstr, re.I | re.M | re.U) + flag = matchobj.group('flag') + #得到修改后的multirow内容 + if __dictIsHasKey__(contentdict,flag): + multirowstr = contentdict[flag] + #替换实际列号 + replstr = r'\\sphinxtablestrut{' + str(colindex)+'}' + cellstr = re.sub(searchstr,replstr,cellstr,1,re.I|re.U|re.M) + #找出合并行的最大数量 + count,multirowstr = self.__FindMergeRowCountAndContent(multirowstr) + if lineindex == count: #是最后一行了再合并,否则会重复 + cellcontent[0] = cellstr+"\n"+multirowstr + else: + cellcontent[0] = cellstr+'\n' + r"\cellcolor" + self.tablesattrobj.headtype + "{}" + + return cellcontent + + def __GetFactColIndex(self,cellcontent): + ''' + 得到真实的列索引。当有multicolumn的时候,列索引需要加合并单元格的数量。 + 返回合并单元格的数量,没合并单元格返回1 + :para cellcontent: 单元格内容字符串 + ''' + if len(cellcontent) ==0: + return 0 + searchstr = r"\\multicolumn{(?P[1-9]+)}" + match = re.search(searchstr,cellcontent,re.I|re.U) + if match is None: + return 1 #不是合并单元格,默认加1列 + factcount = 1 + try: + factcount = int(match.group('count')) + except Exception as e: + print('------------------') + print(e) + # traceback.print_stack() + traceback.print_exc() + print('------------------') + + return factcount + def __ModifyTableHeadForFamilyList(self,tablehead,isLongtable,columnwidthlst,jsonmultirow,tablestyle_dict): + ''' + 用列表的方式逐个拆解每个单元格内容,列表格式为[['cell content','分割符号&或者\\'],[],[]] + :param tablehead: + :return: + ''' + #内容的起始位置 + startpos = tablehead.find(r'\\hline') + len(r'\\hline') + + contentdict ={} #创建一个字典,用来保存mlutirow交换内容 + #得到每个单元格的内容 + colindex = 0 #是否第一个单元格的索引 + lineindex = 1 #行索引,为了支持多行合并,默认起始在第一行 + # multirowindex:第几个multirow单元格的索引号,用于作为conf.json中multirowcolumn字段的索引值, + # 取出当前multirow单元格真实所在列 + #得到该列的列宽,填写到multirow指令里面,实现multirow指令中的内容自动换行 + multirowindex = 0 + lineflag = 0 + newtablehead = "\\hline\n" + while startpos !=-1: + multirowflag = False # 是否是multirow处理,是的话__ModifyCellContentForMergeCell返回true,默认返回false + lineflag,startpos,cellcontent = self.__FindCellContentForHead(tablehead,startpos) + if cellcontent is not None: + factcolcount = self.__GetFactColIndex(cellcontent[0]) + colindex = colindex+factcolcount #当前列号 + if r"\sphinxtablestrut" in cellcontent[0]: + # 需要将multirow的内容copy到该单元格 + self.__AdjustMultirowContent(cellcontent, contentdict,lineindex,colindex,columnwidthlst) + else: + + multirowflag = self.__ModifyCellContentForMergeCell(cellcontent, + tablestyle_dict, + isLongtable, + contentdict, + columnwidthlst, + jsonmultirow, + colindex, + lineindex, + multirowindex) + newtablehead = newtablehead+cellcontent[0] + "\n" + cellcontent[1]+ "\n" + + if multirowflag is True: + multirowindex = multirowindex+1 + lineindex = lineindex + lineflag # 计算新的行数 + if lineflag: + # 如果开启了新行,则列索引从0开始计算 + colindex = 0 + + #组成新的表头内容 + newtablehead =newtablehead + "\\hline\n" + return newtablehead + def __modifyTableHeadForMinusSign(self,content,isLongTable,headtype,tablestyle_dict): + """ + 如果表头不是用“=”分割的,而是用“-”分割,调用该函数修改表头 + """ + # 先找出第一行 + if isLongTable: + searchstr = r'\\endlastfoot[\s]+(\\sphinxtableatstartofbodyhook)?(?P[\s\S]*?)\\hline' + else: + searchstr = r'\\hline(?P[\s\S]*?)\\hline' + m = re.search(searchstr, content, re.M | re.I | re.U) + headcontent = m.group('content') # 匹配到的第一个即为表头内容 + posarr = m.span('content') # 保存起始位置和结束位置,便于组成新的内容 + + if 'multicolumn' in headcontent: + return content + + if r'\sphinxstyletheadfamily' in headcontent: + pattern = re.compile(r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)', + re.M | re.I | re.U) + aftercontent = headcontent + # pattern = re.compile(r'(?<=\\sphinxstylethead{\\sphinxstyletheadfamily)([\s\S]*?)(?=\\unskip}\\relax &)', re.M | re.I|re.U) + else: + aftercontent = headcontent.replace(r'\\', '&', 1) + pattern = re.compile(r'(?P[\s\S]*?)(&\s{1})', re.M | re.I | re.U) + # pattern = re.compile(r'[\s\S]*?&|\\unskip', re.M | re.I|re.U) + + mobjarr = pattern.finditer(aftercontent) + headlist = [] + preposlist = [] + fontcolor = "" + for mobj in mobjarr: + amarr = mobj.group('value') + curposlist = [mobj.start(), mobj.start() + len(amarr)] + + # 用表头内容数组替换 + if self.tablesattrobj.headfontcolor != "": + fontcolor = self.tablesattrobj.headfontcolor + # 先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() # 再去掉首尾空格,避免多余的空格出现 + if amarr == '': + continue + amarr = self.__delNotMustStrFromContent(amarr) + fontcolor = fontcolor.replace('{}', '{' + amarr + '}', 1) + # 当表头没有合并单元格时,不能用cellcolor,否则表头不对齐 + # fontcolor = self.commonstr % { + # 'sphinxstyletheadfamily':'', + # 'content': self.__delNotMustStrFromContent(amarr) + # } + if len(preposlist) > 0: + headlist.append(headcontent[preposlist[1]:curposlist[0]]) + else: + headlist.append(headcontent[0:curposlist[0]]) + headlist.append(fontcolor) + preposlist = curposlist + + headlist.append(headcontent[preposlist[1]:len(headcontent)]) # 把最后一个字符串加上 + headcontent = '' + for prelist in headlist: + headcontent = headcontent + prelist + '\n' + newcontent = content[0:posarr[0]] + r'\rowcolor' + headtype + '\n' + headcontent + content[ + posarr[1]:len(content)] + return newcontent + #修改sphinx自动生成的长表格表头 + def __ModifyLongTableHead(self,content,columnwidthlst,headtype,jsonmultirow,tablestyle_dict): + + tablehead = self.__FindTableHeadForFamily(content,True) + if tablehead != "" and (r"\sphinxmultirow" in tablehead or r"\multicolumn" in tablehead): + #长表格自定义sphinxcolwidth的第2个参数默认为5,否则无法实现自动换行 + #self.normalstr= self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth':'5', + # 'sphinxstyletheadfamily':'%(sphinxstyletheadfamily)s', + # 'content':'%(content)s' + #} + newtablehead = self.__ModifyTableHeadForFamilyList(tablehead, + True, + columnwidthlst, + jsonmultirow, + tablestyle_dict) + return content.replace(tablehead,newtablehead,2) + + #先找出第一行 + if tablehead.strip() == r"\hline": + #进入该分支,说明表格设置为了长表格,但是表头没有用“=”分割,用“-”做的分割。 + return self.__modifyTableHeadForMinusSign(content,True,headtype,tablestyle_dict) + + searchstr = r'\\hline(?P[\s\S]*?)\\hline' + pattern = re.compile(searchstr,re.M | re.I|re.U) + matchiter = pattern.finditer(content) + posarr = [] + i = 0 + for m in matchiter: + + if i > 1: + break + + posarr.append([]) + posarr[i] = m.span() #保存起始位置和结束位置,便于组成新的内容 + + if i ==0: + newcontent = content[0:posarr[i][0]] + else: + newcontent = newcontent+content[posarr[i-1][1]:posarr[i][0]] + + newcontent += r'\hline\rowcolor'+headtype + headcontent = m.group(1) #匹配到的第一个即为表头内容 + + if 'multicolumn' in headcontent: + return content + + headlist = [] + + if r'\sphinxstyletheadfamily' in headcontent: + pattern = re.compile(r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)', re.M | re.I|re.U) + aftercontent = headcontent + mobjarr = pattern.finditer(aftercontent) + + preposlist = [] + for mobj in mobjarr: + amarr = mobj.group('value') + curposlist = mobj.span() + + #用表头内容数组替换 + fontcolor = self.tablesattrobj.headfontcolor + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + if amarr == '': + continue + amarr = self.__delNotMustStrFromContent(amarr) + fontcolor = fontcolor.replace('{}','{'+ amarr+'}',1) + # 当表头没有合并单元格时,不能用cellcolor,否则表头不对齐 + #fontcolor = self.commonstr % { + # 'sphinxstyletheadfamily': '', + # 'content': self.__delNotMustStrFromContent(amarr) + #} + if len(preposlist) > 0: + headlist.append(headcontent[preposlist[1]:curposlist[0]]) + else: + headlist.append(headcontent[0:curposlist[0]]) + headlist.append(fontcolor) + preposlist = curposlist + headlist.append(headcontent[preposlist[1]:len(headcontent)]) #把最后一个字符串加上 + headcontent = '' + for prelist in headlist: + headcontent = headcontent + prelist + '\n' + newcontent += headcontent+r'\hline' + i +=1 + newcontent += content[posarr[i-1][1]:len(content)] + return newcontent + + def __ModifyTableHead(self, content,columnwidthlst, headtype,jsonmultirow,tablestyle_dict): + + tablehead = self.__FindTableHeadForFamily(content,False) + if tablehead != "" and (r"\sphinxmultirow" in tablehead or r"\multicolumn" in tablehead): + #为了实现自动换行,普通表格自定义sphinxcolwidth的第2个参数默认为3,否则无法实现自动换行 + #self.normalstr= self.normalstr % { + # 'flag': '%(flag)s', + # 'colwidth':'3', + # 'sphinxstyletheadfamily':'%(sphinxstyletheadfamily)s', + # 'content':'%(content)s' + #} + # 得到每一列的列宽供multirow自动换行使用 + newtablehead = self.__ModifyTableHeadForFamilyList(tablehead, + False, + columnwidthlst, + jsonmultirow, + tablestyle_dict) + return content.replace(tablehead,newtablehead,2) + + return self.__modifyTableHeadForMinusSign(content,False,headtype,tablestyle_dict) + """ + #先找出第一行 + searchstr = r'\\hline(?P[\s\S]*?)\\hline' + m = re.search(searchstr, content, re.M|re.I|re.U ) + headcontent = m.group(1) #匹配到的第一个即为表头内容 + posarr = m.span(1) #保存起始位置和结束位置,便于组成新的内容 + + if 'multicolumn' in headcontent: + return content + + if r'\sphinxstyletheadfamily' in headcontent: + pattern = re.compile(r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)', re.M | re.I|re.U) + aftercontent = headcontent + #pattern = re.compile(r'(?<=\\sphinxstylethead{\\sphinxstyletheadfamily)([\s\S]*?)(?=\\unskip}\\relax &)', re.M | re.I|re.U) + else: + aftercontent = headcontent.replace(r'\\','&',1) + pattern = re.compile(r'(?P[\s\S]*?)(&\s{1})', re.M | re.I|re.U) + #pattern = re.compile(r'[\s\S]*?&|\\unskip', re.M | re.I|re.U) + + mobjarr = pattern.finditer(aftercontent) + headlist = [] + preposlist = [] + fontcolor="" + for mobj in mobjarr: + amarr = mobj.group('value') + curposlist = [mobj.start(),mobj.start()+len(amarr)] + + #用表头内容数组替换 + if self.tablesattrobj.headfontcolor !="": + fontcolor = self.tablesattrobj.headfontcolor + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + if amarr == '': + continue + amarr = self.__delNotMustStrFromContent(amarr) + fontcolor = fontcolor.replace('{}','{'+ amarr+'}',1) + #当表头没有合并单元格时,不能用cellcolor,否则表头不对齐 + #fontcolor = self.commonstr % { + # 'sphinxstyletheadfamily':'', + # 'content': self.__delNotMustStrFromContent(amarr) + #} + if len(preposlist) > 0: + headlist.append(headcontent[preposlist[1]:curposlist[0]]) + else: + headlist.append(headcontent[0:curposlist[0]]) + headlist.append(fontcolor) + preposlist = curposlist + + headlist.append(headcontent[preposlist[1]:len(headcontent)]) #把最后一个字符串加上 + headcontent = '' + for prelist in headlist: + headcontent = headcontent + prelist + '\n' + newcontent = content[0:posarr[0]]+r'\rowcolor'+headtype+'\n'+headcontent+content[posarr[1]:len(content)] + return newcontent + """ + + def __ModifySecondTableHeadForLongTable(self,headcontent): + #如果表头没有渲染过,对于只渲染第一列的长表格,分页的表格不能有重复的表头内容,需要将其删除 + startpos = headcontent.find(r"\hline") + newhead = headcontent[0:startpos] + "\\hline\n\\endhead" #因为前面删除了行结束符,最后再添加上 + return newhead + + def __GetTableRowCount(self,tablecontent): + ''' + 得到传进来的表格行数量 + 统计原则:latex以“\\”+回车换行为表格的一行,因此以"\\CRLF"或者“\\LF”作为一行, + 正则表达式统计"\\CRLF"或者“\\LF”的数量作为行的数量。 + :param tablecontent: 表格内容,必须去除了latex表格的前面参数部分,否则得到的第一行也包含其内容 + :return: 返回行的数量和每一行内容的列表,如果找不到的话,返回0和None。 + ''' + #以下正则表达式还可以解析出表格的每一行内容 + #在末尾补充一个换行符,避免如果最后一行以“\\”结尾,漏掉最后一行。 + tablecontent = tablecontent+'\n' + searchstr = "[\s\S]*?(?\\\\[\r|\r\n|\n|\n\r])[\s\S]*?" + pattern = re.compile(searchstr,re.I | re.M |re.U) + linelst = pattern.findall(tablecontent) + if linelst is not None: + return len(linelst),linelst + + return 0,None + + + def __ModifyVerticalLongTable(self, singletablecontent,tablestyle_dict,columnwidthlst): + #修改长表格的第一列 + #查找长表格的表头,如果是通过sphinx指令设置的长表格,表头分为三个部分, + #第一个表头以第一个\hline开始,以\endfirsthead结束 + #第二个表头,以\endfirsthead后续的内容开始,\endhead结束 + #下页继续的说明,以\endhead后续内容开始,以\endlastfoot结束,该部分内容不做处理 + #以\endlastfoot后续内容的第一部分即第一列的内容,需要添加背景色和字体颜色。后续的\hline按正常处理即可 + #先将所有表头内容找出来 + isCusHead = tablestyle_dict['isCusHead'] + headstartpos = singletablecontent.find(r"\hline") + headendpos = singletablecontent.find(r"\endfirsthead") + fcontent = singletablecontent[0:headstartpos] #字符串起始内容 + fhead = singletablecontent[headstartpos:headendpos+len(r"\endfirsthead")] #第一个表头内容 + #修改第一个表头内容 + if not isCusHead: #如果表头已经渲染过了,则不再对第一个单元格再渲染一边。 + fhead = self.__ModifyTableByLine(fhead,True,True,tablestyle_dict,columnwidthlst) + + #获得第2个表头的原始内容 + headstartpos = singletablecontent.find(r"\endhead") + shead = singletablecontent[headendpos+len(r"\endfirsthead"):headstartpos+len(r"\endhead")] #第2个表头的内容 + #修改第2个表头内容 + if not isCusHead: #如果表头已经渲染过了,则不再对第一个单元格再渲染一遍。 + shead = self.__ModifySecondTableHeadForLongTable(shead) + #获得第3部分内容 + headendpos = singletablecontent.find(r"\endlastfoot") + thead = singletablecontent[headstartpos+len(r"\endhead"):headendpos+len(r"\endlastfoot")] #第3部分表头内容 + othercontent = singletablecontent[headendpos+len(r"\endlastfoot"):len(singletablecontent)] #其它表格内容 + + #因为第3部分后续紧跟的内容即为下一行的内容,为了统一处理,在前面固定添加行元素,再修改完成后,再将其删除 + othercontent = "\\hline\n" + othercontent + #修改表格内容部分 + othercontent = self.__ModifyTableByLine(othercontent,True,True,tablestyle_dict,columnwidthlst) + #修改完成后,去掉固定添加的内容 + othercontent = othercontent[len("\\hline\n"):len(othercontent)] + return fcontent + fhead + shead + thead + othercontent + + def __RenderTableHeadByLine(self, linecontent, lineindex, isLongtable, columnwidthlst, jsonmultirow, tablestyle_dict, contentdict): + ''' + 该函数暂时未用 + 用列表的方式逐个拆解每个单元格内容,列表格式为[['cell content','分割符号&或者\\'],[],[]] + :return: + ''' + # 内容的起始位置 + searchstr = r"(\\hline|\\cline|\\sphinxcline)" + match = re.match(searchstr,linecontent) + if match is None: + return linecontent + if match.group() == r"\hline": + startpos = match.end() + newtablehead = "\\hline\n" + else: + startpos = 0 + newtablehead = "" + + # 得到每个单元格的内容 + colindex = 0 # 是否第一个单元格的索引 + lineindex = 1 # 行索引,为了支持多行合并,默认起始在第一行 + # multirowindex:第几个multirow单元格的索引号,用于作为conf.json中multirowcolumn字段的索引值, + # 取出当前multirow单元格真实所在列 + # 得到该列的列宽,填写到multirow指令里面,实现multirow指令中的内容自动换行 + multirowindex = 0 + lineflag = 0 + + while startpos != -1: + multirowflag = False # 是否是multirow处理,是的话__ModifyCellContentForMergeCell返回true,默认返回false + lineflag, startpos, cellcontent = self.__FindCellContentForHead(linecontent, startpos) + if cellcontent is not None: + factcolcount = self.__GetFactColIndex(cellcontent[0]) + colindex = colindex + factcolcount # 当前列号 + if r"\sphinxtablestrut" in cellcontent[0]: + # 需要将multirow的内容copy到该单元格 + self.__AdjustMultirowContent(cellcontent, contentdict, lineindex,colindex,columnwidthlst) + else: + + multirowflag = self.__ModifyCellContentForMergeCell(cellcontent, + tablestyle_dict, + isLongtable, + contentdict, + columnwidthlst, + jsonmultirow, + colindex, + lineindex, + multirowindex) + newtablehead = newtablehead + cellcontent[0] + "\n" + cellcontent[1] + "\n" + + if multirowflag is True: + multirowindex = multirowindex + 1 + lineindex = lineindex + lineflag # 计算新的行数 + if lineflag: + # 如果开启了新行,则列索引从0开始计算 + colindex = 0 + + # 组成新的表头内容 + #newtablehead = newtablehead + "\\hline\n" + return newtablehead + + def __ModifyLineContent(self,linecontent,isvertical,islongtable,lineindex,columnwidthlst,colno,maxcolno,tablestyle_dict,rowdict): + ''' + 修改行内容,当colno为-1时,则修改整行内容,否则根据colno的值修改单元格 + :param linecontent: 该行的内容,以行结束标志“\\”结束。 + :param columnwidthlst: 保存的每一列的列宽内容,根据tablecolumns参数解析出来 + 如果实现自动换行,必须用tablecolumns这个参数,不能用widths参数 + :param colno: 要修改的列号,如果为-1,则修改整行内容。否则根据colno的值修改单元格。 + 如果该行有合并列单元格,则colno为合并单元格后的列号。 + :param maxcolno: 将合并单元格拆分后,得到的最大的列号,该列号为了取得列宽,使multirow内容自动换行。 + :param rowdict: 传出参数,如果是multirow的合并单元格,则返回单元格的内容 + :return: 返回修改后的行的内容。 + ''' + if isvertical is False: + #修改行,暂时不支持修改任意行,表头的渲染由其他函数完成 + #按表头内容渲染 + + jsonmultirow = [] # 保存需要设置自动换行的合并单元格的列号。 + if __dictIsHasKey__(tablestyle_dict, 'multirowcolumn'): + jsonmultirow = tablestyle_dict['multirowcolumn'] + newlinecontent = self.__RenderTableHeadByLine(linecontent, + lineindex, + islongtable, + columnwidthlst, + jsonmultirow, + tablestyle_dict, + rowdict) + return newlinecontent + + #下面的内容修改列,即只修改行的第一个单元格 + linecontentstr=linecontent + #startpos = linecontent.find('\\hline') + #lineprefix = "" + #if startpos != -1: + # linecontentstr = linecontent[startpos + len("\\hline"):len(linecontent)] + # lineprefix = linecontent[0:startpos+len("\\hline")] + #得到行内容和前缀 + lineprefix = "" + searchstr=r"\\hline|\\cline{\S+?}|\\sphinxcline{\S+?}\\sphinxfixclines{\d+?}" + match = re.search(searchstr,linecontent,re.I|re.U|re.M) + if match is not None: + linecontentstr = linecontent[match.end():len(linecontent)] + lineprefix = linecontent[0:match.start()+ len(match.group())] + + othercontent = "" + newlinecontent = "" + + #暂时仅支持修改第一列 + lineflag,nextpos,cellcontent = self.__FindCellContentForHead(linecontentstr,0) + if (cellcontent is None) or (r"\cellcolor" in cellcontent[0]): + #说明该单元格已经修改过了,直接返回 + return linecontent + + othercontent = linecontentstr[nextpos:len(linecontentstr)] + + if r"\sphinxtablestrut" in cellcontent[0]: + # 需要将multirow的内容copy到该单元格 + self.__AdjustMultirowContent(cellcontent, rowdict, lineindex,maxcolno,columnwidthlst) + else: + jsonmultirow =[1] + cellstr = cellcontent[0] # 取得单元格内容 + + if r'\sphinxmultirow' in cellstr: + cellcontent[0] = self.__ModifySphinxMultirow(cellstr, + rowdict, + columnwidthlst, + jsonmultirow, + lineindex, + maxcolno, + 0, + tablestyle_dict, + isvertical) + else: + coluwidth = "*" + if len(columnwidthlst) > 0: + coluwidth = columnwidthlst[0] + # 查找\sphinxAtStartPar位置 + pos = cellstr.find(r'\sphinxstyletheadfamily') + if pos != -1: + contentstr = cellstr[pos + len(r'\sphinxstyletheadfamily'):len(cellstr)] + + fullcontentstr = self.commonstr % { + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content': self.__delNotMustStrFromContent(contentstr) + } + #fullcontentstr = self.multirowstr % { + # 'count': 1, + # 'coluwidth': coluwidth, + # 'sphinxstyletheadfamily': r"\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(contentstr) + #} + #fullcontentstr = self.normalstr % { + # 'flag': '|l|', + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(contentstr) + #} + + cellcontent[0] = fullcontentstr + else: + fullcontentstr = self.commonstr % { + 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + 'content': self.__delNotMustStrFromContent(cellstr) + } + #fullcontentstr = self.multirowstr % { + # 'count': 1, + # 'coluwidth': coluwidth, + # 'sphinxstyletheadfamily': r"\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(cellstr) + #} + #fullcontentstr = self.normalstr % { + # 'flag': '|l|', + # 'sphinxstyletheadfamily': "\\sphinxstyletheadfamily", + # 'content': self.__delNotMustStrFromContent(cellstr) + #} + + cellcontent[0] = fullcontentstr + + newlinecontent = lineprefix +"\n" + cellcontent[0] + "\n" + cellcontent[1] + "\n" + othercontent + return newlinecontent + + + def __ModifyTableByLine(self,tablecontent,isvertical,islongtable,tablestyle_dict,columnwidthlst): + ''' + 按行渲染每一个单元格。根据正则表达式解析出每一行的内容, + 然后再解析这一行中每一个单元格的内容,根据条件判断是否需要渲染该单元格 + :param tablecontent: + 表格内容,对于长表格去除表格头的表格,对于普通表格,整个表格内容 + :param isvertical: + 是否修改列还是修改行 + :return: 返回修改后的内容 + ''' + #找出每一行的内容 + + searchstr = r'(\\hline|\\cline{\S+}|\\sphinxcline{\S+})(?P[\s\S]*?)(?=\\hline|\\cline{\S+}|\\sphinxcline{\S+})' + pattern = re.compile(searchstr,re.M | re.I|re.U) + # 得到每一行的内容,以\hline或者\cline开头,以\\结尾 + linelst = pattern.findall(tablecontent) + if len(linelst) ==0: + return tablecontent + + tableprefix = "" #得到latex表格的前缀参数部分,除表格内容外的部分 + tablepostfix = "" #得到latex表格的后缀参数部分,出表格内容外的部分 + prepos = tablecontent.find(''.join(linelst[0])) + if prepos > 0: + tableprefix = tablecontent[0:prepos] + lastline = ''.join(linelst[-1]) + postpos = tablecontent.find(lastline) + if (postpos+len(lastline)) < len(tablecontent): + tablepostfix = tablecontent[postpos+len(lastline):len(tablecontent)] + + newtablecontent = "" + rowdict = {} #保存合并单元格的字典,用于multirow + for lineindex in range(0,len(linelst)): + line =''.join(linelst[lineindex]) + newline= self.__ModifyLineContent(line,isvertical,islongtable,lineindex+1,columnwidthlst,1,1,tablestyle_dict,rowdict) + newtablecontent = newtablecontent + newline + + newtablecontent = tableprefix + newtablecontent+tablepostfix + return newtablecontent + + #渲染树型表格的第一列 + def __ModifyVerticalTable(self, singletablecontent, tablestyle_dict, columnwidthlst): + + #找出每一行的内容 + searchstr = r'(\\hline|\\cline{\S+}|\\sphinxcline{\S+})(?P[\s\S]*?)(?=\\hline|\\cline{\S+}|\\sphinxcline{\S+})' + pattern = re.compile(searchstr,re.M | re.I|re.U) + matchiter = pattern.finditer(singletablecontent) + posarr=[] #保存位置,便于组合 + i = 0 + for m in matchiter: + + posarr.append([]) + posarr[i] = m.span() + + if i ==0: + newcontent = singletablecontent[0:posarr[i][0]]+m.group(1) + else: + newcontent = newcontent+singletablecontent[posarr[i-1][1]:posarr[i][0]]+m.group(1) + + cellcontent = m.group('content') + #匹配到的第一个即为表头内容 + #将第一个单元格内容渲染成蓝底白字 + firstcellcontent = self.__ModifyFirstColumnType(cellcontent) + newcontent += firstcellcontent + i+=1 + newcontent += singletablecontent[posarr[i-1][1]:len(singletablecontent)] + return newcontent + + #渲染第一个单元格内容 + def __ModifyFirstColumnType(self,cellcontent): + + new_cellcontent = "" + + if r'\sphinxstyletheadfamily' in cellcontent: + + searchstr = r'(?<=\\sphinxstyletheadfamily)(?P[\s\S]*?)(?=(\\unskip|&)|\\\\)' + + aftercontent = cellcontent.strip() + aftercontent = aftercontent.strip('\r') + aftercontent = aftercontent.strip('\n') + aftercontent = aftercontent.strip() + mobj = re.search(searchstr, aftercontent, re.M|re.I|re.U ) #匹配到的第一个既是需要修改的内容 + + #修改字体颜色 + amarr = mobj.group('value') + posarr = mobj.span() + new_cellcontent = aftercontent[0:posarr[0]]+'\n'+r'\cellcolor'+self.tablesattrobj.headtype + #用表头内容数组替换 + fontcolor = self.tablesattrobj.headfontcolor + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + amarr = amarr.strip() + amarr = amarr.strip('\r') + amarr = amarr.strip('\n') + amarr = amarr.strip() #再去掉首尾空格,避免多余的空格出现 + #if amarr == '': + # continue + if (r'\textbf' or r'\textcolor') in amarr: + return cellcontent + fontcolor = fontcolor.replace('{}','{'+ amarr+'}',1) + new_cellcontent +=r'{'+fontcolor + '}\n' + aftercontent[posarr[1]:len(aftercontent)] + else: + aftercontent = cellcontent.replace(r'\\','&',1) + #去掉首尾空格和换行符 + aftercontent = aftercontent.strip() + aftercontent = aftercontent.strip('\r') + aftercontent = aftercontent.strip('\n') + aftercontent = aftercontent.strip() + + tmplist = re.split(r'&',aftercontent) + + preposlist = 0 + #只对第一个做修改 + onelist = tmplist[0] + #先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + onelist = onelist.strip() + onelist = onelist.strip('\r') + onelist = onelist.strip('\n') + onelist = onelist.strip() #再去掉首尾空格,避免多余的空格出现 + #if onelist == '': + # continue + if (r'\textbf' or r'\textcolor') in onelist: + return cellcontent + new_cellcontent = self.__ModifyFirstColumnContentPart(onelist) + + for i in range(1,len(tmplist)): + if len(tmplist[i])>0: + new_cellcontent += '&' +tmplist[i] + new_cellcontent+=r'\\' #将最后被替换掉的\\再加上 + + return new_cellcontent + '\n' + + def __ModifyFirstColumnContentPart(self,firstcontent): + #查找firtcontent的内容部分,因为如果存在合并单元格,第一列的内容起始部分不一定是内容,可能是latex的其它标签。 + #sphinx生成的latex表格都以“\sphinxAtStartPar”标识开始 + startpos = firstcontent.find(r"\sphinxAtStartPar") + firstpart = "" + if startpos==-1: + contentpart = firstcontent + else: + firstpart = firstcontent[0:startpos] + contentpart=firstcontent[startpos:len(firstcontent)] + fontcolor = self.tablesattrobj.headfontcolor + # 先去掉首尾空格,避免首尾有空格无法去掉回车换行符 + new_cellcontent = '\n'+r'\cellcolor'+self.tablesattrobj.headtype+r'{'+fontcolor.replace('{}','{'+ contentpart+'}',1)+r'}'+'\n' + return firstpart+new_cellcontent + + def __ModifyReplaceContent(self): + ''' + 修改replace内容。为了删除sphinx在替换时多加的空格,在替换时,每个替换内容前后都加了"@@"。 + 对于前后都加了“@@”的内容,删除前后的"@@",同时删除前后的空格。 + :return: + ''' + searchstr = r"@@(?P[\s\S]+?)@@" + pattern = re.compile(searchstr) + match = pattern.search(self.content) + if match is None: + return None + str = match.group() + replace = match.group('content') + + prepos = match.start() + endpos = match.end() + replaceed = str + if prepos > 1: + #判断前一个字符是否为空格 + prechar = self.content[prepos-1] + if prechar==" ": + replaceed = " "+replaceed + + endchar = self.content[endpos] + if endchar == " ": + replaceed = replaceed + " " + + #进行替换 + self.content = self.content.replace(replaceed,replace) + + return self.content + + + def ModifyReplaceContent(self): + + while True: + ''' + 可能会有多个需要替换,直到找不到为止 + ''' + result = self.__ModifyReplaceContent() + if result is None: + break + +class clsSensitiveWords: + + def __init__(self,entext): + self.passkey="Hello123" + self.entext = entext + self.jsonwords = GetSensitiveword() + + def GetFullSensitiveWords(self): + senwords = "|".join(self.jsonwords) + detext = "" + if self.entext !="": + command="echo '" + self.entext +"'| openssl aes-128-cbc -d -base64 -pbkdf2 -k " + self.passkey + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + detext = detext + line.decode().strip() + senwords = senwords + '|' + detext + + return senwords + +class clsCheckDocCopyrightYear: + ''' + 检查文档是否符合规范。 + 2022年07月12日更新 + ------------------- + 增加对当前年份的检查,如果年份不符合要求,自动修改conf.py和copyright.rst文档中的年份。 + ''' + def __init__(self,app,copyrightfile): + self.app=app + self.confdir = app.confdir + self.language = app.config.language + self.confright = app.config.copyright + if (app.config.latex_elements is not None) and __dictIsHasKey__(app.config.latex_elements,'preamble'): + self.latex_preamble = app.config.latex_elements['preamble'] + else: + self.latex_preamble = "" + + if len(copyrightfile) > 0: + self.copyrightfile = copyrightfile + elif hasattr(app.config,"copyfile_path") and len(app.config.copyfile_path)>0: + self.copyrightfile = app.config.copyfile_path + else: + self.copyrightfile = "" + + def __FindCopyrightFile(self): + ''' + #如果copyrightfile为空,则查找copyrightfile文件 + ''' + if self.copyrightfile =="": + + copyrightpre = "copyright/copyright" + filepath = os.path.join(self.app.srcdir,copyrightpre + ".rst") + + if os.path.isfile(filepath): + return filepath + + if self.language=="zh_CN": + filepath = os.path.join(self.app.srcdir,copyrightpre + "_zh.rst") + else: + filepath = os.path.join(self.app.srcdir,copyrightpre +"_en.rst") + + if os.path.isfile(filepath): + return filepath + else: + return "" + else: + filepath = os.path.join(self.confdir , self.copyrightfile) + + #转为绝对路径 + filepath = os.path.abspath(filepath) + if os.path.isfile(filepath): + return filepath + else: + return "" + + def __IsModifyfootYear(self,year): + ''' + 查找lfoot设置版权的语句年份是否需要修改 + :return: True - 需要修改年份 + False - 不需要修改年份 + ''' + + ismodify = False + searchstr = r'\\lfoot{[ ]*Copyright [\s\S]*?}' + pattern = re.compile(searchstr,re.I|re.U) + matchiter = pattern.finditer(self.latex_preamble) + if matchiter is None: + return ismodify + + for m in matchiter: + lfoot = m.group() + if year in lfoot: + #如果当前年份已在脚注中,则认为年份已修改,则不再修改。 + break + else: + # 如果当前年份不在脚注中,则认为年份还未修改,需要修改。 + ismodify = True + break + return ismodify + + def __ModifyYearofConfforDetail(self,year,ispremble=False): + ''' + 修改conf.py中的年份 + :param ispremble: 是否只修改ispremble中脚注的年份。 + True-只修改脚注的年份;False-同时修改copyright和脚注的年份。 + ''' + #打开文件 + conffile = self.confdir +"/conf.py" + fo = codecs.open(conffile, "r+", encoding='utf-8') + textcontent = fo.read() + fo.close() + + if ispremble is False: + #修改copyright的年份 + searchstr = r"[0-9]{4}" + matchobj = re.search(searchstr,self.confright) + if matchobj is not None: + copyright = re.sub(searchstr,year,self.confright) + textcontent=textcontent.replace(self.confright,copyright,1) + #为了看到最终效果,同时修改已经读取的变量,因为调用该函数前,conf.py的配置已经被sphinx读取 + self.app.config.copyright=copyright + + #开始替换脚注中的年份 + #得到脚注中带有年份的声明 + searchstr = r'\\lfoot{[ ]*Copyright [\s\S]*?}' + pattern = re.compile(searchstr, re.I | re.U) + matchiter = pattern.finditer(self.latex_preamble) + if matchiter is not None: + for m in matchiter: + lfoot = m.group() + searchstr= r"[0-9]{4}" + matchobj = re.search(searchstr,lfoot) + if matchobj is not None: + newlfoot = re.sub(searchstr,year,lfoot) + textcontent=textcontent.replace(lfoot,newlfoot) + #两个脚注相同,因此只替换一次即可 + # 为了看到最终效果,同时修改已经读取的变量,因为调用该函数前,conf.py的配置已经被sphinx读取 + self.app.config.latex_elements['preamble'] = self.latex_preamble.replace(lfoot,newlfoot) + break + #将替换后的内容重新写入文件 + fw = codecs.open(conffile, "w+", encoding='utf-8') + fw.write(textcontent) + fw.close() + + def __ModifyYearofCopyright(self,year): + + #得到copyright绝对文件名 + filepath = self.__FindCopyrightFile() + if filepath == "": + return + + fo = codecs.open(filepath, "r+", encoding='utf-8') + textcontent = fo.read() + fo.close() + #新的版权年份 + newyear = '© ' + year + #if newyear in textcontent: + # return + isupdate = False + #避免替换错误,循环找到最后一个进行替换 + searchstr = "©( )*([0-9]{4})" + pattern = re.compile(searchstr, re.I | re.U) + matchiter = pattern.finditer(textcontent) + for m in matchiter: + oldyear = m.group() + if oldyear == newyear: + continue + textcontent = textcontent.replace(oldyear,newyear,1) + isupdate = True + + if isupdate: + fw = codecs.open(filepath, "w+", encoding='utf-8') + fw.write(textcontent) + fw.close() + #pattern = re.compile(searchstr, re.I | re.U) + #matchlst = pattern.findall(textcontent) + #if len(matchlst)==0: + # #如果没有找到四数字的年份,直接返回 + # return + + #i=0 + #content="" + #matchiter = pattern.finditer(textcontent) + #for m in matchiter: + # if (i+1)==len(matchlst): + # #对最后一个年份进行替换 + # content = textcontent[0:m.start()] + year + textcontent[m.end():len(textcontent)] + # break + # else: + # i = i+1 + + #fw = codecs.open(filepath, "w+", encoding='utf-8') + #fw.write(content) + #fw.close + + def __ModifyYearofConf(self,year): + #开始修改conf.py中的年份 + if (year not in self.confright) and (self.confright != ""): + #修改conf.py中的年份 + self.__ModifyYearofConfforDetail(year) + return + #开始修改conf.py中脚注的年份 + if (self.latex_preamble!="") and (self.__IsModifyfootYear(year) is True): + #需要修改年份,走到这一步conf.py的copyright年份肯定修改过的,因此只修改foot中的年份即可 + self.__ModifyYearofConfforDetail(year,True) + return + + def CheckYearIsRight(self): + ''' + 检查年份是否正确 + ''' + #得到系统当前年份 + today = dt.date.today() + year = today.strftime('%Y') + #开始检查conf.py中的年份 + self.__ModifyYearofConf(year) + #开始检查copyright文件中的年份 + self.__ModifyYearofCopyright(year) + + def __AddAdSymbolForReplaceZh(self,content): + ''' + 解析是否需要替换,需要的话就进行替换。 + 替换的条件:1.被替换字符串全部为中文,2。被替换字符的首尾有一个不为英文 + :param content: 要替换的内容 + :return: 已替换或者未替换的字符串 + ''' + #用正则表达式解析出要替换的内容 + searchstr = r"[\s\S]+ replace:: (?P[\s\S]+)" + match = re.search(searchstr,content) + if match is None: + return content + replacestr = match.group('replace').strip('\r') + #取得字符串的首和尾字符 + #prochr = ord(replacestr[0]) + #print(prochr) + #epichr = ord(replacestr[-1]) + #print(epichr) + #if ((64 < prochr and prochr<91) or \ + # (96 < prochr and prochr<123)) and \ + # ((64 < epichr and epichr < 91) or \ + # (96 < epichr and epichr < 123)): + # #首尾都为英文则无需替换空格,直接返回 + # return content + newreplacestr = "@@"+replacestr+"@@" + return content.replace(replacestr,newreplacestr) + + def __ParseReplaceList(self,contentlst): + ''' + 为中文替换词前后添加@@符号 + :param contentlst: 需要替换的list内容 + :return: 新生成列表,如果无需替换,返回None。 + ''' + if contentlst is None or len(contentlst)==0: + return None + + for i in range(0,len(contentlst)): + content = contentlst[i] + if content=="": + continue + #进行替换 + newcontent = self.__AddAdSymbolForReplaceZh(content) + contentlst[i] = newcontent + + return contentlst + + def CheckReplaceContent(self): + ''' + 该函数用来判断conf.py文件中是否包含rst_epilog或者rst_prolog的配置 + 如果有的话,修改这两个参数中的配置在需替换的名词前后都加@@符号,以方便在latex文件中删除前后的空格。 + 该修改为了解决sphinx对中文的替换前后都加空格的bug。 + :return: + ''' + epilogstr = self.app.config.rst_epilog + prologstr = self.app.config.rst_prolog + + if self.app.config.language != "zh_CN": + ''' + 如果不是中文文档,则直接返回。 + ''' + return + + if prologstr is not None: + # 为rst_prolog的中文替换前后添加@@符号 + prologlst = prologstr.split('\n') + newlst = self.__ParseReplaceList(prologlst) + if newlst is not None: + self.app.config.rst_prolog = '\n'.join(newlst) + + if epilogstr is not None: + #为rst_epilog的中文替换前后添加@@符号 + epiloglst = epilogstr.split('\n') + newlst = self.__ParseReplaceList(epiloglst) + if newlst is not None: + self.app.config.rst_epilog = '\n'.join(newlst) + + + +class clsCheckHtmlRule: + ''' + 检查html配置是否符合规范 + ''' + def __init__(self,app): + self.app=app + #错误描述模板 + #配置参数错误 + self.configparam_error="Parameter %(param)s configuration error!" + self.correctvalue="The correct value is %(value)s!" + self.config_error=self.configparam_error + self.correctvalue+"\n" + #工程名称配置不一致 + self.inconsistent_error="The project names configured by conf.py and %(rootdoc)s.rst are inconsistent!\n" + #缺少custom.css文件 + self.csslack_error="custom.css file not found!\n" + self.config_detail = "For detailed configuration, see wiki: pageId=46137795\n" + + def __CheckCommonParamForHtml(self, theme): + ''' + 检查相关配置是否正确 + :param theme:为主题文件,如果和配置的不一致返回错误。默认为sphinx_rtd_theme。 + ''' + error = "" + if self.app.config.html_theme != theme: + error = self.config_error % { + 'param': 'html_theme', + 'value': theme + } + if self.app.config.html_copy_source is True: + error = error + self.config_error % { + 'param': 'html_copy_source', + 'value': 'False' + } + # 判断配置的css文件是否存在 + if len(self.app.config.html_css_files) == 0: + error = error + self.configparam_error % { + 'param': 'html_css_files', + } + \ + "Please configure a css file, such as custom.css." + "The css file needs to be found under the path specified by html_static_path!" + if len(self.app.config.html_static_path) == 0: + error = error + self.configparam_error % { + 'param': 'html_static_path', + } + \ + "Please configure the path of the file specified by html_css_files." + return error + + def CheckProjectNameConsistent(self): + ''' + 检查conf.py中project配置工程名称和主索引文件index.rst中的标题是否一致 + 两者不一致返回错误。 + 该接口必须在builder-inited事件之后调用,否则得不到index.rst的标题,无法比较。 + ''' + #为了兼容性获取config对象的属性,没有的属性不做比较 + cfgattrlst=dir(self.app.config) + #主索引文件名 + if "master_doc" in cfgattrlst: + indexname = self.app.config.master_doc + elif "root_doc" in cfgattrlst: + indexname = self.app.config.root_doc + else: + return "" + + #得到主索引文件的标题 + indextitle = self.app.env.titles[indexname][0] + #得到配置的工程名称 + projectname = self.app.config.project + if indextitle != projectname: + error = self.inconsistent_error % { + 'rootdoc': indexname + } + return error + else: + return "" + + def CheckCommConfigHtmlRule(self, theme): + ''' + 检查生成html前参数配置是否符合规范 + ''' + # 开始检查参数配置 + error = self.__CheckCommonParamForHtml(theme) + return error + +def CheckProjectNameIsConsistent(app): + ''' + 检查conf.py中配置的工程名称与主索引文件的标题是否一致 + ''' + checkDocobj = clsCheckHtmlRule(app) + return checkDocobj.CheckProjectNameConsistent() + +def CheckHtmlConfigIsCorrect(app,theme='sphinx_rtd_theme'): + ''' + 检查html配置参数是否正确 + ''' + checkDocobj = clsCheckHtmlRule(app) + return checkDocobj.CheckCommConfigHtmlRule(theme) + +def CheckCurrentYearAndModify(app,copyrightfile=""): + ''' + 检查conf.py和copyright中的年份是否为当前年份,如果不是当前年份则进行修改。 + Conf.py主要检查latex_elements和copyright中的年份是否为当前年份 + copyright主要检查copyright.rst文件最后一句的年份是否为当前系统年份。 + :param app: + :param copyrightfile: 相对于conf.py的相对路径。 + ''' + checkDocobj = clsCheckDocCopyrightYear(app,copyrightfile) + #检查年份是否需要修改 + checkDocobj.CheckYearIsRight() + #if (app.builder is not None) and \ + # (app.builder.name=='latex' or app.builder.name=='latexpdf'): + # checkDocobj.CheckReplaceContent() + +def __returnCfgDate(app): + ''' + 返回conf.py中配置的today的日期。如果today为空,则返回当前日期。 + ''' + if app.config.today != "": + return app.config.today + today = dt.date.today() + if app.config.language =="zh_CN": + return today.strftime('%Y年%m月%d日') + else: + return today.strftime('%b %d, %Y') + +def CheckUpdateHistory(app,docname,source): + ''' + 检查是否需要添加更新历史,在source_read_handler这个事件里调用。 + 为了提高效率,不写在类里,减少创建和销毁对象的消耗。 + :param app: + :param docname: + :param source: 源文件内容 + ''' + #判断文档中是否包含更新历史 + # 默认为英文 + title = "Update History" + update="Date:" + changes="Changes:" + language = app.config.language + if language == "zh_CN": + title = "更新历史" + update="更新时间:" + changes="更新内容:" + index = source.find(title) + if index==-1: + return source + + #说明该文档中包含更新历史,判断是否为更新历史标题 + precontent = source[0:index] + leftcontent = source[index:len(source)] + #按行读取剩余内容,理论上第2行为标题的下标,如果不是,则认为这个更新历史不是标题,直接忽略。 + linelst = leftcontent.splitlines(True) + flagline = linelst[1] + search = r"([\x21-\x2f|\x3a-\x40|\x5b-\x60|\x7b-\x7d])\1+(?=[\r|\n|\r\n])" + matchobj=re.match(search,flagline) + if matchobj is None: + return source + #说明是更新历史章节,再比较版本号是否一致 + search = r"\* \*\*V(?P[0-9]{1,2}\.[0-9]{1,2}\.[0-9]{1,2})\*\*" + matchobj=re.search(search,leftcontent) + if matchobj is None: + return source + version=matchobj.group('version') + if version==app.config.version: + return source + #如果版本不一致,自动添加新的版本内容 + dateformat = __returnCfgDate(app) + updatehistory="\n* **V%(version)s**\n\n" \ + " **%(date)s** %(dateformat)s\n\n" \ + " **%(changes)s**\n\n" \ + " - 无内容更新。\n" + + upcontent= updatehistory % { + 'version':app.config.version, + 'date':update, + 'dateformat':dateformat, + 'changes':changes + } + #拆分列表 + titlestr =''.join(linelst[0:2]) + otherstr =''.join(linelst[2:len(linelst)+1]) + newcontent = precontent + titlestr + upcontent + otherstr + #将新内容写入文档 + filepath = app.project.doc2path(docname) + fw = codecs.open(filepath, "w+", encoding='utf-8') + fw.write(newcontent) + fw.close() + return newcontent + + +def __exe_command_forprint(command): + """ + 执行 shell 命令并实时打印输出 + :param command: shell 命令 + :return: process, exitcode + """ + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + print(line.decode().strip()) + exitcode = process.wait() + return process, exitcode + +def CheckSensitiveWordsForLinux(app,docname,entext): + ''' + 该接口对外提供,可在conf.py的source_read_handler处理函数中调用。对sphinx工程中的文件检查敏感词。 + + **注意:** + + 因为检查敏感词使用的linux命令,如果敏感词中包含中文,则在不支持中文的linux系统下不能使用该接口, + 否则无法完成敏感词。 + + :param app: sphinx环境变量,可以获取输出目录、文件后缀等相关信息。 + :param docname: 需检查的文件路径。该路径是基于源目录的相对路径,而且不包含文件扩展名, + 因此在检查时需要补充上扩展名,并转为绝对路径。 + :param entext: 已加密的密文,如果为空,则只对conf.json中保存的敏感词做检查。 + 如何加密参考GetFullSensitiveWords的实现,加解密保持一致。 + :return: 无返回值。 + ''' + #windows系统请使用工具组开发的sensitive-filter.exe工具 + try: + if platform.system().lower() == 'windows': + return + + filepath = app.project.doc2path(docname,True) + + #print("Checking file:["+docname+"]") + sensitivewordsobj = clsSensitiveWords(entext) + sensitivewords = sensitivewordsobj.GetFullSensitiveWords() + #print(sensitivewords) + cmd = "egrep -rni --color=always '"+ sensitivewords + "' " + filepath + __exe_command_forprint(cmd) + except Exception as e: + print("Sensitive words check failed, the system may not support Chinese. " + "The detailed error information is as follows:") + print(e) + return + +# 打开Makefile文件查找source和build文件夹 +def OpenMakefile(): + global source_dir + global build_dir + source_dir = '' + build_dir = '' + try: + if platform.system().lower() == 'windows': + with open('make.bat',"r") as f: + fstr = f.read() + + #用正则表达式查找source和build文件夹具体路径 + searchstr = r"set *SOURCEDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + source_dir = m.group(1) #匹配到的第一个即为source所在目录 + + searchstr = r"set *BUILDDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + build_dir = m.group(1) #匹配到的第一个即为build所在目录 + else: + with open('Makefile',"r") as f: + fstr = f.read() + + #用正则表达式查找source和build文件夹具体路径 + searchstr = r"SOURCEDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + source_dir = m.group(1) #匹配到的第一个即为源所在目录 + + searchstr = r"BUILDDIR *= *(\S+)" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + build_dir = m.group(1) #匹配到的第一个即为源所在目录 + + except Exception as e: + print(e) + return + +def GetLatex_documents(): + global source_dir + if source_dir == '': + return + #得到配置文件conf.py的路径 + if source_dir == '.': + confdir = './conf.py' + else: + confdir = './' + source_dir +'/conf.py' + + conffile = os.path.abspath(confdir) + #打开conf.py文件 + with codecs.open(conffile,"r+",encoding='utf-8') as f: + fstr = f.read() + list = [] + versioninfo = GetVersionInfo(fstr) + fileprefile = GetFilePreInfo(fstr) + if versioninfo=="" or fileprefile=="": + #根据正则表达式,找出latex_documents内容 + searchstr = r"latex_documents *= *\[([\s\S]*?)\]" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + latex_documents = m.group(1) #匹配到的第一个即为源所在目录 + #拆分二维数组,兼容多个情况 + list = latex_documents.split(")") + for i in range(len(list)): + if IsComment(list[i]): + list[i]= list[i].split(",") + list.pop() + else: + #创建2维列表 + for i in range(1): + list.append([]) + for j in range(2): + list[i].append("") + list[0][0] = "comment" #为了兼容文件解析内容,进行补位。 + list[0][1] = '"' + fileprefile + versioninfo + '.tex"' #为了兼容解析过程,添加双引号。 + return list + +#得到版本信息 +def GetVersionInfo(fstr): + + if releasever != '': + return releasever + + versioninfo = "" + searchstr = r"version *= *(u*)'(.*?)'" + m = re.search(searchstr, fstr, re.I) + if m is not None: + versioninfo = m.group(2) #匹配到的第一个即为源所在目录 + print("version = " + versioninfo) + return versioninfo + +#得到文件前缀信息 +def GetFilePreInfo(fstr): + filepre = "" + searchstr = r"curfnpre *= *(u*)'(.*?)'" + m = re.search(searchstr, fstr, re.M|re.I|re.U ) + if m is not None: + filepre = m.group(2) #匹配到的第一个即为源所在目录 + return filepre + +#判断是否为注释行 +def IsComment(instr): + + if instr.strip() is None: + return False + + rule = re.compile('^#.*$') + if rule.match(instr.strip()) is None: + return True + else: + return False + +#根据正则表达式取单引号和双引号中的内容 +def getquomarkcontent(strarr): + #根据正则表达式,找出双引号和单引号中的内容 + searchstr = r"[\"|'](.*?)[\"|']" + m = re.search(searchstr, strarr, re.M|re.I|re.U ) + if m is None: + return None + return m.group(1).strip() #匹配到的第一个即为源所在目录 + +def GetConfjsondict(app): + ''' + 得到配置文件conf.json的字典 + :param app: + :return: + ''' + global __load_dict__ + + + if app is not None: + #先判断conf.json路径在配置文件中是否配置 + if hasattr(app.config, "confjson_path") and len(app.config.confjson_path) > 0: + confpath = os.path.abspath(app.confdir + "/" + app.config.confjson_path) + else: + confpath = os.path.join(app.confdir ,"conf.json") + + if os.path.exists(confpath): + __load_dict__ = __openconfjsonfile__(confpath) + else: + __load_dict__ = __openconfjsonfile__() + else: + __load_dict__ = __openconfjsonfile__() + + +def Modifylatex_main(build_dir,latexdocumentslst,app=None,language='zh_CN'): + + global __load_dict__ + global doc_language + doc_language=language + + if len(__load_dict__) == 0: + GetConfjsondict(app) + + + doclen = len(latexdocumentslst) + for i in range(0,doclen): + #得到latex路径 + latexpath = build_dir + desbkpaperfile= os.path.join(latexpath,'chapterbkpaper.pdf') + #copy 背景图到latex路径,背景图必须与该文件在同一个目录下 + if (app is not None) and hasattr(app.config,'chapterbkpaper_image') and (app.config.chapterbkpaper_image!=""): + #转为绝对路径进行判断,兼容conf.py和parsejson.py不在同一目录的情况 + bkpaperimage = os.path.abspath(os.path.join(app.confdir,app.config.chapterbkpaper_image)) + if os.path.isfile (bkpaperimage): + shutil.copy(bkpaperimage,desbkpaperfile) + elif os.path.exists('./chapterbkpaper.pdf'): + shutil.copy('./chapterbkpaper.pdf', latexpath) + else: + #取配置文件所在目录下的chapterbkpaper.pdf,将其copy到latex文件夹 + bkpaperimage = os.path.abspath(os.path.join(app.confdir,"chapterbkpaper.pdf")) + if os.path.isfile(bkpaperimage): + shutil.copy(bkpaperimage, latexpath) + + #得到相对路径 + filename = latexdocumentslst[i][1] + if filename is None: + continue + texfilepath = os.path.join(latexpath,filename) + #相对路径转绝对路径 + texfile = os.path.abspath(texfilepath) + if not os.path.exists(texfile): + continue + fo = codecs.open(texfile, "r+",encoding = 'utf-8') + texcontent = fo.read() + fo.close() + + #得到修改tex文件的对象 + ModTexobj = clsModifyTex(texcontent) + ModTexobj.AddPackageToTex() + ModTexobj.AddOtherTocToTex() + ModTexobj.AddCustormOptionsToTex() + ModTexobj.ModifyReplacePackage() + ModTexobj.ModifyFunctionBackColor() + ModTexobj.ModifyTablesAttributes() + #if (app is not None) and \ + # (app.builder.name=='latex' or app.builder.name=='latexpdf'): + # ModTexobj.ModifyReplaceContent() + + fw = codecs.open(texfile, "w+",encoding = 'utf-8') + fw.write(ModTexobj.content) + fw.close() + +def parseargv(): + global source_dir + global latex_dir + if len(sys.argv) <= 1: + return + for i in range(1,len(sys.argv)): + param = sys.argv[i] + if param=='-c': #解析conf.py所在目录 + source_dir = sys.argv[i+1] #保存相对路径 + #print("argv source=" + source_dir) + i = i+2 + elif param=='-l': #保存输出的latex目录 + latex_dir = sys.argv[i+1] #保存相对路径 + i = i+2 + #print("latex_dir = "+latex_dir) + +source_dir = '' #保存源文件所在目录 +build_dir = '' #保存生成文件所在目录 +releasever = '' +latex_dir = '' +doc_language ='' +__load_dict__ = {} + +if __name__ == '__main__': + + parseargv() #解析系统参数 + OpenMakefile() + + latex_documents = GetLatex_documents() #保存latex文档所在数组 + __load_dict__ = __openconfjsonfile__() + + latexdir="" + if latex_dir == '': + latexdir = '/latex/' + else: + latexdir = '/' + latex_dir + '/' + + doclen = len(latex_documents) + for i in range(0,doclen): + #得到latex路径 + latexpath = './' + build_dir + latexdir + #copy 背景图到latex路径,背景图必须与该文件在同一个目录下 + if os.path.exists('./chapterbkpaper.pdf'): + shutil.copy('./chapterbkpaper.pdf',latexpath) + #得到相对路径 + if getquomarkcontent(latex_documents[i][1]) is None: + continue + texfilepath = latexpath + getquomarkcontent(latex_documents[i][1]) + #相对路径转绝对路径 + texfile = os.path.abspath(texfilepath) + if not os.path.exists(texfile): + continue + fo = codecs.open(texfile, "r+",encoding = 'utf-8') + texcontent = fo.read() + fo.close() + + #得到修改tex文件的对象 + ModTexobj = clsModifyTex(texcontent) + ModTexobj.AddPackageToTex() + ModTexobj.AddOtherTocToTex() + ModTexobj.AddCustormOptionsToTex() + ModTexobj.ModifyFunctionBackColor() + ModTexobj.ModifyTablesAttributes() + ModTexobj.ModifyReplacePackage() + + fw = codecs.open(texfile, "w+",encoding = 'utf-8') + fw.write(ModTexobj.content) + fw.close() diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/requirements.txt b/torch_mlu_ops-v1.3.2/docs/user_guide/requirements.txt new file mode 100755 index 0000000..8194e34 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/requirements.txt @@ -0,0 +1,6 @@ +sphinx==3.3.1 +breathe==4.24.1 +sphinx_rtd_theme + +jinja2==2.10.2 +markupsafe==1.1.1 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf.json b/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf.json new file mode 100644 index 0000000..07c07e8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf.json @@ -0,0 +1,164 @@ +{ + "package": [ + "\\\\usepackage{amsmath}", + "\\\\usepackage[table]{xcolor}", + "\\\\usepackage{eso-pic}", + "\\\\usepackage{wallpaper}", + "%\\\\usepackage{titlesec}", + "\\\\usepackage{tocbibind}", + "% \\\\usepackage{draftwatermark}", + "\\\\usepackage{enumitem}", + "\\\\usepackage{tcolorbox}", + "\\\\usepackage{listings}", + "\\\\usepackage{framed}", + "\\\\usepackage{color}", + "\\\\usepackage{multirow}", + "% \\\\usepackage[justification=centering]{caption}" + ], + "replacepackage": { + "comment": "该章节内容是为了替换现有tex里面包的配置,左边为tex文件现有内容,右边是替换内容。", + "\\\\usepackage{hyperref}": "\\\\usepackage[bookmarksnumbered=true]{hyperref}", + "\\\\sphinxtableofcontents": + "\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}\n\\\\sphinxtableofcontents", + "\\\\chapter{([\\s\\S].*)}": + "\\\\chapter{\\1}\n\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}", + "\\\\listoffigures": "\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}\n\\\\listoffigures", + "\\\\listoftables": + "\\\\newpage\n\\\\ThisURCornerWallPaper{1}{chapterbkpaper.pdf}\n\\\\listoftables", + "\\\\footnotesize\\\\raggedright\\\\printindex": + "% \\\\footnotesize\\\\raggedright\\\\printindex", + "\\\\begin{itemize}": "\\\\begin{itemize}[leftmargin=*]", + "\\\\begin{enumerate}": "\\\\begin{enumerate}[leftmargin=*]", + "\\\\setmainfont{FreeSerif}\\[[\\s\\S]*?\\]": "", + "\\\\setsansfont{FreeSans}\\[[\\s\\S]*?\\]": "", + "\\\\setmonofont{FreeMono}\\[[\\s\\S]*?\\]": "", + "Extension([\\s\\S]*?)= .otf,": "Extension = ,", + "\\\\sphinxtoprule": "\\\\hline", + "\\\\sphinxmidrule": "\\\\hline", + "\\\\sphinxbottomrule": "\\\\hline", + "\\\\sphinxhline": "\\\\hline", + "\\\\sphinxhyphen": "{sphinx:5.3.0}\\\\PYGZhy", + "\\\\begin{sphinxalltt}": "\\\\begin{sphinxVerbatim}[commandchars=\\\\\\\\\\\\{\\}]", + "\\\\end{sphinxalltt}": "\\\\end{sphinxVerbatim}", + "\\\\begin{sphinxadmonition}{note}{注解:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=注解:]", + "\\\\begin{sphinxadmonition}{warning}{警告:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=警告:]", + "\\\\begin{sphinxadmonition}{tip}{小技巧:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=小技巧:]", + "\\\\begin{sphinxadmonition}{attention}{注意:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=注意:]", + "\\\\begin{sphinxadmonition}{hint}{提示:}": + "\\\\begin{tcolorbox}[colframe={hintframecolor},colback={hintbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=提示:]", + "\\\\end{sphinxadmonition}": "\\\\end{tcolorbox}", + "\\\\begin{sphinxadmonition}{note}{Note:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Note:]", + "\\\\begin{sphinxadmonition}{warning}{Warning:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Warning:]", + "\\\\begin{sphinxadmonition}{tip}{Tip:}": + "\\\\begin{tcolorbox}[colframe={noteframecolor},colback={notebackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Tip:]", + "\\\\begin{sphinxadmonition}{attention}{Attention:}": + "\\\\begin{tcolorbox}[colframe={warningframecolor},colback={warningbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Attention:]", + "\\\\begin{sphinxadmonition}{hint}{Hint:}": + "\\\\begin{tcolorbox}[colframe={hintframecolor},colback={hintbackcolor},coltitle=white,fonttitle=\\\\bfseries,title=Hint:]" + }, + "customoptions": [ + "% \\\\numberwithin{figure}{chapter}", + "% \\\\numberwithin{table}{chapter}", + "\\\\titleformat{\\\\chapter}{\\\\raggedleft\\\\huge\\\\bfseries\\\\color{white}}{\\\\thechapter}{0.5em}{}", + "\\\\titlespacing{\\\\chapter}{10pt}{40pt}{25pt}", + "\\\\definecolor{noteframecolor}{RGB}{91,163,235}", + "\\\\definecolor{notebackcolor}{RGB}{222,237,251}", + "\\\\definecolor{warningframecolor}{RGB}{235,162,28}", + "\\\\definecolor{warningbackcolor}{RGB}{255,247,236}", + "\\\\definecolor{hintframecolor}{RGB}{70,193,196}", + "\\\\definecolor{hintbackcolor}{RGB}{226,254,249}", + "\\\\definecolor{camblue}{RGB}{0,89,196}", + "% \\\\SetWatermarkText{Cambricon}", + "% \\\\SetWatermarkLightness{0.9}", + "% \\\\SetWatermarkScale{1}", + "\\\\renewcommand{\\\\labelitemi}{$\\\\vcenter{\\\\hbox{\\\\scriptsize$\\\\bullet$}}$}", + "\\\\definecolor{shadecolor}{RGB}{220,220,220}" + ], + "isfiguretabletoc": { + "comment": "插图目录英文:List of Figures;表格目录英文:List of Tables.", + "isfigurestoc": false, + "istablestoc": true, + "figurestoc": [ + "\\\\renewcommand\\\\listfigurename{插\\ 图\\ 目\\ 录}", + "\\\\listoffigures" + ], + "tablestoc": [ + "\\\\renewcommand\\\\listtablename{表\\ 格\\ 目\\ 录}", + "\\\\listoftables" + ] + }, + "tables": { + "comment": + "isname:true-根据表格的name属性查找表格,false-根据表格的标题查找表格。isLongTable:true-强制设置为长表格,false-不设置长表格。isVertical:true-对第一列进行渲染,false-不对第一列进行渲染。isCusHead:true-对第一行进行渲染,false-不对第一行进行渲染。", + "isname": true, + "rowtype": "", + "headtype": "{camblue!100}", + "headfontcolor": "\\textbf{\\textcolor{white}{}}", + "styles": [ + { + "align": "centering", + "caption": "allfilestandard", + "captionalign": "left", + "isLongTable": false, + "isVertical": false, + "isCusHead": true, + "fontsize": "footnotesize", + "ispartial": true + }, + { + "align": "centering", + "caption": "gptconfignumber", + "captionalign": "left", + "isLongTable": false, + "isVertical": false, + "isCusHead": true, + "fontsize": "footnotesize", + "ispartial": true + } + ] + }, + "image": { + "styles": [ + { + "name": "", + "align": "", + "caption": "" + } + ] + }, + "sensitivewords": [ + "安防", + "监听", + "stream", + "hisi", + "展讯", + "英伟达", + "nvidia", + "讯飞", + "展锐", + "c10", + "c20", + "IK", + "AH", + "US", + "MLU320" + ], + "ignorewarndes": [ + "该字段已废弃不再使用,但不能删掉,否则会导致编译错误!" + ], + "ignorewarnkey": [ + "sdk_memory.rst && 文档没有加入到任何目录树中", + "sdk_memory.rst && document isn't included in any toctree", + "Package hyperref Warning", + "LaTeX Font Warning", + "Package cmap Warning", + "Package xeCJK Warning" + ] + +} diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf.py b/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf.py new file mode 100755 index 0000000..71eb0ff --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- +# pylint: disable=C0301, C0305 +# +# Torch-MLU-Ops documentation build configuration file, created by +# sphinx-quickstart on Tue Apr 23 18:33:05 2020. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) +from __future__ import print_function + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +#import sys +#reload(sys) +#sys.setdefaultencoding('utf-8') +extensions = [ + 'sphinx.ext.mathjax', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'寒武纪Torch-MLU-Ops用户手册' +copyright = u'2024, Cambricon' +author = u'' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = u'1.3.2' +# The full version, including alpha/beta/rc tags. +release = version +curfnpre=u'Cambricon-Torch-MLU-Ops-User-Guide-CN-v' +curfn=curfnpre + version + '.tex' +today = "" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'zh_CN' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build','_venv','_templates'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +smartquotes = False +numfig = True +numfig_secnum_depth = 1 +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_copy_source = False +html_css_files = [ + 'custom.css', +] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['../_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# This is required for the alabaster theme +# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars +html_sidebars = { + '**': [ + 'relations.html', # needs 'show_related': True theme option to display + 'searchbox.html', + ] +} + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'Torch-MLU-Opsdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, curfn, u'寒武纪Torch-MLU-Ops用户手册', + author, 'manual'), +] + + +# xelatex 作为 latex 渲染引擎,因为可能有中文渲染 +latex_engine = 'xelatex' + + # 如果没有使用自定义封面,可以加上封面 logo。这里可以是 pdf 或者 png,jpeg 等 +latex_logo = "./logo.png" + +latex_use_latex_multicolumn = True +latex_show_urls = 'footnote' + +latex_use_xindy = False +# 主要的配置,用于控制格式等 +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +# +'papersize': 'a4paper', + +# The font size ('10pt', '11pt' or '12pt'). +# +# 'pointsize': '8pt', + +'inputenc': '', +'utf8extra': '', +'figure_align': 'H', +'releasename':"版本", +'sphinxsetup': ''' + verbatimwithframe=false, + verbatimwrapslines=true, + VerbatimColor={RGB}{220,220,220}, + verbatimhintsturnover=false, +''', +# Additional stuff for the LaTeX preamble. +'preamble': ''' +\\addto\\captionsenglish{\\renewcommand{\\chaptername}{}} + +% 设置字体 + +\\usepackage{xeCJK} +\\usepackage{fontspec} +\\CJKsetecglue{\\hskip0.08em plus0.01em minus 0.01em} +\\defaultfontfeatures[\\rmfamily,\\sffamily,\\ttfamily]{} +\\setCJKmainfont{Noto Sans CJK SC}[AutoFakeSlant] +\\setCJKsansfont{Noto Sans CJK SC}[AutoFakeSlant] +\\setCJKmonofont{Noto Sans Mono CJK SC}[AutoFakeSlant] +\\setmainfont{Noto Sans CJK SC}[AutoFakeSlant] + +% 以下配置将页面的可显示范围调宽 +% 调宽后背景图一定要使用chapterbkpaper_width.pdf,否则一级标题的显示效果会很差 +\\geometry{ +top=26mm, +bottom=22mm, +left=10mm, +right=10mm, +} + +%%清除 latex headheight small错误编译告警 +\\setlength{\\headheight}{14.0pt} + +% 段首缩进 2 格 + +%\\usepackage{indentfirst} +%\\setlength{\\parindent}{2em} + +\\usepackage{setspace} + +% 1.5 倍行间距 +\\renewcommand{\\baselinestretch}{1.5} +% 表格里的行间距 +\\renewcommand{\\arraystretch}{1.5} + +% list 列表的缩进对齐 +\\usepackage{enumitem} +\\setlist{nosep} + +% 表格类的宏包 +\\usepackage{threeparttable} +\\usepackage{array} +\\usepackage{booktabs} + +% fancy 页眉页脚 +\\usepackage{fancyhdr} +\\pagestyle{fancy} + +% 在 sphinx 生成的 tex 文件里,normal 是指普通页面(每一个章节里,除了第一页外剩下的页面) +% 页眉,L:left,R:right,E:even,O:odd +% 奇数页面:左边是 leftmark,章号和章名称;右边是 rightmark,节号与节名称 +% 偶数页面:左边是 rightmark,节号与节名称;右边是 leftmark,章号和章名称 +% textsl 是字体,slanted shape,对于英语而言,某种斜体。但是不同于 textit 的 italic shape +% 左页脚:版权信息 +% 右页脚:页码数 +% rulewidth:页眉和页脚附近的横线的粗细,当设置为 0pt 时,就没有该横线 +% +\\fancypagestyle{normal} { +\\fancyhf{} +\\fancyhead{} +\\fancyhead[LE,RO]{\\textsl{\\rightmark}} +\\fancyhead[LO,RE]{\\textsl{\\leftmark}} +\\lfoot{Copyright © 2024 Cambricon Corporation.} +\\rfoot{\\thepage} +\\renewcommand{\\headrulewidth}{0.4pt} +\\renewcommand{\\footrulewidth}{0.4pt} +} + +% 在 sphinx 生成的 tex 文件里,plain 是指每个章节的第一页等 +\\fancypagestyle{plain} { +\\fancyhf{} +% left head 还可以内嵌图片,图片可以是 pdf,png,jpeg 等 +% \\lhead{\\includegraphics[height=40pt]{cn_tm.pdf}} +\\lhead{\\large\\textcolor[rgb]{0.1804,0.4588,0.7137}{Cambricon®}} +\\lfoot{Copyright © 2024 Cambricon Corporation.} +\\rfoot{\\thepage} +\\renewcommand{\\headrulewidth}{0.4pt} +\\renewcommand{\\footrulewidth}{0.4pt} +} +''', + + +# +'printindex': r'\footnotesize\raggedright\printindex', + + +# 移除空白页面 +'extraclassoptions': 'openany,oneside', + +# 如果需要用 latex 自已做封面,可以使用 maketitle +# 下面这个封面的例子来自于互联网 + +# 'maketitle': r''' +# \pagenumbering{Roman} %%% to avoid page 1 conflict with actual page 1 +# +# \begin{titlepage} +# \centering +# +# \vspace*{40mm} %%% * is used to give space from top +# \textbf{\Huge {Sphinx format for Latex and HTML}} +# +# \vspace{0mm} +# \begin{figure}[!h] +# \centering +# \includegraphics[width=0.8\textwidth]{cn.png} +# \end{figure} +# +# % \vspace{0mm} +# % \Large \textbf{{Meher Krishna Patel}} +# +# % \small Created on : Octorber, 2017 +# +# % \vspace*{0mm} +# % \small Last updated : \MonthYearFormat\today +# +# +# %% \vfill adds at the bottom +# % \vfill +# % \small \textit{More documents are freely available at }{\href{http://pythondsp.readthedocs.io/en/latest/pythondsp/toc.html}{PythonDSP}} +# \end{titlepage} +# +# \clearpage +# \pagenumbering{roman} +# \tableofcontents +# \listoffigures +# \listoftables +# \clearpage +# \pagenumbering{arabic} +# +# ''', +# +} # latex_elements + +import os.path as osp +import sys +import traceback + +try: + # conf_callback.py一般和conf.py放在同一目录下即可,如果有特殊需求请根据conf_callback.py的实际路径进行修改。 + if osp.exists(osp.join(osp.dirname(__file__), r'./conf_callback.py')): + sys.path.append(osp.join(osp.dirname(__file__), '.')) + else: + sys.path.append(osp.join(osp.dirname(__file__), '../')) + + from conf_callback import * + +except Exception as e: + # 如果出现了异常,请检查上面默认配置的两个路径是否满足要求,如果不满足要求,请修改上面的路径直到没异常发生。 + print(e) + traceback.print_exc() + #app.setup_extension('custoctree') diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf_callback.py b/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf_callback.py new file mode 100644 index 0000000..dafa534 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/conf_callback.py @@ -0,0 +1,370 @@ +#!/usr/bin/python +# -*- coding: UTF-8 + +from __future__ import print_function + +# sphinx事件回调函数 +#如果copy到conf.py文件中,只copy下面的代码即可。 + +import os +import os.path as osp +import sys +import importlib +import traceback +import datetime as dt +import locale + +#如果文档编译失败,请检查下面默认的路径是否正确,请根据实际路径进行修改 + +curpath = os.path.dirname(os.path.realpath(__file__)) + +if osp.exists(osp.join(osp.dirname(__file__), './parsejson.py')): + sys.path.append(osp.join(osp.dirname(__file__), '.')) +else: + sys.path.append(osp.join(osp.dirname(__file__), '../')) + +cusdirect = None +extpathisexist = False # ./_ext/customdirective.py是否存在的标志,用于设置指令扩展 + +extpath = osp.join(osp.dirname(__file__), './_ext') +if osp.exists(extpath): + # 自定义指令所在路径,加到搜索目录,否则无法使用自定义指令 + sys.path.append(extpath) + if osp.exists(osp.join(osp.dirname(__file__),r'./_ext/customdirective.py')): + extpathisexist = True + cusdirect = importlib.import_module('customdirective') +else: + extpath = osp.join(osp.dirname(__file__), '../_ext') + sys.path.append(extpath) + if osp.exists(osp.join(osp.dirname(__file__),r'../_ext/customdirective.py')): + extpathisexist = True + cusdirect = importlib.import_module('customdirective') + +import sphinx +import sphinx.errors as sperror +import parsejson + +try: + import breathe # 仅为了判断版本号,根据breathe版本号添加不同配置 + breathe_version = breathe.__version__[:3] +except Exception as e: + breathe_version = "" + +warnfile = '' # 告警文件,不包含路径 +warnfilepath = '' # 保存告警日志文件,包含完整的路径名 + +months=['January','February','March','April','May','June','July', + 'August','September','October','Novmber','December'] + +gmaketitle = r''' + \pagenumbering{Roman} + \begin{titlepage} + \centering + \vspace*{40mm} + \textbf{\Huge {%(titlename)s}} + + \vspace{10mm} + \textbf{\Large{%(release)s}} + + \vfill + \textbf{\large{%(today)s}} + \end{titlepage} + ''' + +def __ModifyMakeTitle(elementskeys,config): + ''' + 如果有maketitle变量则修改maketitle的内容 + :param config: 配置变量 + :return: + ''' + #得到封面名称 + titlename = config.latex_documents[0][2] + + if 'releasename' in elementskeys: + releasename = config.latex_elements['releasename'] + else: + releasename = '' + + if hasattr(config,'version'): + version = config.version + else: + version = '' + release = releasename +' ' + version + if hasattr(config,'today') and len(config.today) > 0: + today = config.today + else: + todayft = dt.date.today() + if hasattr(config,'language') and config.language=='zh_CN': + today = todayft.strftime('%Y 年 %m 月 %d 日') + else: + #除中文外,一律用英文时间格式 + #得到英文月的全称。 + month = months[int(todayft.strftime('%m'))-1] + today = month + todayft.strftime(" %d, %Y") + + if len(config.latex_elements['maketitle'].strip())==0: + maketitle = gmaketitle + else: + maketitle = config.latex_elements['maketitle'] + #必须转一下,否则赋值不成功 + config.latex_elements['maketitle'] = maketitle % { + 'titlename': titlename, + 'release': release, + 'today': today + } +def docview_modifyconfig(config): + """ + 以下配置需要在config_init事件中配置,否则有些变量无法读取到,导致配置不生效 + :param config: + :return: + """ + #配置静态路径 + if len(config.html_static_path)==0: + #如果配置了,就以配置路径为准,否则默认提供两个可选路径 + if osp.exists(curpath + '/_static'): + config.html_static_path = [curpath + '/_static'] + else: + config.html_static_path = [curpath + '/../_static'] + + texfilename = config.latex_documents[0][1] #得到tex文件名,构建pdf文件名 + filename,fileext = os.path.splitext(texfilename) + pdfname = filename+'.pdf' + #设置js脚本 + config.html_js_files = [('custom.js', {'pdfname': pdfname}),'cndeveloper.js'] + config.html_css_files = ['custom.css'] + +def modifycommonconfig(config): + """ + 以下配置为通用配置,为了配置生效,放在setup函数中调用。 + :param config: + :return: + """ + #设置html通用配置 + #设置html最新更新时间 + config.html_last_updated_fmt = '%Y-%m-%d %H:%M:%S' + config.html_theme = 'sphinx_rtd_theme' + config.html_copy_source = False + config.html_show_sphinx = False + + #设置latex通用配置 + config.latex_use_latex_multicolumn = True + config.latex_use_xindy = False + config.latex_engine = 'xelatex' + #修改通用配置 + config.smartquotes = False + config.numfig = True + config.numfig_secnum_depth = 1 + + + +def config_inited_handler(app, config): + # 检查年份是否正确,并修改。需配合最新parsejson.py文件使用,否则不支持该接口。 + # 通用配置的修改放在该位置,否则对配置的修改不生效 + try: + # 该函数的第二个参数可以传入版权声明文件相对于conf.py的相对路径,自动修改版权声明的年份 + # 比如:parsejson.CheckCurrentYearAndModify(app,'../source/copyright/cnperf_conpyright.rst') + # 默认路径为“./copyright/copyright.rst”和“./copyright/copyright_zh.rst”或者“./copyright/copyright_en.rst” + # 以上三个路径无需传入 + parsejson.CheckCurrentYearAndModify(app) + except Exception as e: + print('------------------') + print(e) + # traceback.print_stack() + traceback.print_exc() + print('------------------') + + try: + # print(sphinx.__version__) + # app.require_sphinx('3.5') + keys = config.latex_elements.keys() + if 'sphinxsetup' not in keys: + config.latex_elements['sphinxsetup'] = '' + if "3.5" <= sphinx.__display_version__[:3]: + config.latex_elements['sphinxsetup'] = 'verbatimforcewraps=true,verbatimmaxunderfull=2,' + \ + config.latex_elements['sphinxsetup'] + if "5.3" <= sphinx.__display_version__[:3]: + config.latex_table_style = ['standard', 'nocolorrows'] + config.latex_elements['sphinxsetup'] += 'pre_border-radius=0pt,' + if len(breathe_version) > 0 and "4.3" <= breathe_version: + config.latex_elements['preamble'] += ''' + \\renewenvironment{description} + {\\list{}{\\labelwidth=0pt + \\let\\makelabel\\descriptionlabel}} + {\\endlist} + ''' + #如果有自定义maketitle,则修改maketitle的内容 + if 'maketitle' in keys: + __ModifyMakeTitle(keys,config) + + docview_modifyconfig(config) + # sphinxsetup = config.latex_elements['sphinxsetup'] + # print('sphinxversion:%s;sphinxsetup:%s' % (sphinx.__version__,sphinxsetup)) + except Exception as e: + # print('sphinxversion:%s %s' % (sphinx.__version__,sphinx.__display_version__[:3])) + pass + + +def build_finished_handler(app, exception): + if exception != None: + # print(exception) + return + + # 判断告警文件是否存在,只有无告警或者告警全是忽略告警才允许继续后续的编译 + if warnfilepath != '' and osp.exists(warnfilepath): + # 判断告警文件中是否全是忽略告警 + iswarn = parsejson.warn_main(warnfilepath, app) + if iswarn: + # 如果为True则说明有不可忽略的告警,报sphinxerror异常,停止继续编译 + raise sperror.SphinxError('There are alarms, please check the file of %s for details' % warnfile) + return + + try: + if app.builder.name == "latex": + selffnlst = app.config.latex_documents + parsejson.Modifylatex_main(app.outdir, selffnlst, app) + except Exception as e: + print('------------------') + print(e) + traceback.print_exc() + + # 检查html标题一致性。该接口最好放在该位置,主索引文件的标题已经解析出来。 + # if app.builder.name == "html": + # result = parsejson.CheckProjectNameIsConsistent(app) + # if result != "": + # raise sperror.SphinxError(result) #带raise的异常如果用except捕捉,则在终端打印告警将不能显示为红色字体。 + + +def build_inited_handler(app): + global warnfile + global warnfilepath + + print(sys.argv) + args = sys.argv[1:] # 0为sphinx-build,需忽略掉 + if '-w' in args: + pos = args.index('-w') # 找到-w所在的索引位置 + warnfile = args[pos + 1] # 得到告警保存的文件名 + # print('warnfile=' + warnfile) + + # 根据工作路径,得到文件名的绝对路径 + # 当前在build阶段,因此工作路径为Makefile所在的目录,-w后面的文件保存在基于Makefile的相对路径下 + filepath = osp.join(os.getcwd(), warnfile) + warnfilepath = osp.abspath(filepath) + # print('warnfilepath = ' + warnfilepath) + + # try: + # # 检查是否有rst_prolog或者rst_epilog替换内容,有的话去掉前后空格 + # # 仅适用于pdf文件和中文文档,英文文档和html不启作用 + # checkDocobj = parsejson.clsCheckDocCopyrightYear(app, "") + # if app.builder.name == 'latex' or app.builder.name == 'latexpdf': + # checkDocobj.CheckReplaceContent() + # except Exception as e: + # print('------------------') + # print(e) + # traceback.print_exc() + # print('------------------') + + # try: + # # 检查html配置是否符合规范。必须放在这个位置,否则app没有builder对象。 + # if app.builder.name == "html": + # error = parsejson.CheckHtmlConfigIsCorrect(app) + # if error != "": + # raise sperror.ConfigError(error) + # except Exception as e: + # print('------------------') + # print(e) + # traceback.print_stack() + # traceback.print_exc() + # print('------------------') + + +def source_read_handler(app, docname, source): + if cusdirect is not None: + cnonlyobj = cusdirect.CNOnlyPro(app, docname, source[0]) + source[0] = cnonlyobj.parsecnonlycontent(source[0]) + + # try: + # 自动添加更新历史代码,默认添加“无内容更新” + # source[0] = parsejson.CheckUpdateHistory(app, docname, source[0]) + # except Exception as e: + # print('------------------') + # print(e) + # traceback.print_exc() + # print('------------------') + + return source + + +def doctree_resolved_handle(app, doctree, docname): + """ + 删除不需要文件的内容,避免编译html的时候还进行编译,导致内容泄漏 + """ + if not hasattr(app.config, 'enable_exclude_html') or not app.config.enable_exclude_html: + # 默认按sphinx方式编译所有rst文件到html + return + + indexname = app.config.master_doc + if docname == indexname: + return + indexdocnames = app.env.toctree_includes[indexname] + isexist = False + # 判断文档是否在索引文件里 + if docname not in indexdocnames: + newtoctree = app.env.toctree_includes.copy() + del newtoctree[indexname] # 删除主索引文件,因为无需比较 + for key in newtoctree.keys(): + values = newtoctree[key] + if docname in values: + isexist = True + # 判断该key是否在主索引文件里面 + if key not in indexdocnames: + # 如果不在整个的索引文件里面,删除该文件内容 + # doctree.attributes['source']='' + while len(doctree.children) > 0: + doctree.remove(doctree.children[0]) + # 删除标题,否则html页面还有标题,还需要维护标题 + for node in app.env.titles[docname].children: + app.env.titles[docname].remove(node) + if not isexist: + # 文档不在toctree字典里,直接删除内容 + while len(doctree.children) > 0: + doctree.remove(doctree.children[0]) + # 删除标题,否则html页面还有标题,还需要维护标题 + for node in app.env.titles[docname].children: + app.env.titles[docname].remove(node) + + +def setup(app): + """ + 该函数中的路径都是相对于工程builder环境的路径而不是相对于conf.py所在的路径, + 该函数外面的相对路径是相对于conf.py的路径没有问题。 + 比如./_ext相对路径是相对于工程的路径,因此如果使用abspath得到绝对路径会得到~/m0_docs/_ext/, + 而不是~ / mo_docs / source / _ext。 + 因此该路径下对路径的判断最好转化为绝对路径判断,否则可能会判断失败。 + :param app: + :return: + """ + #将通用配置放在这个位置,因为有些配置比如html_theme,在config_init事件前就已读取,不允许再修改 + #放在setup函数中,提前预配置项才能生效。 + modifycommonconfig(app.config) + + app.connect('config-inited', config_inited_handler) + app.connect('build-finished', build_finished_handler) + app.connect('builder-inited', build_inited_handler) + app.connect('source-read', source_read_handler) + app.connect('doctree-resolved', doctree_resolved_handle) + + app.add_config_value('chapterbkpaper_image', '', 'env', [str]) + # 是否排除未添加到index的html文件。在编译html时,sphinx默认会将所有rst文件都生成html,因此可能会造成部分内容的泄露。 + # 为了删除这部分内容,设置该变量。如果置为True的话,将删除不在index.rst里面的html的内容。 + # 默认为false,不做处理,以增加编译效率。 + # 建议还是将不参与编译的rst文件,放到exclude_patterns变量内,放到该变量后,rst不参与编译,能提高编译效率。 + app.add_config_value('enable_exclude_html', False, 'env', [bool]) + #配置版权声明文件完整路径,实现版权声明文件年份的自动修改。 + #parsejson.CheckCurrentYearAndModify如果该函数中传入了版权声明的全路径,则以该函数中传入的为准 + #在该处又增加了一个配置值,是为了配置和代码解偶,为了兼容性,保留了原来的配置方式 + app.add_config_value('copyfile_path', '', 'env', [str]) + #在conf.py中配置的conf.json路径,为了实现conf.json配置的灵活处理 + app.add_config_value('confjson_path', '', 'env', [str]) + + if extpathisexist: + app.setup_extension('customdirective') diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/index.rst new file mode 100644 index 0000000..ac40aac --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/index.rst @@ -0,0 +1,20 @@ +.. torch_mlu_ops_rst documentation master file, created by + sphinx-quickstart on Mon Apr 22 19:17:54 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +寒武纪Torch-MLU-Ops用户手册 +=========================================== + +.. toctree:: + :maxdepth: 5 + :caption: 目录 + :numbered: + + ./tmo_1_copyright/index + ./tmo_2_preface/index + ./tmo_3_overview/index + ./tmo_4_installation/index + ./tmo_5_key_features/index + ./tmo_6_best_practice/index + ./tmo_7_FAQ/index diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/logo.png b/torch_mlu_ops-v1.3.2/docs/user_guide/source/logo.png new file mode 100644 index 0000000..7b6ac72 Binary files /dev/null and b/torch_mlu_ops-v1.3.2/docs/user_guide/source/logo.png differ diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_1_copyright/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_1_copyright/index.rst new file mode 100755 index 0000000..0dd8354 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_1_copyright/index.rst @@ -0,0 +1,30 @@ +.. copyright + +版权声明 +========================================== + +**免责声明** + +中科寒武纪科技股份有限公司(下称“寒武纪”)不代表、担保(明示、暗示或法定的)或保证本文件所含信息,并明示放弃对可销售性、所有权、不侵犯知识产权或特定目的适用性做出任何和所有暗示担保,且寒武纪不承担因应用或使用任何产品或服务而产生的任何责任。寒武纪不应对因下列原因产生的任何违约、损害赔偿、成本或问题承担任何责任:(1)使用寒武纪产品的任何方式违背本指南;或(2)客户产品设计。 + +**责任限制** + +在任何情况下,寒武纪都不对因使用或无法使用本指南而导致的任何损害(包括但不限于利润损失、业务中断和信息损失等损害)承担责任,即便寒武纪已被告知可能遭受该等损害。尽管客户可能因任何理由遭受任何损害,根据寒武纪的产品销售条款与条件,寒武纪为本指南所述产品对客户承担的总共和累计责任应受到限制。 + +**信息准确性** + +本文件提供的信息属于寒武纪所有,且寒武纪保留不经通知随时对本文件信息或对任何产品和服务做出任何更改的权利。本指南所含信息和本指南所引用寒武纪文档的所有其他信息均“按原样”提供。寒武纪不担保信息、文本、图案、链接或本指南内所含其他项目的准确性或完整性。寒武纪可不经通知随时对本指南或本指南所述产品做出更改,但不承诺更新本指南。 + +本指南列出的性能测试和等级要使用特定芯片或计算机系统或组件来测量。经该等测试,本指南所示结果反映了寒武纪产品的大概性能。系统硬件或软件设计或配置的任何不同会影响实际性能。如上所述,寒武纪不代表、担保或保证本指南所述产品将适用于任何特定用途。寒武纪不代表或担保测试每种产品的所有参数。客户全权承担确保产品适合并适用于客户计划的应用以及对应用程序进行必要测试的责任,以期避免应用程序或产品的默认情况。 + +客户产品设计的脆弱性会影响寒武纪产品的质量和可靠性并导致超出本指南范围的额外或不同的情况和/或要求。 + +**知识产权通知** + +寒武纪和寒武纪的标志是中科寒武纪科技股份有限公司在中国和其他国家的商标和/或注册商标。其他公司和产品名称应为与其关联的各自公司的商标。 + +本指南为版权所有并受全世界版权法律和条约条款的保护。未经寒武纪的事先书面许可,不可以任何方式复制、重制、修改、出版、上传、发布、传输或分发本指南。除了客户使用本指南信息和产品的权利,根据本指南,寒武纪不授予其他任何明示或暗示的权利或许可。未免疑义,寒武纪不根据任何专利、版权、商标、商业秘密或任何其他寒武纪的知识产权或所有权对客户授予任何(明示或暗示的)权利或许可。 + +* 版权声明 + +* © 2024 中科寒武纪科技股份有限公司保留一切权利。 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_2_preface/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_2_preface/index.rst new file mode 100755 index 0000000..97744ca --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_2_preface/index.rst @@ -0,0 +1,112 @@ +更新历史 +========================================== +* **V1.3.2** + + **更新时间**: 2025年01月09日 + + **更新内容**: + +- ``fused_layer_norm`` 与 ``fused_rms_norm`` 的输入和输出支持更多的 ``stride`` 组合。请参考 :ref:`torch_mlu_ops.fused_rms_norm`、:ref:`torch_mlu_ops.fused_layer_norm` 章节。 + +* **V1.3.1** + + **更新时间**: 2024年12月30日 + + **更新内容**: + + - ``smooth_quant_group_gemm`` 与 ``fused_moe`` 算子支持w4w8混合量化。请参考 :ref:`torch_mlu_ops.smooth_quant_group_gemm` 和 :ref:`torch_mlu_ops.fused_moe` 章节。 + - 新增 ``dequant_from_paged_cache`` 算子,详细内容请参考 :ref:`torch_mlu_ops.dequant_from_paged_cache` 章节。 + - ``dequant_from_linear_cache`` 算子不再支持float32的 ``key`` 和 ``value``。 + - 新增 :ref:`gen_case` 功能。 + - 新增 :ref:`pt_example` 章节内容。 + +* **V1.3.0** + + **更新时间**: 2024年12月17日 + + **更新内容**: + + - ``group_gemm`` 与 ``smooth_quant_group_gemm`` 算子的 ``max_m`` 参数由可选修改为必选。请参考 :ref:`torch_mlu_ops.group_gemm` 和 :ref:`torch_mlu_ops.smooth_quant_group_gemm` 章节。 + - ``moe_softmax_topk`` 支持mask功能,请参考 :ref:`torch_mlu_ops.moe_softmax_topk` 章节。 + - 新增 ``dequant_from_linear_cache`` 算子,详细内容请参考 :ref:`torch_mlu_ops.dequant_from_linear_cache` 章节。 + +* **V1.2.3** + + **更新时间**: 2024年12月09日 + + **更新内容**: + + - ``single_query_cached_kv_attn`` 与 ``flash_attention`` 算子支持 ``head_size_qk != head_size_v``,请参考 :ref:`torch_mlu_ops.single_query_cached_kv_attn` 和 :ref:`torch_mlu_ops.flash_attention` 章节。 + - ``group_gemm`` 与 ``smooth_quant_group_gemm`` 算子支持bias,请参考 :ref:`torch_mlu_ops.group_gemm` 和 :ref:`torch_mlu_ops.smooth_quant_group_gemm` 章节。 + - ``matmul`` 删除原位输出参数,增加指定输出类型参数,请参考 :ref:`torch_mlu_ops.matmul` 章节。 + +* **V1.2.2** + + **更新时间**: 2024年11月08日 + + **更新内容**: + + - ``smooth_quant_group_gemm`` 算子新增int4 group量化功能,请参考 :ref:`torch_mlu_ops.smooth_quant_group_gemm` 章节。 + - allreduce类算子删除act_mode,请参考 :ref:`torch_mlu_ops.smooth_quant_matmul_allreduce` 和 :ref:`torch_mlu_ops.matmul_allreduce` 章节。 + - ``weight_only_quant_matmul`` 与 ``smooth_quant_matmul`` 添加use_hp_active参数,控制激活计算方式, 请参考 :ref:`torch_mlu_ops.weight_only_quant_matmul` 和 :ref:`torch_mlu_ops.smooth_quant_matmul` 章节。 + - ``active`` 算子新增quick_gelu和swish激活模式,请参考 :ref:`torch_mlu_ops.active` 章节。 + +* **V1.2.1** + + **更新时间**: 2024年11月04日 + + **更新内容**: + + - ``quant_to_linear_cache`` 算子新增group量化及int4量化功能,请参考 :ref:`torch_mlu_ops.quant_to_linear_cache` 章节。 + - 新增 ``moe_cast_gating`` 算子,请参考 :ref:`torch_mlu_ops.moe_cast_gating` 章节。 + - 新增 ``update_out_and_lse`` 算子,请参考 :ref:`torch_mlu_ops.update_out_and_lse` 章节。 + - ``fused_rope`` 算子支持int8/int4 kv cache。 + - ``matmul`` 与 ``batch_matmul`` 新增配置 ``trans_a``, ``trans_b`` 参数,请参考 :ref:`torch_mlu_ops.matmul` 和 :ref:`torch_mlu_ops.batch_matmul` 章节。 + - ``single_query_cached_kv_attn`` 支持int4 kv cache,支持返回lse。请参考 :ref:`torch_mlu_ops.single_query_cached_kv_attn` 章节。 + - 新增 ``single_query_mixed_cached_kv_attn`` 算子。 + - ``fused_layer_norm`` 与 ``fused_rms_norm`` 支持输出动态量化。请参考 :ref:`torch_mlu_ops.fused_rms_norm`、:ref:`torch_mlu_ops.fused_layer_norm` 章节。 + +* **V1.2.0** + + **更新时间**: 2024年10月24日 + + **更新内容**: + + - ``moe_softmax_topk`` 算子新增grouped_topk功能,请参考 :ref:`torch_mlu_ops.moe_softmax_topk` 章节。 + - ``moe_softmax_topk`` 算子不再支持原位功能,请参考 :ref:`torch_mlu_ops.moe_softmax_topk` 章节。 + - ``moe_gen_idx`` 算子不再支持原位功能,请参考 :ref:`torch_mlu_ops.moe_gen_idx` 章节。 + - 新增 ``fused rope`` 算子,请参考 :ref:`torch_mlu_ops.fused_rope` 章节。 + - ``matmul`` 算子支持INT8输入,请参考 :ref:`torch_mlu_ops.matmul` 章节。 + - 新增 ``batch_matmul`` 算子,请参考 :ref:`torch_mlu_ops.batch_matmul` 章节。 + +* **V1.1.4** + + **更新时间**: 2024年09月30日 + + **更新内容**: + + - 新增 ``offline_quant_to_paged_cache`` 算子,详见 :ref:`torch_mlu_ops.offline_quant_to_paged_cache` 章节。 + - 新增 ``moe_gen_idx`` 和 ``moe_expand_input`` 算子,详见 :ref:`torch_mlu_ops.moe_gen_idx` 和 `torch_mlu_ops.expand_input` 章节。 + - 新增 ``moe_combine_result`` 算子,详见 :ref:`torch_mlu_ops.moe_combine_result` 章节。 + - 新增 ``moe_gen_idx`` 算子,详见 :ref:`torch_mlu_ops.moe_gen_idx` 章节。 + - 删除 ``quant_matmul`` 算子,由 ``smooth_quant_matmul`` 和 ``weight_only_quant_matmul`` 实现其功能, 详见 :ref:`torch_mlu_ops.smooth_quant_matmul` 和 :ref:`torch_mlu_ops.weight_only_quant_matmul` 章节。 + - ``flash_attention`` 算子新增 ``block_tables`` 、 ``k/v_cache_quant_scale`` 参数。请参考 :ref:`torch_mlu_ops.flash_attention` 章节。 + - ``matmul`` 算子支持激活计算方式配置,请参考 :ref:`torch_mlu_ops.matmul` 章节。 + - FAQ章节新增对 ``flash_attention`` 算子的说明,请参考 :ref:`FAQ` 章节。 + - 新增 ``moe_active`` 算子,详见 :ref:`torch_mlu_ops.moe_active` 章节。 + +* **V1.1.3** + + **更新时间**: 2024年09月06日 + + **更新内容**: + + - BangTransformer更名为Torch-MLU-Ops, 定位PyTorch第三方算子库。 + - 原BangTransformer网络评测相关文档请参考《寒武纪BangTransformer用户手册》v1.1.2版本。 + - bt_ops的命名空间变化为torch_mlu_ops, 请参考 :ref:`torch_mlu_ops`。 + - torch_mlu_ops.single_query_cached_kv_attn算子支持per_token量化和per_channel量化。 + - torch_mlu_ops.fused_moe算子在非量化情况下支持EP模式,内部支持group gemm和allreduce并行。 + - 提供flash_attn_sq_mm_allreduce、matmul_allreduce、smooth_quant_matmul_allreduce通算融合算子,fused_moe支持通算融合。请参考 :ref:`torch_mlu_ops.flash_attn_sq_mm_allreduce`、 + :ref:`torch_mlu_ops.matmul_allreduce`、:ref:`torch_mlu_ops.smooth_quant_matmul_allreduce`、:ref:`torch_mlu_ops.fused_moe`。 + - torch_mlu_ops.flash_attention、torch_mlu_ops.single_query_cached_kv_attn、torch_mlu_ops.fused_rms_norm、torch_mlu_ops.fused_layer_norm支持原位输出。请参考 :ref:`torch_mlu_ops.flash_attention`、 + :ref:`torch_mlu_ops.single_query_cached_kv_attn`、:ref:`torch_mlu_ops.fused_rms_norm`、:ref:`torch_mlu_ops.fused_layer_norm`。 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_3_overview/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_3_overview/index.rst new file mode 100755 index 0000000..debed9b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_3_overview/index.rst @@ -0,0 +1,7 @@ +.. _简介: + +简介 +========================== + +.. toctree:: + tmo_overview diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_3_overview/tmo_overview.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_3_overview/tmo_overview.rst new file mode 100755 index 0000000..bf2cd6c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_3_overview/tmo_overview.rst @@ -0,0 +1,27 @@ +应用场景 +=============================== + +Torch-MLU-Ops是寒武纪设计和开发的PyTorch第三方算子库。对于使用PyTorch框架的开发者,通过Torch-MLU-Ops,能够便捷地使用这些自定义算子,进行算子的集成、评测和业务部署。 + +Torch-MLU-Ops已全量覆盖LLM(Large Language Model)推理场景下的常见算子。作为后端,已支持Cambricon vLLM、Cambricon TGI、Cambricon Stable Diffusion web UI、Cambricon ComfyUI以及Cambricon Diffusers。 + +.. _关键特性: + +关键特性 +=============================== +Torch-MLU-Ops在支持LLM模型部署时提供的关键特性如下所示: + +- PyTorch风格的算子API +- 支持LLM推理场景下的常见算子 +- Flash Attention支持Multi-head Attention、Multi-query Attention、Group-query Attention特性 +- 支持KV Cache INT8特性 +- Single Query Cached Attention算子支持Paged特性 +- 矩阵乘类算子支持Weight-only、AWQ、GPTQ、Smoothquant量化特性 +- 支持MOE场景下常见算子 +- 提供不同融合力度的融合算子 +- 提供通信——计算融合算子MC2 +- 支持Cambricon TGI调用 +- 支持Cambricon vLLM调用 +- 支持Cambricon Stable Diffusion web UI调用 +- 支持Cambricon ComfyUI调用 +- 支持Cambricon Diffusers调用 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_4_installation/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_4_installation/index.rst new file mode 100755 index 0000000..e555d7d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_4_installation/index.rst @@ -0,0 +1,7 @@ +.. _编译安装: + +安装和部署 +============================ + +.. toctree:: + tmo_installation diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_4_installation/tmo_installation.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_4_installation/tmo_installation.rst new file mode 100755 index 0000000..303a20e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_4_installation/tmo_installation.rst @@ -0,0 +1,165 @@ +.. _hostinstall: + +在主机上安装和部署 +========================================== + +环境要求 +-------------- + +- Torch-MLU-Ops在不同的操作系统上有不同的软件版本依赖。 + + 不同平台支持的软件版本如下: + + .. table:: Torch-MLU-Ops不同平台支持软件版本 + :widths: 20 20 20 + + +--------------------+----------------------+--------------------+ + | 平台 | 编译器版本 | CMake版本 | + +====================+======================+====================+ + | Ubuntu22.04 x86-64 | gcc 9.4.0或更高版本 | 3.10或更高版本 | + +--------------------+----------------------+--------------------+ + | Debian10.11 x86.64 | gcc 9.4.0或更高版本 | 3.10或更高版本 | + +--------------------+----------------------+--------------------+ + +- 编译、运行Torch-MLU-Ops同时还需要配置以下依赖。 + + .. table:: Torch-MLU-Ops配置依赖 + :name: cppdemorely + + +------------------+-------------------+---------+-------------+---------+-------------+ + | Torch-MLU-Ops | Torch-MLU | CNNL | CNNL_Extra | CNCL | CNToolkit | + +==================+===================+=========+=============+=========+=============+ + | v1.3.2 | v1.24.1 | v1.28.3 | v1.12.3 | v1.24.1 | v3.15.6 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.3.1 | v1.24.1 | v1.28.3 | v1.12.2 | v1.24.1 | v3.15.5 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.3.0 | v1.24.0 | v1.28.2 | v1.12.1 | v1.24.0 | v3.15.4 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.2.3 | v1.24.0 | v1.28.1 | v1.12.1 | v1.24.0 | v3.15.3 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.2.2 | v1.23.1 | v1.27.4 | v1.11.3 | v1.23.0 | v3.14.3 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.2.1 | v1.23.1 | v1.27.3 | v1.11.3 | v1.22.3 | v3.14.2 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.2.0 | v1.23.0 | v1.27.1 | v1.11.1 | v1.22.1 | v3.14.1 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.1.4 | v1.22.2 | v1.26.7 | v1.10.4 | v1.21.1 | v3.13.7 | + +------------------+-------------------+---------+-------------+---------+-------------+ + | v1.1.3 | v1.22.2 | v1.26.6 | v1.10.3 | v1.21.1 | v3.13.5 | + +------------------+-------------------+---------+-------------+---------+-------------+ + +- 此外运行Torch-MLU-Ops还依赖PyTorch环境。Torch-MLU-Ops已支持的PyTorch版本请参考。 + + .. table:: Torch-MLU-Ops支持的PyTorch版本 + :name: pytorchdepends + + +------------------+-------------------+ + | Torch-MLU-Ops | PyTorch | + +==================+===================+ + | v1.3.2 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.3.1 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.3.0 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.2.3 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.2.2 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.2.1 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.2.0 | v2.1, v2.4, v2.5 | + +------------------+-------------------+ + | v1.1.4 | v2.1, v2.3, v2.4 | + +------------------+-------------------+ + | v1.1.3 | v2.1, v2.3 | + +------------------+-------------------+ + + +通过docker镜像启动 +-------------------- + +寒武纪提供的PyTorch解决方案镜像已提供了Torch-MLU-Ops所需要的依赖,请参考PyTorch解决方案镜像的使用手册启动镜像。 + + +编译安装 +-------------------- + +Torch-MLU-Ops编译脚本依赖于NEUWARE_HOME环境变量,该环境变量指向寒武纪CNToolkit的安装路径,通常为 ``/usr/local/neuware`` 。 + +在运行示例程序前,您需要将 ``NEUWARE_HOME`` 中的 ``lib64`` 目录添加到 ``LD_LIBRARY_PATH`` 环境变量中,例如: + +:: + + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NEUWARE_HOME/lib64 + +如果您使用寒武纪提供的docker镜像,以上环境变量均已经被设置好了。 + +设置好 ``NEUWARE_HOME`` 环境变量后,您可以使用以下命令编译Torch-MLU-Ops。 + +从远端clone仓库,进入项目的主目录: + +:: + + cd torch_mlu_ops + +源码安装: + +:: + + pip install -e . + +wheel包安装: + +:: + + python setup.py bdist_wheel + pip install dist/torch_mlu_ops*.whl + +.. attention:: + + 如果安装过程中遇到权限问题,请获取相关权限或切换到拥有权限的user执行pip install 这步操作。 + +移除Torch-MLU-Ops(Optional): + +:: + + pip uninstall torch_mlu_ops + + +测试 +--------------------------------------- + +在测试前,请确保Torch-MLU-Ops已安装。可以通过 ``pip list`` 命令查询Torch-MLU-Ops的安装情况。若有如下类似打印,则表明安装成功。 + +:: + + torch-mlu-ops 1.1.0+pt21 + +Torch-MLU-Ops模块的导入方法请参考。 + +:: + + import torch_mlu_ops as tmo + +在 ``./tests/ops_pytest`` 目录下,提供了 ``ops`` 级别的测试示例程序,您可以参考如下命令进行测试。 + +:: + + cd tests/ops_pytest/ + ./run_test.sh + +或者单独测试某个测例,例如: + +:: + + python test_flash_attention.py + + +在 ``./tests/kernels_pytest`` 目录下,提供了 ``kernels`` 级别的测试示例程序,您可以参考如下命令进行测试。 + +:: + + cd tests/kernels_pytest + ./build.sh # 需要先编译 + ./run_test.sh diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_5_key_features/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_5_key_features/index.rst new file mode 100755 index 0000000..7578fb6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_5_key_features/index.rst @@ -0,0 +1,7 @@ +.. _重点特性: + +重点特性 +========================== + +.. toctree:: + tmo_key_features diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_5_key_features/tmo_key_features.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_5_key_features/tmo_key_features.rst new file mode 100755 index 0000000..0f49502 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_5_key_features/tmo_key_features.rst @@ -0,0 +1,2743 @@ +.. _torch_mlu_ops: + +torch_mlu_ops算子 +========================================== + +torch_mlu_ops算子为Torch-MLU-Ops设计和封装的PyTorch风格的算子API,能够以灵活的方式集成到社区或者客户的LLM推理引擎框架中,加速客户的LLM业务部署。 + +.. list-table:: torch_mlu_ops算子 + :widths: 2 4 + :header-rows: 1 + :class: longtable + + * - 序号 + - 算子名称 + + * - 1 + - :ref:`torch_mlu_ops.active` + * - 2 + - :ref:`torch_mlu_ops.apply_rotary` + * - 3 + - :ref:`torch_mlu_ops.attention_project` + * - 4 + - :ref:`torch_mlu_ops.batch_matmul` + * - 5 + - :ref:`torch_mlu_ops.copy_blocks` + * - 6 + - :ref:`torch_mlu_ops.copy_blocks_out_of_place` + * - 7 + - :ref:`torch_mlu_ops.dequant_from_linear_cache` + * - 8 + - :ref:`torch_mlu_ops.dequant_from_paged_cache` + * - 9 + - :ref:`torch_mlu_ops.ffn` + * - 10 + - :ref:`torch_mlu_ops.flash_attention` + * - 11 + - :ref:`torch_mlu_ops.flash_attn_sq_mm_allreduce` + * - 12 + - :ref:`torch_mlu_ops.fused_layer_norm` + * - 13 + - :ref:`torch_mlu_ops.fused_moe` + * - 14 + - :ref:`torch_mlu_ops.fused_norm_attention_project` + * - 15 + - :ref:`torch_mlu_ops.fused_norm_residual_ffn` + * - 16 + - :ref:`torch_mlu_ops.fused_rms_norm` + * - 17 + - :ref:`torch_mlu_ops.fused_rope` + * - 18 + - :ref:`torch_mlu_ops.group_gemm` + * - 19 + - :ref:`torch_mlu_ops.matmul` + * - 20 + - :ref:`torch_mlu_ops.matmul_allreduce` + * - 21 + - :ref:`torch_mlu_ops.moe_active` + * - 22 + - :ref:`torch_mlu_ops.moe_cast_gating` + * - 23 + - :ref:`torch_mlu_ops.moe_combine_result` + * - 24 + - :ref:`torch_mlu_ops.moe_expand_input` + * - 25 + - :ref:`torch_mlu_ops.moe_gen_idx` + * - 26 + - :ref:`torch_mlu_ops.moe_softmax_topk` + * - 27 + - :ref:`torch_mlu_ops.offline_quant_to_linear_cache` + * - 28 + - :ref:`torch_mlu_ops.offline_quant_to_paged_cache` + * - 29 + - :ref:`torch_mlu_ops.preload` + * - 30 + - :ref:`torch_mlu_ops.quant_to_linear_cache` + * - 31 + - :ref:`torch_mlu_ops.quant_to_paged_cache` + * - 32 + - :ref:`torch_mlu_ops.quantize` + * - 33 + - :ref:`torch_mlu_ops.reshape_linear_cache` + * - 34 + - :ref:`torch_mlu_ops.reshape_paged_cache` + * - 35 + - :ref:`torch_mlu_ops.single_query_cached_kv_attn` + * - 36 + - :ref:`torch_mlu_ops.single_query_mixed_cached_kv_attn` + * - 37 + - :ref:`torch_mlu_ops.smooth_quant_matmul` + * - 38 + - :ref:`torch_mlu_ops.smooth_quant_matmul_allreduce` + * - 39 + - :ref:`torch_mlu_ops.swap_blocks` + * - 40 + - :ref:`torch_mlu_ops.update_out_and_lse` + * - 41 + - :ref:`torch_mlu_ops.weight_only_quant_matmul` + + +详细算子接口说明请参考以下章节。 + +.. _torch_mlu_ops.active: + +torch_mlu_ops.active +------------------------------------------ + +.. code:: + + torch_mlu_ops.active(input, act_mode, is_gated, active_coef)-> torch.Tensor + +**功能描述** + +激活算子。 + +**公式说明** + +* 公式如下: + + .. math:: + + \begin{array}{lcl} + C = input.shape[-1]\\ + if is\_gated = True:\\ + output = active(input[..., :C//2]) * input[..., C//2:]\\ + else: \\ + output = active(input) + \end{array} + +* 公式中各参数描述如下: + + - ``input`` :active算子的输入。 + - ``is_gated``:是否是gated结构。 + +* 公式中各函数描述如下: + + - ``active`` :激活算子。 + +**参数说明** + +- ``input``:Tensor类型,必选。输入维度 ``[..., C]``。 其中 ``C`` 为通道的个数,在最低维。 +- ``act_mode``:string类型。可以是 ``silu`` , ``gelu`` , ``quick_gelu`` 或者 ``swish``。 +- ``is_gated``:Bool类型。当 ``is_gated=true`` 时,为gated激活。 +- ``active_coef``:FLOAT类型。默认值为1.0,只有在act_mode为swish的情况下才需要额外设置。 + +**返回值说明** + +- ``output``:激活后返回的tensor。 + +**规格限制** + +- 数据类型支持: + + - input:FP32、FP16、BF16 + - output:FP32、FP16、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_add_bias_activation.py`` 。 + + +.. _torch_mlu_ops.apply_rotary: + +torch_mlu_ops.apply_rotary +------------------------------------------ + +.. code:: + + torch_mlu_ops.apply_rotary(input,sin_cache,cos_cache,position_ids,cu_seqlens,interleaved,discrete,dynamic_ntk,max_seqlen)->Tensor + +**功能描述** + +计算Attention模块中将rotary embedding加在Q、K上的算子。 + +**公式说明** + +请参考相关的GPT positon embedding文献:https://kexue.fm/archives/8265。该算子实现的主要功能为查表和加操作。 + + +**参数说明** + +- ``input``:Tensor类型,必选。Q或者K,pack模式下,维度 ``[total_seqlen, head_num, head_size]``,非pack模式下,维度 ``[batch_size, total_seqlen, head_num, head_size]``,其中 ``head_num`` 表示transformer模块中的头数,``head_size`` 表示每个头的维度。 +- ``sin_cache``:Tensor类型,必选。输入计算rotary embedding的sin表,可能不是连续。当 ``dynamic_ntk=true`` 时,维度 ``[batch_size, rotary_seq_len, rotary_pos_emb_dim]``;当 ``dynamic_ntk=false`` 时,维度 ``[rotary_seq_len, rotary_pos_emb_dim]``。其中 ``rotary_seq_len`` 一般设置为大于等于 ``max_seqlen`` ,``rotary_pos_emb_dim`` 在full rotary情况下,一般等于 ``hidden_size``,在partial rotary情况下,需要根据网络进行配置。 +- ``cos_cache``:Tensor类型,必选。输入计算rotary embedding的cos表,可能不是连续。当 ``dynamic_ntk=true`` 时,维度 ``[batch_size, rotary_seq_len, rotary_pos_emb_dim]``,当 ``dynamic_ntk=false`` 时,维度 ``[rotary_seq_len, rotary_pos_emb_dim]``。 +- ``position_ids``:Tensor类型,可选。当 ``discrete=true``,维度 ``[total_seqlen]`` 当 ``discrete=false``,维度 ``[batch_size]`` 。 +- ``cu_seqlens``:Tensor类型,可选。在pack模式下,记录了input中每个batch的序列长度,维度 ``[batch + 1]``。表示当前batch的sequence长度与下一个batch的sequence长度的差值。如果不设置,则input中每个batch的序列长度为 ``max_seqlens``。 +- ``interleaved``:Bool类型,必选。表示rotary embedding的方式。当 ``interleaved=true`` 时,rotary embedding的方式为 ``cross`` ,当 ``interleaved=false`` 时,rotary embedding的方式为 ``fold`` 。 +- ``discrete``:Bool类型,必选。discrete表示position_ids是否是离散的。如果是离散的,需要给出每个token的完整的 ``position_ids``,否则只需给出起始token的 ``position_ids`` ,算子内部根据起始 ``position_ids`` 和 ``seqlen`` 确定每个token的 ``position_ids``。默认情况下设置 ``discrete=False``。 +- ``dynamic_ntk``:Bool类型,必选。``dynamic_ntk=true`` 时,表示所有的 ``batch`` 有不同的sin和cos值。当 ``dynamic_ntk=False`` 时,表示所有的 ``batch`` 有相同的sin和cos值。默认情况下设置 ``dynamic_ntk=False``。 +- ``max_seqlen``:Int类型,必选。context阶段,在pad模式下,等于 ``pad_seqlen`` ,pack模式下等于 ``cu_seqlens[-1]``;generate阶段,等于 ``valid_token.max()`` + 解码帧数。 + +**返回值说明** + +- ``output``:Tensor类型,必选。rotary embedding加到Q或者K上的输出,维度 ``[total_seqlen, head_num, head_size]``。其中 ``total_seqlen`` pad模式下等于 ``batch*pad_seqlen``, pack模式下等于 ``valid_seqlen`` 之和,即 ``cu_seqlens[-1]``。 + + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - sin_cache:FP16、BF16 + - cos_cache:FP16、BF16 + - position_ids:Int32 + - cu_seqlens:Int32 + - output:FP16、BF16 + +**备注** + +- 该算子中sin_cache/cos_cache的维度及生成方式和vLLM存在差异,具体表现在: + + - rope维度:vLLM中sin_cache/cos_cache的rope的维度是 ``rotary_pos_emb_dim / 2``, torch_mlu_ops中sin_cache/cos_cache的rope的维度是完整的rotary_pos_emb_dim; + - 生成方式: + + * vLLM中生成fold/cross旋转角度公式 : ``1.0 / (base ** (torch.arange(0, rotary_pos_emb_dim, 2) / rotary_pos_emb_dim))``; + * torch_mlu_ops中生成fold模式旋转角度公式 : ``1.0 / (base ** (torch.arange(0, rotary_pos_emb_dim, 1) % half_rotary_pos_emb_dim / rotary_pos_emb_dim))``, ``half_rotary_pos_emb_dim = rotary_pos_emb_dim // 2``; + * torch_mlu_ops中生成cross模式旋转角度公式 : ``1.0 / (base ** (torch.arange(0, rotary_pos_emb_dim, 1) // 2 * 2 / rotary_pos_emb_dim))``。 + +- 差异性带来的影响: + + - 需要适配旋转角度的生成公式。 + - 无论sin_cache/cos_cache是何种维度形式, 都是针对特定硬件平台的设计的。 + + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_apply_rotary.py`` 。 + + +.. _torch_mlu_ops.attention_project: + +torch_mlu_ops.attention_project +------------------------------------------ + +.. code:: + + torch_mlu_ops.attention_project(input,weight,bias,residual,alpha,beta)->Tensor + + +**功能描述** + +计算Attention模块中的Q、K、V算子。适用于GPT或者 ``hidden_size`` 大于2048的transformer网络。 + +**公式说明** + +* 公式如下: + + .. math:: + + \begin{array}{lcl} + Q = fc(input, weight\_q) + bias\_q\\ + K = fc(input, weight\_k) + bias\_k\\ + V = fc(input, weight\_v) + bias\_v\\ + \end{array} + + 当 ``residual`` 存在时: + + .. math:: + + \begin{array}{lcl} + Q_out = alpha*Q + beta*residual \\ + K_out = alpha*K + beta*residual\\ + V_out = alpha*V + beta*residual\\ + \end{array} + +* 公式中各参数描述如下: + + - ``input`` :attention_project算子的输入。 + - ``weight_q``、 ``weight_k`` 、 ``weight_v`` :Query、Key、Value矩阵的权重。 + - ``bias_q``、 ``bias_k`` 、 ``bias_v`` :Query、Key、Value矩阵的偏置。 + +* 公式中各函数描述如下: + + - ``fc`` :全连接层(fully-connected)的缩写。 + +**参数说明** + +- ``input``:Tensor类型,必选。``attention_project`` 的输入。维度 ``[batch_size, seq_lens, input_size]``。其中 ``batch_size`` 表示batch的大小,``seq_lens`` 表示sequence的长度,``input_size`` 表示输入的大小。 +- ``weight``:Tensor类型,必选。``attention_project`` 的矩阵权重参数。Q、K、V权重参数被pack在一起。在MHA场景下,维度 ``[3*hidden_size, input_size]``。在MQA场景下,维度 ``[(head_num+2*head_num_kv)*head_size, input_size]``。其中 ``hidden_size`` 表示隐层的大小。 +- ``bias``:Tensor类型,可选。``attention_project`` 的矩阵偏置参数。Q、K、V偏置参数被pack在一起。在MHA场景下,维度 ``[3*hidden_size]``。在MQA场景下,维度 ``[(head_num+2*head_num_kv)*head_size]``。其中 ``head_num_kv`` 表示MQA场景下kv的数量。 +- ``residual``: Tensor类型,可选。``attention_project`` 的参差参数。维度与 ``input`` 保持一致。 +- ``alpha`` : float类型。默认值为1.0。当 ``residual`` 存在时,作为Q、K、V的系数。 +- ``beta`` : float类型。默认值为1.0。当 ``residual`` 存在时,作为 ``residual`` 的系数。 + + +**返回值说明** + +- output: Tensor类型,必选。attention_project的输出。Q、K、V被pack在一起。在MHA场景下,维度[batch_size, seq_lens, 3*hidden_size]。在MQA场景下,维度[batch_size, seq_lens, (head_num+2*head_num_kv)*head_size]。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - weight:FP16、BF16 + - bias:FP16、BF16 + - output:FP16、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_attn_proj.py`` 。 + + +.. _torch_mlu_ops.batch_matmul: + +torch_mlu_ops.batch_matmul +------------------------------------------ + + torch_mlu_ops.batch_matmul(input, filter, residual, alpha, beta, a_scale, b_scale, trans_a, trans_b)-> torch.Tensor + +**功能描述** + +通用的批量矩阵乘法算子。 + +**公式说明** + +* 批量matmul标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + mul_out = batch_matmul(input, filter)\\ + residual = alpha * mul_out + beta * residual\\ + output = residual + \end{array} + +* 公式中各参数描述如下: + + - ``input`` :matmul的左矩阵。 + - ``filter``:matmul的右矩阵。 + - ``residual``:参差。 + - ``output`` :计算最终的输出。 + +* 公式中各函数描述如下: + + - ``batch_matmul`` :批量矩阵乘。 + +**参数说明** + +- ``input``:Tensor类型,必选。矩阵乘的左矩阵。维度 ``[batch, mat_m, mat_k]``。 +- ``filter``:Tensor类型,必选。矩阵乘的右矩阵。维度 ``[batch, mat_n, mat_k]``。 +- ``residual``:Tensor类型,可选。参差。维度 ``[batch, mat_m, mat_n]``,如果有,计算结束后residual会被改变。 +- ``alpha``:Float32类型,必选。矩阵乘法系数。 +- ``beta``:Float32类型,必选。参差的系数。 +- ``a_scale``:Float32类型,默认值为1.0。当input为INT8类型时对应的量化系数。 +- ``b_scale``:Float32类型,默认值为1.0。当weight为INT8类型时对应的量化系数。 +- ``trans_a``:Bool类型,默认值为False,标记input是否需要转置。 +- ``trans_b``:Bool类型,默认值为True,标记weight是否需要转置。 + +**返回值说明** + +- ``out``:Tensor类型,算子的输出,维度 ``[batch, mat_m, mat_n]``。 + +**规格限制** + +- 数据类型支持: + - input:FP16、BF16、FP32、INT8 + - filter:FP16、BF16、FP32、INT8 + - residual:FP16、BF16、FP32 + - out:FP16、BF16、FP32 + - alpha:FP32 + - beta:FP32 + - a_scale:FP32 + - b_scale:FP32 + - trans_a:BOOL + - trans_b:BOOL + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_batch_matmul.py`` 。 + + +.. _torch_mlu_ops.copy_blocks: + +torch_mlu_ops.copy_blocks +---------------------------------------------------------------- + +.. code:: + + torch_mlu_ops.copy_blocks(k_caches, v_caches, block_mapping) + +**功能描述** + +将 ``k_caches``、 ``v_caches`` 中源位置的BLOCK复制到指定位置。 + +**公式说明** + +无。 + +**参数说明** + +- ``k_caches``:Vector类型,必选。输入维度 ``[num_layers]``。其中 ``num_layers`` 表示layer数量, ``[num_heads, block_size, head_size]`` 整体大小作为每次需要copy的BLOCK大小。 ``num_blocks`` 表示BLOCK数量; ``block_size`` 表示输入序列中的特定块或子序列大小; ``num_heads`` 表示transformer模块中的头数; ``head_size`` 表示每个头的维度。 +- ``v_caches``:Vector类型,必选。输入维度 ``[num_layers]``。其中每个参数含义与 ``k_caches`` 一致。 +- ``block_mapping``:Map>类型,必选。第一个int64表示需要copy的源BLOCK位置(小于num_blocks),第二个int64表示将源BLOCK copy到的目标位置(小于num_blocks), ``vector`` 表示一个源BLOCK可以copy到多个目标位置。 + +**返回值说明** + +- 无,原位操作。 + +**规格限制** + +- 数据类型支持: + + - k_caches:FP16、FP32、BF16、INT8、UINT8、INT16、INT32、INT64 + - v_caches:FP16、FP32、BF16、INT8、UINT8、INT16、INT32、INT64 + - block_mapping:INT64 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_copy_blocks.py`` 。 + + +.. _torch_mlu_ops.copy_blocks_out_of_place: + +torch_mlu_ops.copy_blocks_out_of_place +------------------------------------------------------------------------------------------------------------ + +.. code:: + + torch_mlu_ops.copy_blocks_out_of_place(k_caches, v_caches, block_mapping)->Tuple(k_caches, v_caches) + +**功能描述** + +将 ``k_caches``、 ``v_caches`` 中源位置的BLOCK复制到指定位置。 + +**公式说明** + +无。 + +**参数说明** + +- ``k_caches``:Vector类型,必选。输入维度 ``[num_layers]``。其中 ``num_layers`` 表示layer数量, ``[num_heads, block_size, head_size]`` 整体大小作为每次需要copy的BLOCK大小。 ``num_blocks`` 表示BLOCK数量; ``block_size`` 表示输入序列中的特定块或子序列大小; ``num_heads`` 表示transformer模块中的头数; ``head_size`` 表示每个头的维度。 +- ``v_caches``:Vector类型,必选。输入维度 ``[num_layers]``。其中每个参数含义与 ``k_caches`` 一致。 +- ``block_mapping``:Map>类型,必选。第一个int64表示需要copy的源BLOCK位置(小于num_blocks),第二个int64表示将源BLOCK copy到的目标位置(小于num_blocks), ``vector`` 表示一个源BLOCK可以copy到多个目标位置。 + +**返回值说明** + +- k_caches, v_caches。 + +**规格限制** + +- 数据类型支持: + + - k_caches:FP16、FP32、BF16、INT8、UINT8、INT16、INT32、INT64 + - v_caches:FP16、FP32、BF16、INT8、UINT8、INT16、INT32、INT64 + - block_mapping:INT64 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_copy_blocks.py`` 。 + + +.. _torch_mlu_ops.dequant_from_linear_cache: + +torch_mlu_ops.dequant_from_linear_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.dequant_from_linear_cache(key,value,key_cache,value_cache,key_cache_scale,value_cache_scale,context_lengths,max_context_len,context_seq_offset,cache_bs_id,cache_seqlen_offset,quant_mode,quant_bit) + +**功能描述** + +将计算得到的K-V cache拼接到K-V上的反量化算子。 + + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[total_seq_len, head_num, head_size]``。 +- ``value``:Tensor类型,可选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[total_seq_len, head_num, head_size]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。当 ``quant_bit`` 为8时,维度为 ``[max_batch, head_num, cache_mem_len, head_size]``;当 ``quant_bit`` 为4时,维度为 ``[max_batch, head_num, cache_mem_len, head_size//2]``。其中 ``max_batch`` 要求大于等于 ``batch_size``,``cache_mem_len`` 为网络支持的最大sequences长度。 +- ``value_cache``:Tensor类型,可选。Transformer模块中保存 ``Value cache`` 的Tensor。当 ``quant_bit`` 为8时,维度为 ``[max_batch, head_num, cache_mem_len, head_size]``;当 ``quant_bit`` 为4时,维度为 ``[max_batch, head_num, cache_mem_len//2, head_size]``。 +- ``key_cache_scale``:Tensor类型,必选。Transformer模块中保存的 ``Key`` 的 ``scale`` 值。在context阶段,维度为 ``per-channel`` 量化 ``[head_num, head_size]`` 或者 ``per-head`` 量化 ``[max_batch, head_num, cache_mem_len]``。 +- ``value_cache_scale``:Tensor类型,可选。Transformer模块中保存的 ``Value`` 的 ``scale`` 值。在context阶段,维度为 ``per-channel`` 量化 ``[head_num, head_size]`` 或者 ``per-head`` 量化 ``[max_batch, head_num, cache_mem_len]``。 +- ``context_lengths``:Tensor类型,必选。为每个batch的sequences length长度。 +- ``max_context_len``:Int类型,必选。在context阶段,表示最大的sequence length。 +- ``context_seq_offset``:Tensor类型,可选。表示context的seq的累积偏移,需要保证每个batch的sequences不重合,维度 ``[batch_size]``。如果为空,则默认为context_lens的累加和。 +- ``cache_bs_id``:Tensor类型,可选。表示cache的batch的偏置,维度 ``[batch_size]``。如果为空,则默认值为{0, 1, 2 ... batch - 1}。 +- ``cache_seq_offset``:Tensor类型,可选。表示cache的seq_lens的偏置,维度 ``[batch_size]``。如果为空,则默认值为{0, 0, 0 ... 0}。 +- ``quant_mode``: Int类型,可选。表示量化的模式,可选0(表示per-channel)或者1(表示per-head),默认值为0。 +- ``quant_bit``:Int类型,可选。表示量化结果的位宽,可选8(表示int8)或4(表示int4),默认值为8。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16 + - value:FP16、BF16 + - key_cache:INT8 + - value_cache:INT8 + - key_cache_scale: FP32 + - value_cache_scale: FP32 + - context_lengths:INT32 + - max_context_len:INT32 + - context_seq_offset:INT32 + - cache_bs_id: INT32 + - cache_seqlen_offset: INT32 + - quant_mode: INT32 + - quant_bit: INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_dequant_from_linear_cache.py`` 。 + + +.. _torch_mlu_ops.dequant_from_paged_cache: + +torch_mlu_ops.dequant_from_paged_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.dequant_from_paged_cache(key,value,key_cache,value_cache,key_cache_scale,value_cache_scale,context_lengths,max_context_len,context_seq_offset,block_tables,quant_mode,quant_bit) + +**功能描述** + +将计算得到的K-V cache(Paged模式)拼接到K-V上的反量化算子。 + + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[total_seq_len, head_num, head_size]``。 +- ``value``:Tensor类型,可选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[total_seq_len, head_num, head_size]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。当 ``quant_bit`` 为8时,维度为 ``[block_num, head_num, block_size, head_size]``。 +- ``value_cache``:Tensor类型,可选。Transformer模块中保存 ``Value cache`` 的Tensor。当 ``quant_bit`` 为8时,维度为 ``[block_num, head_num, block_size, head_size]``。 +- ``key_cache_scale``:Tensor类型,必选。Transformer模块中保存的 ``Key`` 的 ``scale`` 值。在context阶段,维度为 ``per-channel`` 量化 ``[head_num, head_size]`` 或者 ``per-head`` 量化 ``[block_num, head_num, block_size]``。 +- ``value_cache_scale``:Tensor类型,可选。Transformer模块中保存的 ``Value`` 的 ``scale`` 值。在context阶段,维度为 ``per-channel`` 量化 ``[head_num, head_size]`` 或者 ``per-head`` 量化 ``[block_num, head_num, block_size]``。 +- ``context_lengths``:Tensor类型,必选。为每个batch的sequences length长度。 +- ``max_context_len``:Int类型,必选。在context阶段,表示最大的sequence length。 +- ``context_seq_offset``:Tensor类型,可选。表示context的seq的累积偏移,需要保证每个batch的sequences不重合,维度 ``[batch_size]``。如果为空,则默认为context_lens的累加和。 +- ``block_tables``:Tensor类型,可选。表示cache所在的block序号,维度 ``[batch_size, max_block_num]``。 +- ``quant_mode``: Int类型,可选。表示量化的模式,可选0(表示per-channel)或者1(表示per-head),默认值为0。 +- ``quant_bit``:Int类型,可选。表示量化结果的位宽,可选8(表示int8),默认值为8。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16 + - value:FP16、BF16 + - key_cache:INT8 + - value_cache:INT8 + - key_cache_scale: FP32 + - value_cache_scale: FP32 + - context_lengths:INT32 + - max_context_len:INT32 + - context_seq_offset:INT32 + - block_tables: INT32 + - quant_mode: INT32 + - quant_bit: INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_dequant_from_paged_cache.py`` 。 + + +.. _torch_mlu_ops.ffn: + +torch_mlu_ops.ffn +------------------------------------------ + +.. code:: + + torch_mlu_ops.ffn(input,up_fc_weight,up_fc_bias,down_proj_weight,down_proj_bias,gate_up_proj_weight,gate_up_proj_bias,act_mode)->Tensor + +**功能描述** + +计算transformer中的FFN(FeedForward)前馈神经网络模块的算子。适用于GPT或者 ``hidden_size`` 大于2048的transformer网络。 + +**公式说明** + +* FFN标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + fc\_output = fc(input, up\_fc\_weight) + up\_fc\_bias\\ + active\_output = activation(fc\_output)\\ + basic\_output = fc(active\_output, down\_proj\_weight) + down\_proj\_bias + \end{array} + +* Gated FFN标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + fc\_output = fc(input, up\_fc\_weight) + up\_fc\_bias\\ + active\_output = activation(fc\_output)\\ + gated_fc\_output = fc(input, gate\_up\_proj\_weight) + gate\_up\_proj\_bias\\ + elementwise\_product\_output = elementwise\_product(active\_output, gated\_fc\_output)\\ + gated\_output = fc(elementwise\_product\_output, down\_proj\_weight) + down\_proj\_bias\\ + \end{array} + +* 公式中各参数描述如下: + + - ``input`` :FFN模块的输入。 + - ``fc_output`` :FFN模块经过升维后的输出。 + - ``active_output`` :FFN模块经过激活后的输出。 + - ``basic_output`` :标准FFN模块经过降维后的输出。 + - ``gated_fc_output`` :gated FFN旁路经过升维后的输出。 + - ``elementwise_product_output`` :gated FFN主路和旁路经过elementwise product后的输出。 + - ``gated_output`` :gated FFN模块经过降维后的输出。 + + +* 公式中各函数描述如下: + + - ``fc`` :全连接层(fully-connected)的缩写。 + - ``activation`` :用户可配的激活函数,可选的激活函数:"relu"、"gelu"、"silu"。 + - ``elementwise_product`` :点积。 + +**参数说明** + +- ``input``:Tensor类型,必选。FFN算子的输入。维度[batch_size, seq_lens, hidden_size]。 +- ``up_fc_weight``:Tensor类型,必选。FFN算子中升维矩阵权重参数,维度[ffn_inner_size, hidden_size]。 +- ``up_fc_bias``:Tensor类型,可选。FFN算子中升维矩阵偏置参数,维度[ffn_inner_size]。其中ffn_inner_size表示FFN中间层的维度。 +- ``down_proj_weight``:Tensor类型,必选。FFN算子中降维矩阵权重参数,维度[hidden_size, ffn_inner_size]。 +- ``down_proj_bias``:Tensor类型,可选。FFN算子中降维矩阵偏置参数,维度[hidden_size]。 +- ``gate_up_proj_weight``:Tensor类型,可选。Gated FFN算子中升维矩阵权重参数,维度[hidden_size, ffn_inner_size]。 +- ``gate_up_proj_bias``:Tensor类型,可选。Gated FFN算子中升维矩阵偏置参数,维度[hidden_size]。 +- ``act_mode``:string类型,必选。FFN的激化函数类型,可选的激活函数:"relu"、"gelu"、"silu"、"none"。 + +**返回值说明** + +- output: Tensor类型,必选。ffn的输出。维度[batch_size, seq_lens, hidden_size]。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - up_fc_weight:FP16、BF16 + - up_fc_bias:FP16、BF16 + - down_proj_weight:FP16、BF16 + - down_proj_bias: FP16、BF16 + - gate_up_proj_weight: FP16、BF16 + - gate_up_proj_bias: FP16、BF16 + - output: FP16、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_ffn.py`` 。 + + +.. _torch_mlu_ops.flash_attention: + +torch_mlu_ops.flash_attention +------------------------------------------ + +.. code:: + + torch_mlu_ops.flash_attention(q,k,v,out,cu_seq_lens_q,cu_seq_lens_kv,alibi_slope,attn_bias,k_cache_quant_scale,v_cache_quant_scale,block_tables,max_seq_len_q,max_seq_len_kv,softmax_scale,is_causal,window_size_left,window_size_right,compute_dtype,return_lse) + +**功能描述** + +Context阶段的Flash Attention算子。 + +**公式说明** + +* flash attention标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + scores = batch\_matmul(q, k, transpose\_b=True) * softmax\_scale\\ + scores += ALiBi + attn\_bias + causal\_mask + window\_mask\\ + output\_lse = logsumexp(scores)\\ + attention = softmax(scores)\\ + output = batch\_matmul(attention, v)\\ + \end{array} + +* 公式中各参数描述如下: + + - ``q/k/v`` :flash attention模块的输入。 + - ``attn_bias``:flash attention模块的偏置。 + - ``ALiBi``:flash attention模块在线计算的ALiBi位置编码矩阵。 + - ``causal_mask``:如果是因果attention的,则需要添加mask。 + - ``window_mask``:如果设置了window的尺寸,则需要添加mask。 + - ``scores`` :flash attention模块第一个批量矩阵乘的输出。 + - ``output_lse``:flash_attention的lse输出。 + - ``attention`` :flash attention模块经过softmax后的输出。 + - ``output`` :flash attention模块第二个批量矩阵乘的输出,也是模块的最终输出。 + +* 公式中各函数描述如下: + + - ``batch_matmul`` :批量矩阵乘。 + - ``softmax`` :归一化指数函数。 + - ``logsumexp`` :对数求和指数函数。 + +**参数说明** + +- ``q``:Tensor类型,必选。flash attention算子的输入。维度 ``[total_q, q_head_num, head_size_qk]``。 +- ``k``:Tensor类型,必选。flash attention算子的输入。当block_tables为空时,维度 ``[total_k, kv_head_num, head_size_qk]``, 否则维度 ``[num_blocks, kv_head_num, block_size, head_size_qk]``。 +- ``k``:Tensor类型,必选。flash attention算子的输入。当block_tables为空时,维度 ``[total_k, kv_head_num, head_size_v]``, 否则维度 ``[num_blocks, kv_head_num, block_size, head_size_v]``。 +- ``out``:Tensor类型,可选。flash attention算子的原位输入。维度同input。 +- ``cu_seq_lens``:Tensor类型,可选。当输入q/k/v为pack模式时候记录每个batch的seq的位置,维度 ``[batch+1]``。 +- ``alibi_slope``:Tensor类型,可选。是一种相对位置编码方法,维度 ``[batch, q_head_num]``或者 ``[q_head_num]``,类型必须是float32。 +- ``attn_bias``:Tensor类型,可选。attention的偏置,维度 ``[batch, max_seq_len, max_seq_len]`` 或者 ``[batch, q_head_num, max_seq_len, max_seq_len]``。 +- ``k_cache_quant_scale``:Tensor类型,预留参数,必须为空。 +- ``v_cache_quant_scale``:Tensor类型,预留参数,必须为空。 +- ``block_tables``:Tensor类型,可选。block的id与kv_cache之间的映射关系,为空时表示非paged模式。维度 ``[bs, max_num_blocks_per_seq]``,max_num_blocks_per_seq为1时block_size=max(seq_len_kv)。 +- ``max_seq_len``:Int32类型,必选。算子输入的最大seq长度。 +- ``softmax_scale``:Float32类型,必选。softmax之前的系数。 +- ``is_causal``:Bool类型,标记是否是因果的,如果为因果的则计算过程中需要使用 ``causal_mask``。 +- ``window_size_left``:Int32类型, 左侧window size大小,无限制则设置为-1。实际参与计算的key和value长度为 ``window_size_left + 1``。 +- ``window_size_right``:Int32类型,右侧window size大小,无限制则设置为-1。实际参与计算的key和value长度为 ``window_size_right + 1``。 +- ``return_lse``:Bool类型,标记是否返回lse tensor。 + + +**返回值说明** + +- ``output``:Tensor类型,可选。flash attention的输出,当输入的out不为空时无此输出。维度 ``[total_q, q_head_num, head_size_v]``。 +- ``output_lse``:Tensor类型,可选。flash attention的输出,当return_lse为true时有此输出,维度 ``[batch, head, max_seq_len_q]``。 + + +**规格限制** + +- 数据类型支持: + + - q/k/v:FP16、BF16、FP32 + - cu_seq_lens:INT32 + - alibi_slope:FP32 + - attn_bias:FP16、BF16、FP32 + - block_tables:INT32 + - max_seq_len:INT32 + - softmax_scale:FP32 + - is_causal:BOOL + - window_size_left:INT32 + - window_size_right:INT32 + - output:FP16、BF16、FP32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_flash_attention.py``。 + + +.. _torch_mlu_ops.flash_attn_sq_mm_allreduce: + +torch_mlu_ops.flash_attn_sq_mm_allreduce +------------------------------------------ + +.. code:: + + torch_mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm, q, k, v, cu_seq_lens_q, cu_seq_lens_kv, + alibi_slope, attn_bias, smooth, weight, weight_scale, + bias, max_seq_len_q, max_seq_len_kv, softmax_scale, + is_causal, window_size_left, window_size_right, + compute_dtype, block_seq) + +**功能描述** + +Context阶段的Flash Attention、smooth_quant matmul与allreduce通信计算融合算子。 + +**公式说明** + +无 + +**参数说明** + +- ``cncl_comm`` :Int类型,cnclComm,用于归约通信。 +- ``q``:Tensor类型,必选。flash attention算子的输入。维度 ``[total_q, q_head_num, head_size]``。 +- ``k/v``:Tensor类型,必选。flash attention算子的输入。维度 ``[total_k, kv_head_num, head_size]``。 +- ``cu_seq_lens_q``:Tensor类型,可选。当输入q为pack模式时候记录每个batch的seq的位置,维度 ``[batch + 1]``。 +- ``cu_seq_lens_kv``:Tensor类型,可选。当输入k/v为pack模式时候记录每个batch的seq的位置,维度 ``[batch + 1]``。 +- ``alibi_slope``:Tensor类型,可选。是一种相对位置编码方法,维度 ``[batch, q_head_num]`` 或者 ``[q_head_num]``,类型必须是float32。 +- ``attn_bias``:Tensor类型,可选。attention的偏置,维度 ``[batch, max_seq_len, max_seq_len]`` 或者 ``[batch, q_head_num, max_seq_len, max_seq_len]``。 +- ``smooth``:Tensor类型,必选。平滑系数, 维度 ``[head_num_q * head_size]``。 +- ``weight``:Tensor类型,必选。, smooth_quant matmul的权重矩阵,维度 ``[head_num_q * head_size, head_num_q * head_size]``。 +- ``weight_scale``:Tensor类型,必选。, smooth_quant matmul的权重矩阵的量化参数,维度 ``[head_num_q * head_size]``。 +- ``bias``:Tensor类型,可选。, smooth_quant matmul的偏置,维度 ``[head_num_q * head_size]``。 +- ``max_seq_len_q``:Int32类型,必选。q的最大seq长度。 +- ``max_seq_len_kv``:Int32类型,必选。k/v的最大seq长度。 +- ``softmax_scale``:Float32类型,必选。softmax之前的系数。 +- ``is_causal``:Bool类型,标记是否是因果的,如果为因果的则计算过程中需要使用 ``causal_mask``。 +- ``window_size_left``:Int32类型, 左侧window size大小,无限制则设置为-1。 +- ``window_size_right``:Int32类型, 右侧window size大小,无限制则设置为-1。 +- ``compute_dtype``:torch类型,在计算softmax时使用。 +- ``block_seq``: Int类型,保留字段,设置0即可。 + + +**返回值说明** + +- ``output``:Tensor类型,必选。输出,维度 ``[total_q, q_head_num * head_size]``。 + + +**规格限制** + +- 数据类型支持: + + - cncl_comm: INT64 + - q/k/v:FP16、BF16、FP32 + - cu_seq_lens_q:INT32 + - cu_seq_lens_kv:INT32 + - alibi_slope:FP32 + - attn_bias:FP16、BF16、FP32 + - smooth:FP32 + - weight:INT8 + - weight_scale:FP32 + - bias:FP16、BF16、FP32 + - max_seq_len_q:INT32 + - max_seq_len_kv:INT32 + - softmax_scale:FP32 + - is_causal:BOOL + - window_size_left:INT32 + - window_size_right:INT32 + - block_seq:INT32 + - output:FP16、BF16、FP32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/allreduce_case/test_flash_attn_sq_mm_allreduce.py``。 + + +.. _torch_mlu_ops.fused_layer_norm: + +torch_mlu_ops.fused_layer_norm +------------------------------------------ + +.. code:: + + torch_mlu_ops.fused_layer_norm(input,residual,gamma,beta,bias,eps,store_output_before_norm,quant_scale,out,dynamic_quant) + + +**功能描述** + +layernorm归一化算子。 + +**公式说明** + +* layernorm标准结构公式如下: + + .. math:: + + \begin{aligned} + output = weight \times \frac{input - Mean(input)}{\sqrt{Variance(input) + eps}} + bias\\ + output = Quantize(output, quant_scale) + \end{aligned} + +* 公式中各参数描述如下: + + - ``input`` layernorm的输入。 + - ``Mean()`` 表示对input取均值。 + - ``Variance()`` 表示对input取方值。 + - ``Quantize()`` 表示对输出做量化。 + - ``weight`` 表示layernorm的 ``weight`` 参数。 + - ``bias`` 表示layernorm的 ``bias`` 参数。 + - ``eps`` 表示layernorm eps的值,一个极小的正数,避免除零,一般默认值设定为0.00001。 + - ``output`` layernorm的输出。 + - ``quant_scale`` 量化参数。 + +**参数说明** + +- ``input``:Tensor类型,必选。layernorm算子的输入。维度 ``[batch,seq_lens, hidden_size]``,支持batch维度和seq_lens维度的stride。 +- ``residual``:Tensor类型,可选。layernorm算子的残差输入。在pad模式下,维度 ``[batch,seq_lens, hidden_size]``,支持batch维度和seq_lens维度的stride。 +- ``gamma``:Tensor类型,可选。layernorm算子的 ``weight`` 。维度 ``[hidden_size]``。 +- ``beta``:Tensor类型,可选。layernorm算子的 ``bias`` 。维度 ``[hidden_size]``。 +- ``bias``:Tensor类型,可选。layernorm算子的残差的偏置输入。维度 ``[hidden_size]``。 +- ``eps``:Float类型,必选。layernorm算子的 ``eps`` 的值,一般设定为0.00001。 +- ``store_output_before_norm``:Bool类型,必选。layernorm算子是否输出 ``input+residual`` 的值。当为true时,第二个返回值输出 ``input+residual``,当为false时,无第二个返回值。 +- ``quant_scale``:Float类型,可选。量化参数。维度 ``[hidden_size]``。当 ``dynamic_quant`` 为False时表示离线量化参数,为True时表示在线量化时的平滑因子。当该参数不为空时,输入输出必须连续。 +- ``out``:Tensor类型,可选。layernorm算子的原位输出。维度同input,支持batch维度和seq_lens维度的stride。 +- ``dynamic_quant``:Bool类型,为False且 ``quant_scale`` 不为空时表示离线per-channel量化输出,否则为在线动态per-token量化输出。 + +**返回值说明** + +- ``output``:Tensor类型,可选。layernorm算子的输出,当输入参数out不为空时无此输出。维度 ``[batch_size,seq_lens, hidden_size]``,支持batch维度和seq_lens维度的stride。 +- ``residual_output``:Tensor类型,可选。layernorm算子的residual输出。当store_output_before_norm为true时,返回值输出 ``input+residual``,当store_output_before_norm为false时,无此返回值。在pad模式下,维度 ``[batch_size*seq_lens, hidden_size]``;在pack模式下,维度 ``[total_seqlen, hidden_size]``。 +- ``smooth_quant_scale``:Tensor类型,可选。维度 ``[batch_size*seq_lens]`` 。当 ``dynamic_quant`` 为True时输出的量化scale。 + + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - residual:FP16、BF16 + - beta:FP16、BF16 + - gamma:FP16、BF16 + - bias:FP16、BF16 + - eps:Float32 + - store_output_before_norm:Bool + - quant_scale:Float + - dynamic_quant:Bool + - output:FP16、BF16、int8 + - residual_output: FP16、BF16 + - smooth_quant_scale: FP32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_fusenorm.py`` 。 + +.. _torch_mlu_ops.fused_moe: + +torch_mlu_ops.fused_moe +------------------------------------------ + +.. code:: + + torch_mlu_ops.fused_moe(input, router_logit, w1, w2, bias1, bias2, + residual, input_smooth, act_smooth, w1_scale, + w2_scale, topk, renormalize, gated, act_mode, + start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag) + +**功能描述** + +mixtral of experts, FFN(FeedForward)前馈神经网络模块的平替层,支持通信计算融合模式。 + +**公式说明** + +无 + +**参数说明** + +- ``input``:Tensor类型,必选。MOEFFN算子的输入。维度 ``[batch_size, seq_lens, hidden_size]`` 或 ``[total_seq, hidden_size]``。 +- ``router_logit``:Tensor类型,必选。MOEFFN算子中expert的得分。维度 ``[batch_size, seq_lens, expert_num]`` 或 ``[batch_size, seq_lens, expert_num]``。 +- ``w1``:Tensor类型,必选。MOEFFN算子中升维矩阵权重参数。维度 ``[expert_size, ffn_inner_size*(1+gated), hidden_size]`` 或 ``[expert_size, ffn_inner_size*(1+gated), hidden_size / 2]``。 +- ``w2``:Tensor类型,必选。MOEFFN算子中降维矩阵权重参数。维度 ``[expert_size, hidden_size, ffn_inner_size]`` 或 ``[sub_expert_size, hidden_size, block, ffn_inner_size]`` 或 ``[expert_size, hidden_size, ffn_inner_size / 2]``。 +- ``bias1``:Tensor类型,可选。MOEFFN算子中升维矩阵偏置参数。维度 ``[expert_size, ffn_inner_size*(1+gated)]``。 +- ``bias2``:Tensor类型,可选。MOEFFN算子中降维矩阵偏置参数。维度 ``[expert_size, hidden_size]``。 +- ``residual``:Tensor类型,可选。MOEFFN算子的残差。维度 ``[batch_size, seq_lens, hidden_size]`` 或 ``[total_seq, hidden_size]``。 +- ``input_smooth``:Tensor类型,可选。MOEFFN算子smooth_quant模式时input的平滑系数。维度 ``[hidden_size]``。 +- ``act_smooth``:Tensor类型,可选。MOEFFN算子smooth_quant模式时激活操作之后的平滑系数。维度 ``[ffn_inner_size]``。 +- ``w1_scale``:Tensor类型,可选。MOEFFN算子量化模式时升维权重的量化参数。维度 ``[expert_size, ffn_inner_size*(1+gated)]`` 或 ``[w1_quant_group, expert_size, ffn_inner_size*(1+gated)]``。 +- ``w2_scale``:Tensor类型,可选。MOEFFN算子量化模式时降维权重的量化参数。维度 ``[expert_size, hidden_size]`` 或 ``[w2_quant_group, expert_size, hidden_size]``。 +- ``topk``:Int类型,必选。MOEFFN算子中,从每个token的expert_logits中选择topk个。 +- ``renormalize``:Bool类型,必选。MOEFFN算子中softmax结果后是否需要归一化。 +- ``gated``:Bool类型,必选。MOEFFN层中是否有gated结构。 +- ``act_mode``:string类型,必选。MOEFFN的激化函数类型,可选的激活函数:``gelu``、``silu``。 +- ``start_expert_id``:Int类型,可选。用于表明在专家并行模式下专家的起始位置。 +- ``block_n``:Int类型,可选。用于表明在通算融合模式下group_gemm和allreduce的拆分次数。 +- ``cncl_comm``:Int类型,可选。用于表明在通算融合模式下的通信句柄。 +- ``w1_quant_flag``:List类型,可选。MOEFFN算子w4w8混合量化模式时标记quant_group的bit位,数值必须是4或8,其它值会引发错误。长度 ``[expert_size * w1_quant_group]``。 +- ``w2_quant_flag``:List类型,可选。MOEFFN算子w4w8混合量化模式时标记quant_group的bit位,数值必须是4或8,其它值会引发错误。长度 ``[expert_size * w2_quant_group]``。 + + +**返回值说明** + +- ``output``: Tensor类型。 维度 ``[batch_size, seq_lens, hidden_size]`` 或 ``[total_seq, hidden_size]``。 + +**规格限制** + +- 参数限制: + + - input_smooth/act_smooth/w1_scale/w2_scale: 只有smooth quant pertoken模式启用,input采用pertoken量化,权重采用per-channel量化。 + +- 数据类型支持: + + - input:HALF、FP32、BF16 + - router_logit:HALF、FP32、BF16 + - w1:HALF、FP32、BF16、INT8、INT8(INT4x2) + - w2:HALF、FP32、BF16、INT8、INT8(INT4x2) + - bias1:HALF、FP32、BF16 + - bias2:HALF、FP32、BF16 + - residual:HALF、FP32、BF16 + - input_smooth:FP32 + - act_smooth:FP32 + - w1_scale:FP32 + - w2_scale:FP32 + - start_expert_id: INT32 + - block_n: INT32 + - cncl_comm: INT64 + - output:HALF、FP32、BF16 + - w1_quant_flag: INT + - w2_quant_flag: INT + +**备注** + +- w2维度是4D时,sub_expert_size * block 必须等于 expert_size, 并且要求ffn_inner_size是1024的整数倍。 +- group量化时,w1_quant_group * w1_quant_group_size 必须等于 hidden_size,w2_quant_group * w2_quant_group_size 必须等于 ffn_inner_size。 +- group量化时,w1_scale的形状必须是 (w1_quant_group, expert_size, ffn_inner_size*(1+gated)),w2_scale的形状必须是 (w2_quant_group, expert_size, hidden_size)。 +- w1_quant_group_size与w2_quant_group_size表示group量化场景下一个group block的元素数量。必须是128的整数倍。 +- w2维度是4D时不支持int4量化。 +- 通算融合模式不支持int4量化。 +- w1_quant_flag 与 w2_quant_flag存在时为w4w8混合量化模式,其元素只能是4或8。4表示该quant_group对应的权重是int4量化,8表示该quant_group对应的权重是int8量化。 +- w4w8混合量化模式要求w1与w2必须是1D。 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe.py`` 和 ``torch_mlu_ops/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py``。 + + +.. _torch_mlu_ops.fused_norm_attention_project: + +torch_mlu_ops.fused_norm_attention_project +------------------------------------------ + +.. code:: + + torch_mlu_ops.fused_norm_attention_project(input,q_weight,q_bias,k_weight,k_bias,v_weight,v_bias,norm_weight,norm_bias,eps,out_layout,head_size,norm_out)->Tuple(Tensor) + + +**功能描述** + +计算Attention模块中的Q、K、V算子。前融合 ``layernorm`` 的 ``attention_project``算子。 适用于GPT或者 ``hidden_size`` 小于2048的transformer网络。 + +**公式说明** + +* 公式如下: + + .. math:: + + \begin{array}{lcl} + normed = layernorm(input, norm_weight, norm_bias, eps) + Q = fc(normed, weight\_q) + bias\_q\\ + K = fc(normed, weight\_k) + bias\_k\\ + V = fc(normed, weight\_v) + bias\_v\\ + \end{array} + + +* 公式中各参数描述如下: + + - ``input`` :fused_norm_attention_project算子的输入。 + - ``normed`` : 经过 ``layernorm`` 计算后的输出。 + - ``weight_q``、 ``weight_k`` 、 ``weight_v`` 、 ``norm_weight`` :Query、Key、Value矩阵的权重。 + - ``bias_q``、 ``bias_k`` 、 ``bias_v`` 、 ``nrom_bias`` :Query、Key、Value矩阵的偏置。 + - ``eps`` : layernorm中的 ``eps`` 参数。 + +* 公式中各函数描述如下: + + - ``fc`` :全连接层(fully-connected)的缩写。 + - ``layernorm`` : layernorm归一化。 + +**参数说明** + +- ``input``:Tensor类型,必选。``fused_norm_attention_project`` 的输入。维度 ``[batch_size, seq_lens, input_size]``。其中 ``batch_size`` 表示batch的大小,``seq_lens`` 表示sequence的长度,``input_size`` 表示输入的大小。 +- ``q_weight``:Tensor类型,必选。``fused_norm_attention_project`` 的Q矩阵权重参数。维度 ``[hidden_size, input_size]``。 +- ``q_bias``:Tensor类型,可选。``fused_norm_attention_project` 的Q矩阵偏置参数。维度 ``[hidden_size]``。 +- ``k_weight``:Tensor类型,必选。``fused_norm_attention_project`` 的K矩阵权重参数。维度 ``[hidden_size, input_size]``。 +- ``k_bias``:Tensor类型,可选。``fused_norm_attention_project` 的K矩阵偏置参数。维度 ``[hidden_size]``。 +- ``v_weight``:Tensor类型,必选。``fused_norm_attention_project`` 的V矩阵权重参数。维度 ``[hidden_size, input_size]``。 +- ``v_bias``:Tensor类型,可选。``fused_norm_attention_project` 的V矩阵偏置参数。维度 ``[hidden_size]``。 +- ``norm_weight``:Tensor类型,可选。``fused_norm_attention_project`` 的layernorm权重参数。维度 ``[input_size]``。 +- ``norm_bias``:Tensor类型,可选。``fused_norm_attention_project` 的Vlayernorm偏置参数。维度 ``[input_size]``。 +- ``eps`` : float类型。layernorm的 ``eps`` 参数。 +- ``out_layout`` : String类型。允许的layout为 ``nthc`` 或者 ``nhtc``。 +- ``head_size`` : String类型。 ``output`` 的最后一维,当 ``out_layout`` 为 ``nhtc`` 。 +- ``norm_out`` : Bool类型。 当等于 ``True`` 时,返回 ``layernorm`` 的输出。 + + +**返回值说明** + +- output: Tensor类型,必选。fused_norm_attention_project的输出。根据入参的条件,返回不同的结果。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - q_weight:FP16、BF16 + - q_bias:FP16、BF16 + - k_weight:FP16、BF16 + - k_bias:FP16、BF16 + - v_weight:FP16、BF16 + - v_bias:FP16、BF16 + - norm_weight:FP16、BF16 + - norm_bias:FP16、BF16 + - output:FP16、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_fused_attn_proj.py`` 。 + + +.. _torch_mlu_ops.fused_norm_residual_ffn: + +torch_mlu_ops.fused_norm_residual_ffn +------------------------------------------ + +.. code:: + + torch_mlu_ops.fused_norm_residual_ffn(input,up_fc_weight,up_fc_bias,down_proj_weight,down_proj_bias,gate_up_proj_weight,gate_up_proj_bias,layernorm_weight,layernorm_bias,eps,act_mode,residual_is,alhpa,beta)->Tensor + +**功能描述** + +计算transformer中的FFN(FeedForward)前馈神经网络模块的算子。前融合layernorm算子,适用于GPT或者 ``hidden_size`` 小于2048的transformer网络。 + +**公式说明** + +* 前融合layernorm算子的FFN标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + normed = layernorm(input, norm_weight, norm_bias, eps) + fc\_output = fc(normed, up\_fc\_weight) + up\_fc\_bias\\ + active\_output = activation(fc\_output)\\ + basic\_output = fc(active\_output, down\_proj\_weight) + down\_proj\_bias + \end{array} + +* 前融合layernorm算子的Gated FFN标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + normed = layernorm(input, norm_weight, norm_bias, eps) + fc\_output = fc(normed, up\_fc\_weight) + up\_fc\_bias\\ + active\_output = activation(fc\_output)\\ + gated_fc\_output = fc(input, gate\_up\_proj\_weight) + gate\_up\_proj\_bias\\ + elementwise\_product\_output = elementwise\_product(active\_output, gated\_fc\_output)\\ + gated\_output = fc(elementwise\_product\_output, down\_proj\_weight) + down\_proj\_bias\\ + \end{array} + +* 公式中各参数描述如下: + + - ``input`` :FFN模块的输入。 + - ``normed`` : 经过 ``layernorm`` 计算后的输出。 + - ``fc_output`` :FFN模块经过升维后的输出。 + - ``active_output`` :FFN模块经过激活后的输出。 + - ``basic_output`` :标准FFN模块经过降维后的输出。 + - ``gated_fc_output`` :gated FFN旁路经过升维后的输出。 + - ``elementwise_product_output`` :gated FFN主路和旁路经过elementwise product后的输出。 + - ``gated_output`` :gated FFN模块经过降维后的输出。 + + +* 公式中各函数描述如下: + + - ``fc`` :全连接层(fully-connected)的缩写。 + - ``activation`` :用户可配的激活函数,可选的激活函数:"relu"、"gelu"、"silu"。 + - ``elementwise_product`` :点积。 + +**参数说明** + +- ``input``:Tensor类型,必选。 ``fused_norm_residual_ffn`` 算子的输入。维度 ``[batch_size, seq_lens, hidden_size]`` 。 +- ``up_fc_weight``:Tensor类型,必选。 ``fused_norm_residual_ffn`` 算子中升维矩阵权重参数,维度 ``[ffn_inner_size, hidden_size]`` 。 +- ``up_fc_bias``:Tensor类型,可选。 ``fused_norm_residual_ffn`` 算子中升维矩阵偏置参数,维度 ``[ffn_inner_size]`` 。其中ffn_inner_size表示FFN中间层的维度。 +- ``down_proj_weight``:Tensor类型,必选。 ``fused_norm_residual_ffn`` 算子中降维矩阵权重参数,维度 ``[hidden_size, ffn_inner_size]`` 。 +- ``down_proj_bias``:Tensor类型,可选。 ``fused_norm_residual_ffn`` 算子中降维矩阵偏置参数,维度 ``[hidden_size]`` 。 +- ``gate_up_proj_weight``:Tensor类型,可选。Gated FFN算子中升维矩阵权重参数,维度 ``[hidden_size, ffn_inner_size]`` 。 +- ``gate_up_proj_bias``:Tensor类型,可选。Gated FFN算子中升维矩阵偏置参数,维度 ``[hidden_size]`` 。 +- ``norm_weight``:Tensor类型,可选。 ``fused_norm_residual_ffn`` 的layernorm权重参数。维度 ``[input_size]``。 +- ``norm_bias``:Tensor类型,可选。 ``fused_norm_residual_ffn`` 的layernorm偏置参数。维度 ``[input_size]``。 +- ``eps`` : float类型。layernorm的 ``eps`` 参数。 +- ``act_mode``:string类型,必选。FFN的激化函数类型,可选的激活函数:"relu"、"gelu"、"silu"、"none"。 +- ``residual_is`` : string类型, 决定 ``residual`` 的来源,当 ``residual_is = input`` 时,表示来自 ``input`` 。 +- ``alpha`` : float类型。默认值为1.0。当 ``residual`` 存在时,作FNN模块的系数。 +- ``beta`` : float类型。默认值为1.0。当 ``residual`` 存在时,作为 ``residual`` 的系数。 + +**返回值说明** + +- output: Tensor类型,必选。 ``ffn`` 的输出。维度 ``[batch_size, seq_lens, hidden_size]`` 。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - up_fc_weight:FP16、BF16 + - up_fc_bias:FP16、BF16 + - down_proj_weight:FP16、BF16 + - down_proj_bias: FP16、BF16 + - gate_up_proj_weight: FP16、BF16 + - gate_up_proj_bias: FP16、BF16 + - norm_weight: FP16、BF16 + - norm_bias: FP16、BF16 + - output: FP16、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_fused_ffn.py`` 。 + + +.. _torch_mlu_ops.fused_rms_norm: + +torch_mlu_ops.fused_rms_norm +------------------------------------------ + +.. code:: + + torch_mlu_ops.fused_rms_norm(input,residual,gamma,beta,bias,eps,store_output_before_norm,quant_scale,out) + +**功能描述** + +RMSnorm归一化算子。 + +**公式说明** + +* RMSnorm标准结构公式如下: + + .. math:: + + \begin{aligned} + output = weight \times \frac{input}{\sqrt{RMS(input) + eps}} + bias\\ + output = Quantize(output, quant_scale) + \end{aligned} + +* 公式中各参数描述如下: + + - ``input``:RMSnorm的输入。 + - ``RMS()``:表示对input取均方(mean square)。 + - ``Quantize()`` 表示对输出做量化。 + - ``weight``:表示RMSnorm的 ``weight`` 参数。 + - ``bias``:表示RMSnorm的 ``bias`` 参数。 + - ``eps``:表示RMSnorm eps的值,一个极小的正数,避免除零,一般默认值设定为0.00001。 + - ``output``:RMSnorm的输出。 + +**参数说明** + +- ``input``:Tensor类型,必选。RMSnorm算子的输入。维度 ``[batch,seq_lens, hidden_size]``,支持batch维度和seq_lens维度的stride。 +- ``residual``:Tensor类型,可选。RMSnorm算子的残差输入。在pad模式下,维度 ``[batch,seq_lens, hidden_size]``,支持batch维度和seq_lens维度的stride。 +- ``gamma``:Tensor类型,可选。RMSnorm算子的 ``weight`` 。维度 ``[hidden_size]``。 +- ``beta``:Tensor类型,可选。RMSnorm算子的 ``bias`` 。维度 ``[hidden_size]``。 +- ``bias``:Tensor类型,可选。RMSnorm算子的残差的偏置输入。维度 ``[hidden_size]``。 +- ``eps``:Float类型,必选。RMSnorm算子的 ``eps`` 的值,一般设定为0.00001。 +- ``store_output_before_norm``:Bool类型,必选。RMSnorm算子是否输出 ``input+residual`` 的值。当为true时,第二个返回值输出 ``input+residual``,当为false时,无第二个返回值。 +- ``quant_scale``:Float类型,可选。量化参数。维度 ``[hidden_size]``。当 ``dynamic_quant`` 为False时表示离线量化参数,为True时表示在线量化时的平滑因子。当该参数不为空时,输入输出必须连续。 +- ``out``:Tensor类型,可选。RMSnorm算子的原位输出。维度同input,支持batch维度和seq_lens维度的stride。 +- ``dynamic_quant``:Bool类型,为False且 ``quant_scale`` 不为空时表示离线per-channel量化输出,否则为在线动态per-token量化输出。 + +**返回值说明** + +- ``output``:Tensor类型,可选。RMSnorm算子的输出,当输入参数out不为空时无此输出。维度 ``[batch, seq_lens, hidden_size]``,支持batch维度和seq_lens维度的stride。 +- ``residual_output``:Tensor类型,可选。RMSnorm算子的residual输出。当store_output_before_norm为true时,返回值输出 ``input+residual``,当store_output_before_norm为false时,无此返回值。在pad模式下,维度 ``[batch_size*seq_lens, hidden_size]``;在pack模式下,维度 ``[total_seqlen, hidden_size]``。 +- ``smooth_quant_scale``:Tensor类型,可选。维度 ``[batch_size*seq_lens]`` 。当 ``dynamic_quant`` 为True时输出的量化scale。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - residual:FP16、BF16 + - beta:FP16、BF16 + - gamma:FP16、BF16 + - bias:FP16、BF16 + - eps:Float32 + - store_output_before_norm:Bool + - quant_scale:Float + - output:FP16、BF16、int8 + - residual_output: FP16、BF16 + - smooth_quant_scale: FP32 + + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_fusenorm.py`` 。 + +.. _torch_mlu_ops.fused_rope: + +torch_mlu_ops.fused_rope +------------------------------------------ + +.. code:: + + torch_mlu_ops.fused_rope(input, key_cache_hp, value_cache_hp, sin_cache, cos_cache, position_id, gamma, beta, key_cache_lp, value_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, cache_bs_id_lp, cache_seq_offsets_lp, k_scale_hp, v_scale_hp, k_scale_lp, v_scale_lp, slot_mapping_hp, slot_mapping_lp, eps)-> None + +**功能描述** + +fold rope融合算子。 + +**公式说明** + +* 公式如下: + +该算子为apply rotary,layernorm,以及reshape linear/paged cache功能的融合。 + +**参数说明** + +- ``qkv``:Tensor类型,必选。输入维度 ``[batch, 1, head_num_q + head_num_k * 2, head_size]`` 。 +- ``key_cache_hp``:Tensor类型,必选。当 ``paged=False``时,维度为 ``(max_bs, head_num_k, max_decode_len, head_size)`` 否则 ``(num_blocks, head_num_k, block_size, head_size)`` 。 +- ``value_cache_hp``:Tensor类型,必选。当 ``paged=False``时,维度为 ``(max_bs, head_num_k, max_decode_len, head_size)`` 否则 ``(num_blocks, head_num_k, block_size, head_size)`` 。 +- ``sin_table``:Tensor类型,必选。输入维度 ``[rotary_seq, head_size]`` 。 +- ``cos_table``:Tensor类型,必选。输入维度 ``[rotary_seq, head_size]`` 。 +- ``position_id``:Tensor类型,必选。输入维度为 ``[batch]`` 。 +- ``gamma``:Tensor类型,必选。输入维度是 ``[head_size]`` 。 +- ``beta``:Tensor类型,必选。输入维度是 ``[head_size]`` 。 +- ``key_cache_lp``:Tensor类型,可选。当 ``paged=False``时,维度为 ``(max_bs, head_num_k, max_decode_len, head_size / 2)`` 否则 ``(num_blocks, head_num_k, block_size, head_size / 2)`` 。 +- ``value_cache_lp``:Tensor类型,可选。当 ``paged=False``时,维度为 ``(max_bs, head_num_k, max_decode_len / 2, head_size)`` 否则 ``(num_blocks, head_num_k, block_size / 2, head_size)`` 。 +- ``cache_bs_id_hp``:Tensor类型,可选,默认值为None。当前batch在cache中的位置,维度 ``[batch]`` 。 +- ``cache_seq_offsets_hp``:Tensor类型,可选,默认值为None。当前kv在cache中的位置,维度 ``[batch]`` 。 +- ``cache_bs_id_lp``:Tensor类型,可选,默认值为None。当前batch在cache中的位置,维度 ``[batch]`` 。 +- ``cache_seq_offsets_lp``:Tensor类型,可选,默认值为None。当前kv在cache中的位置,维度 ``[batch]`` 。 +- ``key_scale_hp``:Tensor类型,可选,默认值为None。key的量化参数,维度 ``[head_num_k, head_size]`` 。 +- ``value_scale_hp``:Tensor类型,可选,默认值为None。value的量化参数,维度 ``[head_num_k, head_size]`` 。 +- ``key_scale_lp``:Tensor类型,可选,默认值为None。key的量化参数,维度 ``[max_bs, head_num_k, group_num]`` 。 +- ``value_scale_lp``:Tensor类型,可选,默认值为None。value的量化参数,维度 ``[max_bs, head_num_k, group_num]`` 。 +- ``slot_mapping_hp``:Tensor类型,可选,默认值为None。paged_cahe时指示当前kv在cache中的位置,维度 ``[batch]`` 。 +- ``slot_mapping_lp``:Tensor类型,可选,默认值为None。paged_cahe时指示当前kv在cache中的位置,维度 ``[batch]`` 。 +- ``eps``:float32类型,可选。layernorm的eps参数,默认值为1e-5 。 + +**返回值说明** + +- ``qkv``:原位操作。 +- ``key_cache_hp``:原位操作。 +- ``value_cache_hp``:原位操作。 +- ``key_cache_lp``:原位操作。 +- ``value_cache_lp``:原位操作。 +- ``key_scale_lp``:原位操作。 +- ``value_scale_lp``:原位操作。 + +**规格限制** + +- 数据类型支持: + + - qkv:FP16、BF16 + - key_cache_hp:FP16、BF16、INT8 + - value_cache_hp:FP16、BF16、INT8 + - sin_table:FP16、BF16 + - cos_table:FP16、BF16 + - position_id:INT32 + - gamma:FP16、BF16 + - beta:FP16、BF16 + - key_cache_lp: INT64 + - value_cache_lp: INT4 + - cache_bs_id_hp:INT32 + - cache_seq_offsets_hp:INT32 + - cache_bs_id_lp:INT32 + - cache_seq_offsets_lp:INT32 + - key_scale_hp:FP32 + - value_scale_hp:FP32 + - key_scale_lp:FP32 + - value_scale_lp:FP32 + - slot_mapping_hp:INT32 + - slot_mapping_lp: INT32 + - eps:FP32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_fused_rope.py`` 。 + +.. _torch_mlu_ops.group_gemm: + +torch_mlu_ops.group_gemm +------------------------------------------ + +.. code:: + + torch_mlu_ops.group_gemm(a, b, m_list, expand_idx, c, alpha, beta, max_m, bias) -> torch.Tensor + +**功能描述** + +多组GEMM集成算子。 + +**公式说明** + +无 + +**参数说明** + +- ``a``:Tensor类型,必选。GEMM算子的输入矩阵。维度 ``[m, k]``。 ``expand_idx `` 存在时, ``total_m`` 等于 ``m_list`` 元素之和,否则 ``total_m`` 等于 ``m``。 +- ``b``:Tensor类型,必选。GEMM算子的权重矩阵。维度 ``[experts_num, n, k]``。 +- ``m_list``:Tensor类型,必选。用于指定每个子矩阵的m维度。维度 ``[experts_num]``。 +- ``expand_idx``:Tensor类型,可选。用于扩展输入矩阵。维度 ``[total_m]``。 +- ``c``:Tensor类型,可选。残差。维度 ``[total_m, n]``。 +- ``alpha``:Tensor类型,可选。a乘b的系数。维度 ``[experts_num]``。 +- ``beta``:Tensor类型,可选。c的系数。维度 ``[experts_num]``。 +- ``max_m``: int类型,必选。m_list中的最大值。由于m_list是device数据,获取真实的最大值会影响性能,建议提供m_list中的理论最大值。 +- ``bias``: Tensor类型,可选。GEMM算子的bias矩阵。维度 ``[experts_num, n]``。 + + +**返回值说明** + +- ``output``: Tensor类型。 维度 ``[total_m, n]``。 + +**规格限制** + +- 数据类型支持: + + - a:HALF、FP32、BF16、INT8 + - b:HALF、FP32、BF16、INT8 + - m_list: INT32 + - expand_idx: INT32 + - c:HALF、FP32、BF16、INT8 + - alhpa:FP32 + - beta:FP32 + - output:HALF、FP32、BF16 + - bias: HALF、FP32、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_group_gemm.py`` 。 + +.. _torch_mlu_ops.smooth_quant_group_gemm: + +torch_mlu_ops.smooth_quant_group_gemm +------------------------------------------ + +.. code:: + + torch_mlu_ops.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag) -> torch.Tensor + +**功能描述** + +多组GEMM集成量化功能算子。 + +**公式说明** + +无 + +**参数说明** + +- ``a``:Tensor类型,必选。GEMM算子的输入矩阵。维度 ``[m, k]``。 ``expand_idx `` 存在时, ``total_m`` 等于 ``m_list`` 元素之和,否则 ``total_m`` 等于 ``m``。 +- ``b``:Tensor类型,必选。GEMM算子的权重矩阵。维度 ``[experts_num, n, k]``。 +- ``m_list``:Tensor类型,必选。用于指定每个子矩阵的m维度。维度 ``[experts_num]``。 +- ``expand_idx``:Tensor类型,可选。用于扩展输入矩阵。维度 ``[total_m]``。 +- ``c``:Tensor类型,可选。残差。维度 ``[total_m, n]``。 +- ``alpha``:Tensor类型,可选。a乘b的系数。维度 ``[experts_num]``。 +- ``beta``:Tensor类型,可选。c的系数。维度 ``[experts_num]``。 +- ``a_scale``:Tensor类型,可选。输入的分token量化的scale系数。维度 ``[total_m]``。 +- ``b_scale``:Tensor类型,可选。权重的分channel量化的scale系数。维度 ``[experts_num, n]`` 非group量化场景,或 ``[quant_group, experts_num, n]`` group量化场景。 +- ``dtype``:dype类型, ``output`` 的数据类型,必须是 ``torch.half`` 、 ``torch.bfloat16`` 。 +- ``max_m``: int类型,必选。m_list中的最大值。由于m_list是device数据,获取真实的最大值会影响性能,建议提供m_list中的理论最大值。 +- ``bias``: Tensor类型,可选。GEMM算子的bias矩阵。维度 ``[experts_num, n]``。 +- ``quant_flag``:List类型,可选。w4w8混合量化模式时标记quant_group的bit位,数值必须是4或8,其它值会引发错误。长度 ``[experts_num * quant_group]``。 + +**返回值说明** + +- ``output``: Tensor类型。 维度 ``[total_m, n]``。 + +**规格限制** + +- 数据类型支持: + + - a:HALF、FP32、BF16、INT8 + - b:HALF、FP32、BF16、INT8、INT8(INT4x2) + - m_list: INT32 + - expand_idx: INT32 + - c:HALF、FP32、BF16、INT8 + - alhpa:FP32 + - beta:FP32 + - a_scale: FP32 + - b_scale: FP32 + - output:HALF、FP32、BF16 + - bias: HALF、FP32、BF16 + +**备注** + +- group量化场景,``quant_group * quant_group_size == k``,``quant_group_size`` 表示该场景下每个group的元素数量,``quant_group_size`` 可以取值128、256、512、1024。 +- group量化场景, b_scale的形状必须是(quant_group, experts_num, n)。 +- gather 模式不支持int4量化。 +- 4D 转数模式不支持int4量化。 +- 不支持通用量化模式的int4量化。 +- ``quant_flag`` 存在时为w4w8混合量化模式,其元素只能是4或8。4表示该quant_group对应的权重是int4量化,8表示该quant_group对应的权重是int8量化。 +- w4w8混合量化模式要求b必须是1D。 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_group_gemm.py`` 。 + +.. _torch_mlu_ops.matmul: + +torch_mlu_ops.matmul +------------------------------------------ + + torch_mlu_ops.matmul(input, filter, bias, residual, act_mode, alpha, beta, use_fast, approximate, out_dtype, a_scale, b_scale, trans_a, trans_b)-> torch.Tensor + +**功能描述** + +通用的矩阵乘法算子。 + +**公式说明** + +* 带后融合的matmul标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + mul_out = alpha * (matmul(input, filter) + bias) + beta * residual\\ + output = active(mul_out)\\ + \end{array} + +* 公式中各参数描述如下: + + - ``input`` :matmul的左矩阵。 + - ``filter``:matmul的右矩阵。 + - ``bias``:偏置。 + - ``residual``:参差。 + - ``output`` :计算最终的输出。 + +* 公式中各函数描述如下: + + - ``matmul`` :矩阵乘。 + - ``active``:激活计算。 + +**参数说明** + +- ``input``:Tensor类型,必选。矩阵乘的左矩阵。维度 ``[mat_m, mat_k]``。 +- ``filter``:Tensor类型,必选。矩阵乘的右矩阵。维度 ``[mat_n, mat_k]``。 +- ``bias``:Tensor类型,可选。矩阵乘的偏置。维度 ``[1, mat_n]``。 +- ``residual``:Tensor类型,可选。参差。维度 ``[mat_m, mat_n]``。 +- ``act_mode``:String类型,可选。后融合激活的类型,可选的激活函数: ``relu`` 、 ``gelu`` 、 ``silu``。 +- ``alpha``:Float32类型,必选。矩阵乘法系数。 +- ``beta``:Float32类型,必选。参差的系数。 +- ``use_fast``:Bool类型,必选。是否使用高性能的方式计算激活。 +- ``approximate``:Bool类型,必选。gelu是否使用更精确的近似计算方式。 +- ``out_dtype``:String类型,可选。如果输入为INT8类型,则必须指定输出类型。 +- ``a_scale``:Float32类型,默认值1.0。当input为INT8类型时对应的量化系数。 +- ``b_scale``:Float32类型,默认值1.0。当weight为INT8类型时对应的量化系数。 +- ``trans_a``:Bool类型,默认值为False,标记input是否需要转置。 +- ``trans_b``:Bool类型,默认值为True,标记weight是否需要转置。 + + +**返回值说明** + +- ``out``:Tensor类型,算子的输出,维度 ``[mat_m, mat_n]``。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16、FP32、INT8 + - filter:FP16、BF16、FP32、INT8 + - bias:FP16、BF16、FP32 + - residual:FP16、BF16、FP32 + - out:FP16、BF16、FP32 + - alpha:FP32 + - beta:FP32 + - use_fast:BOOL + - approximate:BOOL + - a_scale:FP32 + - b_scale:FP32 + - trans_a:BOOL + - trans_b:BOOL + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_matmul.py`` 。 + +.. _torch_mlu_ops.matmul_allreduce: + +torch_mlu_ops.matmul_allreduce +------------------------------------------ + +.. code:: + + torch_mlu_ops.matmul_allreduce(cncl_comm, input, filter, bias, residual, alpha, beta, block_m)-> torch.Tensor + +**功能描述** + +矩阵乘法与allreduce的通信计算融合算子。 + +**公式说明** + +无 + +**参数说明** + +- ``cncl_comm``:Int类型,必选。cnclComm,用于归约通信。 +- ``input``:Tensor类型,必选。矩阵乘的左矩阵。维度 ``[mat_m, mat_k]``。 +- ``filter``:Tensor类型,必选。矩阵乘的右矩阵。维度 ``[mat_n, mat_k]``。 +- ``bias``:Tensor类型,可选。矩阵乘的偏置。维度 ``[1, mat_n]``。 +- ``residual``:Tensor类型,可选。参差。维度 ``[mat_m, mat_n]``。 +- ``alpha``:Float32类型,必选。矩阵乘法系数。 +- ``beta``:Float32类型,必选。参差的系数。 +- ``block_m``:Int类型,可选。input在mat_m维度拆分的份数。 + +**返回值说明** + +- ``out``:Tensor类型,算子的输出,维度 ``[mat_m, mat_n]``。 + +**规格限制** + +- 数据类型支持: + + - cncl_comm: INT64 + - input:FP16、BF16、FP32 + - filter:FP16、BF16、FP32 + - bias:FP16、BF16、FP32 + - residual:FP16、BF16、FP32 + - out:FP16、BF16、FP32 + - alpha:FP32 + - beta:FP32 + - block_m: INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py`` 。 + +.. _torch_mlu_ops.moe_combine_result: + +torch_mlu_ops.moe_combine_result +------------------------------------- + +.. code:: + + torch_mlu_ops.moe_combine_result(input, reduce_weight, gather_ids, residual(optional), + cusum_token_count(optional), start_expert_id, expert_size, + bias(optional))->Tensor + + +**功能描述** + +融合算子。从不同的专家模型中收集计算结果 (gather) ,然后对这些结果进行加权组合 (reduce) 。 + +**公式说明** + +无。 + +**参数说明** + +- ``input``:Tensor类型,必选。算子的输入,维度 ``[num_tokens, hidden_size]`` 。其中 ``num_tokens = num_token * topk`` ,``num_token`` 表示token数量,``topk`` 表示每个token选取概率在前topk的专家,``hidden_size`` 表示隐向量维数。 +- ``reduce_weight``:Tensor类型,必选。token的权重信息,维度 ``[num_tokens]`` 。 +- ``gather_ids``:Tensor类型,必选。用于收集token的索引,维度 ``[num_tokens]`` 。 +- ``residual``:Tensor类型,可选,默认值为None。表示残差,维度 ``[num_token, hidden_size]`` 。 +- ``cusum_token_count``:Tensor类型,可选,默认值为None。token_count的前缀和,维度 ``[num_expert + 1]`` ,其中 ``num_expert`` 表示专家数量。 +- ``start_expert_id``:Int32类型,必选。表示EP模型下起始专家的位置信息。 +- ``expert_size``:Int32类型,必选。表示EP模型下实际使用的专家数量。 +- ``bias``:Tensor类型,可选,默认值为。表示偏置,维度 ``[num_expert, hidden_size]`` 。 + +**返回值说明** + +- output: Tensor类型,必选。维度为 ``[num_token, hidden_size]`` 。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、FP32、BF16 + - reduce_weight:FP32 + - gather_ids:INT32 + - residual:FP16、FP32、BF16 + - cusum_token_count:INT32 + - start_expert_id:INT32 + - expert_size:INT32 + - bias:FP16、FP32、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_combine_result.py`` 。 + +.. _torch_mlu_ops.moe_expand_input: + +torch_mlu_ops.moe_expand_input +------------------------------------------ + +.. code:: + + torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count, output, start_expert_id, expert_size)->Tensor + +**功能描述** + +根据 ``gather_idx`` 在 ``input`` 执行索引操作。其中 ``cusum_token_count`` 、 ``start_expert_id``、 ``expert_size`` 在EP模式下使用,用于确定当前专家处理token的index范围。 + +**公式说明** + +无。 + +**参数说明** + +- ``input``:Tensor类型,必选。算子的输入,输入维度 ``[token_num, hidden_size]``。其中 ``token_num`` 表示token数量, ``hidden_size`` 表示各个token的特征信息。 +- ``gather_idx``:Tensor类型,必选。输入维度 ``[token_num * topk]``, ``topk`` 表示每个token选取概率在前topk的专家进行计算。 +- ``cusum_token_count``:Tensor类型,可选,默认值为None。输入维度为 ``[expert_num + 1]`` ,其中 ``expert_num`` 表示专家的数量。 +- ``output``:Tensor类型,可选,默认值为None。算子的原位输出,维度为 ``[token_num * topk, hidden_size]`` 。 +- ``start_expert_id``:Int32类型,可选,默认值为0。表示EP模型下起始专家的位置信息。 +- ``expert_size``:Int32类型,可选,默认值为0。表示EP模式下实际使用的专家数量。 + +**返回值说明** + +- output,可选。维度为 ``[token_num * topk, hidden_size]``。 + +**规格限制** + +- 数据类型支持: + + - input:INT8,FP16,FP32,BF16 + - gather_idx:INT32 + - cusum_token_count:INT32 + - output:INT8,FP16,FP32,BF16 + - start_expert_id:INT32 + - expert_size:INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_expand_input.py`` 。 + +.. _torch_mlu_ops.moe_gen_idx: + +torch_mlu_ops.moe_gen_idx +------------------------------------------ + +.. code:: + + torch_mlu_ops.moe_gen_idx(expert_id, expert_num)->Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + +**功能描述** + +基于 ``expert_id`` 和 ``expert_num`` 生成 ``expand_idx`` 、 ``combine_idx`` 、 ``token_count`` 、 ``cusum_token_count`` 。 + +**公式说明** + +无。 + +**参数说明** + +- ``expert_id``:Tensor类型,必选。算子的输入,输入维度 ``[token_num, topk]``。其中 ``token_num`` 表示token数量, ``topk`` 表示每个token选取概率在前topk的专家进行计算。 +- ``expert_num``:Int32类型,必选。算子的输入,表示专家的数量。 + +**返回值说明** + +- outputs,Tuple[Tensor, Tensor, Tensor, Tensor]类型,可选。输出的4个Tensor分别是expand_idx、combine_idx、token_count和cusum_token_count,维度同上。 + +**规格限制** + +- 数据类型支持: + + - expert_id:INT32 + - expert_num:INT32 + - outputs:输出的4个Tensor类型均为INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_gen_idx.py`` 。 + +.. _torch_mlu_ops.moe_softmax_topk: + +torch_mlu_ops.moe_softmax_topk +------------------------------------------ + +.. code:: + + torch_mlu_ops.moe_softmax_topk(input,topk,num_expert_group,topk_group,normalize,mask,normed_by)->Tuple[torch.Tensor] + +**功能描述** + +MOE Route模块中的softmax topk融合算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``input``:Tensor类型,必选。算子的输入,输入维度 ``[..., expert_num]``。需要保证连续性。 +- ``topk``:INT32类型,每个token选取的专家数。 +- ``num_expert_group``:INT32类型,表示每个专家的分组个数。 +- ``topk_group``:INT32类型,每个专家组的topk个数。 +- ``normalize``:Bool类型,是否给 ``reduce_weight`` 添加归一化操作。 +- ``mask``:Tensor类型,可选。用于与softmax的计算结果相乘, 维度数量与input一致,后两维度大小与input一致。需要保证连续性。 +- ``normed_by``:Str类型,仅在normalize为True的时候使能,用于表示归一化的模式(topk_logit或softmax_logit)。 + +**返回值说明** + +- ``outputs`` : 可选,Tuple[Tensor, Tensor]类型。输出的两个Tensor分别是 ``reduce_weight`` 和 ``expert_id`` 。 + +**规格限制** + +- 数据类型支持: + + - input :Float,half,bfloat16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_softmax_topk.py`` 。 + + +.. _torch_mlu_ops.preload: + +torch_mlu_ops.preload +------------------------------------------ + +.. code:: + + torch_mlu_ops.preload(weight,size) + +**功能描述** + +预加载权重。 + +**公式说明** + +无。 + +**参数说明** + +- ``weight``:Tensor类型,必选。预加载的权重。 +- ``size``:int类型,必选。预加载的数据大小(字节)。 + +**返回值说明** + +无返回值。 + +**规格限制** + +无。 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_preload.py`` 。 + + +.. _torch_mlu_ops.quant_to_linear_cache: + +torch_mlu_ops.quant_to_linear_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.quant_to_linear_cache(key,value,key_cache,value_cache,key_cache_scale,value_cache_scale,context_lengths,max_context_len,packed,context_seq_offset,cache_bs_id,cache_seqlen_offset,quant_bit) + +**功能描述** + +将计算得到的K-V拼接到K-V cache上的INT8量化版本算子。 + + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[batch_size, max_context_len, head_num, head_size]``。 +- ``value``:Tensor类型,可选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[batch_size, max_context_len, head_num, head_size]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。维度为 ``[max_batch, head_num, cache_mem_len, head_size]``。其中max_batch要求大于batch_size,max_context_len为网络支持的最大sequences长度。 +- ``value_cache``:Tensor类型,可选。Transformer模块中保存 ``Value cache`` 的Tensor。维度为 ``[max_batch, head_num, cache_mem_len, head_size]``。 +- ``key_cache_scale``:Tensor类型,必选。Transformer模块中保存的 ``Key`` 的 ``scale`` 值。在context阶段,维度为 ``[batch_size, head_num, cache_mem_len]`` 或者 ``[batch_size, head_num, cache_mem_len, group_num]``。 +- ``value_cache_scale``:Tensor类型,可选。Transformer模块中保存的 ``Value`` 的 ``scale`` 值。在context阶段,维度为 ``[batch_size, head_num, cache_mem_len]`` 或者 ``[batch_size, head_num, cache_mem_len, group_num]``。 +- ``context_lengths``:Tensor类型,必选。为每批batch的sequences length长度。当 ``packed=false`` 时,维度为 ``[batch_size]``,每个batch里为对于的sequence length;当 ``packed=true`` 时,维度为 ``[batch_size+1]``,用sequence与sequence之间的长度差表示sequence长度。 +- ``max_context_len``:Int类型,必选。在context阶段,表示最大的sequence length。 +- ``packed``:Bool类型,必选。当 ``packed=true`` 时,表示开启pack模式,当 ``packed=false`` 时,表示关闭packed模式。 +- ``context_seq_offset``:Tensor类型,可选。表示context的seq的偏置,维度 ``[batch_size]``。 +- ``cache_bs_id``:Tensor类型,可选。表示cache的batch的偏置,维度 ``[batch_size]``。如果为空,则默认值为{0, 1, 2 ... batch - 1}。 +- ``cache_seqlen_offset``:Tensor类型,可选。表示cache的seq_lens的偏置,维度 ``[batch_size]``。如果为空,则默认值为{0, 0, 0 ... 0}。 +- ``quant_bit``:Int类型,可选。表示量化结果的位宽,可选8(表示int8)或4(表示int4),默认值为8。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16 + - value:FP16、BF16 + - key_cache:INT8 + - value_cache:INT8 + - key_cache_scale: FP32 + - value_cache_scale: FP32 + - context_lengths:INT32 + - max_context_len:INT32 + - packed:Bool + - context_seq_offset:INT32 + - cache_bs_id: INT32 + - cache_seqlen_offset: INT32 + - quant_bit : INT32 + - group_size: INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_quant_to_linear_cache.py`` 。 + +.. _torch_mlu_ops.quant_to_paged_cache: + +torch_mlu_ops.quant_to_paged_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.quant_to_paged_cache(key,value,key_cache,value_cache,key_cache_scale,value_cache_scale,slot_mapping) + +**功能描述** + +:ref:`torch_mlu_ops.reshape_paged_cache` 的INT8量化形式算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[num_tokens, head_num, head_size]``。 +- ``value``:Tensor类型,必选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[num_tokens, head_num, head_size]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。维度为 ``[num_blocks, head_num, block_size, head_size]``。 +- ``value_cache``:Tensor类型,必选。Transformer模块中保存 ``Value cache`` 的Tensor。维度为 ``[num_blocks, head_num, block_size, head_size]``。 +- ``key_cache_scale``:Tensor类型,必选。Transformer模块中保存 ``Key cache scale`` 的Tensor。维度为 ``[num_blocks, head_num, block_size]``。 +- ``value_cache_scale``:Tensor类型,必选。Transformer模块中保存 ``Value cache scale`` 的Tensor。维度为 ``[num_blocks, head_num, block_size]``。 +- ``slot_mapping``:Tensor类型,必选。表示Paged Attention中 ``K-V cache`` 存放位置。维度为 ``[num_tokens]``。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16, FP32 + - value:FP16、BF16, FP32 + - key_cache:INT8 + - value_cache:INT8 + - key_cache_scale:FP32 + - value_cache_scale:FP32 + - slot_mapping:INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/kernel_pytest/quant_paged_cache_kernel_test/quant_paged_cache_kernel_test.py``。 + +.. _torch_mlu_ops.offline_quant_to_linear_cache: + +torch_mlu_ops.offline_quant_to_linear_cache +------------------------------------------------- + +.. code:: + + torch_mlu_ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, context_lengths, max_context_len, quant_mode, packed, context_seq_offset, cache_bs_id, cache_seqlen_offset) + +**功能描述** + +:ref:`torch_mlu_ops.quant_to_linear_cache` 的离线量化形式算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。 如果使用pack模式, 维度为 ``[num_tokens, head_num_kv, head_size]``, 如果使用pad模式, 维度为 ``[batch, seqlen, head_num_kv, head_size]``。 +- ``value``:Tensor类型,可选。Transformer模块中计算得到的 ``Value``。 维度与key相同。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。维度为 ``[max_batch, head_num_kv, cache_mem_len, head_size]``。 +- ``value_cache``:Tensor类型,可选。Transformer模块中保存 ``Value cache`` 的Tensor。维度为与key_cache相同。 +- ``key_cache_quant_scale``:Tensor类型,必选。Transformer模块中保存 ``Key cache scale`` 的Tensor。当量化模式是per_channel时,维度为 ``[head_num_kv, head_size]``, 当量化模式是per_head时,维度是 ``[head_num_kv, cache_mem_len]``。 +- ``value_cache_quant_scale``:Tensor类型,可选。Transformer模块中保存 ``Value cache scale`` 的Tensor。维度与value_cache_quant_scale相同。 +- ``context_lengths``:Tensor类型,必选。表示context的长度。如果使用pack模式,维度为 ``[batch + 1]``,如果使用pad模式,维度为 ``[batch]``。 +- ``max_context_len``:Int类型,必选。表示context的最大长度。 +- ``quant_mode``:int类型,必选。表示量化模式。当该参数为True时,表示使用per_head量化,否则使用per_channel量化。 +- ``packed``:bool类型,必选。该值为True时,表示输入为pack模式,否则为pad模式。 +- ``context_seq_offset``:Tensor类型,可选。表示待量化的key或value的位置索引。维度为 ``[batch]``。 +- ``cache_bs_id``:Tensor类型,可选。表示key或value放入的cache的batch索引。维度为 ``[batch]``。 +- ``cache_seqlen_offset``:Tensor类型,可选。表示key或value放入cache的起始位置。维度为 ``[batch]``。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16, FP32 + - value:FP16、BF16, FP32 + - key_cache:INT8 + - value_cache:INT8 + - key_cache_quant_scale:FP32 + - value_cache_quant_scale:FP32 + - context_lengths:INT32 + - max_context_len:INT32 + - quant_mode:INT32 + - packed:BOOL + - context_seq_offset:INT32 + - cache_bs_id:INT32 + - cache_seqlen_offset:INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/kernel_pytest/test_offline_quant_to_linear_cache.py``。 + +.. _torch_mlu_ops.offline_quant_to_paged_cache: + +torch_mlu_ops.offline_quant_to_paged_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.offline_quant_to_paged_cache(key,value,key_cache_scale,value_cache_scale,slot_mapping,key_cache,value_cache) + +**功能描述** + +:ref:`torch_mlu_ops.quant_to_paged_cache` 的离线量化形式算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[num_tokens, head_num, head_size]``。 +- ``value``:Tensor类型,必选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[num_tokens, head_num, head_size]``。 +- ``key_cache_scale``:Tensor类型,必选。Transformer模块中保存 ``Key cache scale`` 的Tensor。维度为 ``[head_num, head_size]``。 +- ``value_cache_scale``:Tensor类型,必选。Transformer模块中保存 ``Value cache scale`` 的Tensor。维度为 ``[head_num, head_size]``。 +- ``slot_mapping``:Tensor类型,必选。表示Paged Attention中 ``K-V cache`` 存放位置。维度为 ``[num_tokens]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。维度为 ``[num_blocks, head_num, block_size, head_size]``。 +- ``value_cache``:Tensor类型,必选。Transformer模块中保存 ``Value cache`` 的Tensor。维度为 ``[num_blocks, head_num, block_size, head_size]``。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16, FP32 + - value:FP16、BF16, FP32 + - key_cache_scale:FP32 + - value_cache_scale:FP32 + - slot_mapping:INT32 + - key_cache:INT8 + - value_cache:INT8 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/kernel_pytest/test_offline_quant_to_paged_cache_kernel.py``。 + +.. _torch_mlu_ops.reshape_linear_cache: + +torch_mlu_ops.reshape_linear_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.reshape_linear_cache(key,value,key_cache,value_cache,context_lengths,max_context_len,packed,context_seq_offset,cache_bs_id,cache_seqlen_offset) + +**功能描述** + +将计算得到的K-V拼接到K-V cache上的算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[batch_size, max_context_len, head_num, head_size]``。 +- ``value``:Tensor类型,可选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[batch_size, max_context_len, head_num, head_size]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。维度为 ``[max_batch, head_num, cache_mem_len, head_size]``。其中 ``max_batch`` 要求大于 ``batch_size``,``max_context_len`` 为网络支持的最大sequences长度。 +- ``value_cache``:Tensor类型,可选。Transformer模块中保存 ``Value cache`` 的Tensor。维度为 ``[max_batch, head_num, cache_mem_len, head_size]``。 +- ``context_lengths``:Tensor类型,必选。为每批batch的sequences length长度。当 ``packed=false`` 时,维度为 ``[batch_size]``,每个batch里为对于的sequence length;当 ``packed=true`` 时,维度为 ``[batch_size+1]``,用sequence与sequence之间的长度差表示sequence长度。 +- ``max_context_len``:Int类型,必选。在context阶段,表示最大的sequence length。 +- ``packed``:Bool类型,必选。当 ``packed=true`` 时,表示开启pack模式,当 ``packed=false`` 时,表示关闭pack模式。 +- ``context_seq_offset``:Tensor类型,可选。表示context的seq的偏置,维度 ``[batch_size]``。 +- ``cache_bs_id``:Tensor类型,可选。表示cache的batch的偏置,维度 ``[batch_size]``。如果为空,则默认值为{0, 1, 2 ... batch - 1}。 +- ``cache_seqlen_offset``:Tensor类型,可选。表示cache的seq_lens的偏置,维度 ``[batch_size]``。如果为空,则默认值为{0, 0, 0 ... 0}。 + + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16 + - value:FP16、BF16 + - key_cache:FP16、BF16 + - value_cache:FP16、BF16 + - context_lengths:INT32 + - max_context_len:INT32 + - packed:Bool + - context_seq_offset:INT32 + - cache_bs_id: INT32 + - cache_seqlen_offset: INT32 + + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_reshape_linear_cache.py`` 。 + + +.. _torch_mlu_ops.reshape_paged_cache: + +torch_mlu_ops.reshape_paged_cache +------------------------------------------ + +.. code:: + + torch_mlu_ops.reshape_paged_cache(key,value,key_cache,value_cache,slot_mapping) + +**功能描述** + +:ref:`torch_mlu_ops.reshape_linear_cache` 的Paged形式算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``key``:Tensor类型,必选。Transformer模块中计算得到的 ``Key``。在context阶段,维度为 ``[num_tokens, head_num, head_size]``。 +- ``value``:Tensor类型,必选。Transformer模块中计算得到的 ``Value``。在context阶段,维度为 ``[num_tokens, head_num, head_size]``。 +- ``key_cache``:Tensor类型,必选。Transformer模块中保存 ``Key cache`` 的Tensor。维度为 ``[num_blocks, head_num, block_size, head_size]``。 +- ``value_cache``:Tensor类型,必选。Transformer模块中保存 ``Value cache`` 的Tensor。维度为 ``[num_blocks, head_num, block_size, head_size]``。 +- ``slot_mapping``:Tensor类型,必选。表示Paged Attention中 ``K-V cache`` 存放位置。维度为 ``[num_tokens]``。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - key:FP16、BF16 + - value:FP16、BF16 + - key_cache:FP16、BF16 + - value_cache:FP16、BF16 + - slot_mapping:INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_reshape_paged_cache.py`` 。 + + +.. _torch_mlu_ops.quantize: + +torch_mlu_ops.quantize +------------------------------------------ + +.. code:: + + torch_mlu_ops.quantize(x, scale, zero=None)-> torch.Tensor + +**功能描述** + +对输入tensor进行量化的算子。 + +**公式说明** + +* 公式如下: + + .. math:: + + \begin{array}{lcl} + scaled = x * scale\\ + output = scaled.round().clamp(-128, 127).to(torch.int8) + \end{array} + +* 公式中各参数描述如下: + + - ``x`` :quantize算子的输入。 + - ``scale``:量化过程的scale参数。 + - ``output``: quantize算子的输出。 + +**参数说明** + +- ``x`` :Tensor类型,必选。quantize算子的输入,维度是 ``[..., C]``,其中 ``C`` 是通道的维度。 +- ``scale`` :Tensor类型,必选。quantize算子的 ``scale``输入,维度是 ``[C]``。 +- ``zero`` :Tensor类型,必选。quantize算子的保留参数,填入 ``None``。 + +**返回值说明** + +- ``output``:Tensor类型,必选。算子的输出,维度和 ``x`` 一致。 + +**规格限制** + +- 数据类型支持: + + - x:FP32、FP16、BF16 + - scale:FP32 + - output:INT8 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_smooth_quant.py`` 。 + +.. _torch_mlu_ops.single_query_cached_kv_attn: + +torch_mlu_ops.single_query_cached_kv_attn +------------------------------------------ + +.. code:: + + torch_mlu_ops.single_query_cached_kv_attn(q_ori,k_cache,v_cache,out,block_tables,context_lens,k_cache_quant_scale,v_cache_quant_scale,alibi_slopes,max_context_len,windows_size_left,windows_size_right,softmax_scale,return_lse,kv_cache_quant_bit_size)->Tensor + +**功能描述** + +Decoder阶段的Single Query Cached KV Attention算子。 + +**公式说明** + +* Single Query Cached KV Attention标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + scores = batch\_matmul(q\_ori, k\_cache, transpose\_b=True) * softmax\_scale\\ + scores += ALiBi\\ + scores += causal\_mask\\ + attention = softmax(scores)\\ + output = batch\_matmul(attention, v\_cache)\\ + lse = logsumexp(scores)\\ + \end{array} + +* 公式中各参数描述如下: + + - ``q_ori`` :Single Query Cached KV Attention模块的输入。 + - ``ALiBi``:在线计算的ALiBi位置编码矩阵。 + - ``scores`` :Single Query Cached KV Attention模块第一个批量矩阵乘的输出。 + - ``causal_mask``:因果mask, 当seq_q>1时默认使能。 + - ``attention`` :Single Query Cached KV Attention模块经过sotmax后的输出。 + - ``output`` :Single Query Cached KV Attention模块第二个批量矩阵乘的输出,也是模块的最终输出。 + - ``lse`` :Single Query Cached KV Attention的lse输出。 + +* 公式中各函数描述如下: + + - ``batch_matmul`` :批量矩阵乘。 + - ``softmax`` :归一化指数函数。 + - ``logsumexp`` :对数求和指数函数。 + +**参数说明** + +- ``q_ori``:Tensor类型,必选。算子的输入。维度 ``[batch, seq_q, q_head_num, head_size_qk]``。 +- ``k_cache``:Tensor类型,必选。算子的输入。维度 ``[num_blocks, kv_head_num, block_size, head_size_qk]``。 +- ``v_cache``:Tensor类型,必选。算子的输入。维度 ``[num_blocks, kv_head_num, block_size, head_size_v]``。 +- ``out``:Tensor类型,可选。维度同input。 +- ``block_tables``:Tensor类型,必选。block的id与kv_cache之间的映射关系。维度 ``[bs, max_num_blocks_per_seq]``,非paged attention时,max_num_blocks_per_seq为1。 +- ``context_lens``:Tensor类型,必选。记录每个batch的实际seq长度,维度 ``[batch]``。 +- ``k_cache_quant_scale``:Tensor类型,可选。k_cache&v_cache为量化类型时k_cache的量化参数,维度 ``[num_blocks, kv_head_num, block_size]``表示per_token量化、 ``[num_blocks, kv_head_num, block_size, 1]`` 表示带group的per_token量化,但当前group_num必须为1, 以及 ``[kv_head_num, head_size_qk]`` 表示per_channel量化。 +- ``v_cache_quant_scale``:Tensor类型,可选。k_cache&v_cache为量化类型时v_cache的量化参数,维度同 ``k_cache_quant_scale``。 +- ``alibi_slope``:Tensor类型,可选。是一种相对位置编码方法,维度 ``[batch, q_head_num]``,类型必须是float32。 +- ``max_context_len``:Int64类型,必选。算子的最大seq长度。 +- ``window_size_left``:Int64类型,左侧窗口大小,未指定时设置-1。实际参与计算的key和value长度为 ``window_size_left + 1``。 +- ``window_size_right``:Int64类型,保留字段,设置-1即可。 +- ``softmax_scale``:Float32类型,必选。softmax之前的系数。 +- ``return_lse``:Bool类型,标记是否返回lse tensor。为True时当前仅支持seq_q=1。 +- ``kv_cache_quant_bit_size``:INT32类型,KV cache量化的位宽,可选8(表示int8)、4(表示int4)和-1(表示和input一样的数据类型),默认值为-1 + +**返回值说明** + +- ``output``:Tensor类型,可选。算子的输出,当out不为空时有此输出,维度 ``[batch, seq_q, q_head_num, head_size]``。 +- ``output_lse``:Tensor类型,可选。算子的lse输出,当 ``return_lse`` 为True时有此输出,维度 ``[batch, q_head_num, seq_q]``。 + +**规格限制** + +- 数据类型支持: + + - q_ori:FP16、BF16、FP32 + - k_cache:FP16、BF16、FP32、INT8、int4 + - v_cache:FP16、BF16、FP32、INT8、int4 + - block_tables:INT32 + - context_lens:INT32 + - k_cache_quant_scale:FP32 + - v_cache_quant_scale:FP32 + - alibi_slope:FP32 + - max_context_len:INT64 + - window_size_left:INT32 + - window_size_right:INT32 + - softmax_scale:FP32 + - output:FP16、BF16、FP32 + - output_lse:FP32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_single_query_cached_kv_attn.py`` 。 + +.. _torch_mlu_ops.smooth_quant_matmul: + +torch_mlu_ops.smooth_quant_matmul +------------------------------------------ + +.. code:: + + torch_mlu_ops.smooth_quant_matmul(a,a_scale,b,b_scale,dtype,bias,c,act_mode,alpha,beta,use_hp_active)-> torch.Tensor + +**功能描述** + +实现smooth_quant量化功能的GEMM算子。 + +**公式说明** + +* smooth_quant_matmul标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + if a_scale is not None: \\ + ab = matmul(a, b) * a_scale.outer(b_scale) + bias \\ + else: \\ + ab = matmul(a, b) * b_scale + bias \\ + acted = active(ab) \\ + output = alpha * acted + beta * c + \end{array} + + +* 公式中各参数描述如下: + + - ``a`` :矩阵乘运算的a矩阵。 + - ``a_scale`` : 输入的分token量化的scale系数 + - ``b`` :矩阵乘运算的b矩阵。 + - ``b_scale`` : 权重的分channel量化的scale系数 + - ``bias`` :偏置。 + - ``alpha`` :浮点标量,默认1.0f。 + - ``beta`` :浮点标量,默认1.0f。 + - ``c`` :第三个elemenwise加c矩阵。 + - ``ouput`` :GEMM返回值。 + +* 公式中各函数描述如下: + - ``matmul`` :矩阵乘操作。 + - ``active`` :激活操作。 + +**参数说明** + +- ``a``:Tensor类型,必选。GEMM算子的输入。维度 ``[M, K]``。 +- ``a_scale``:Tensor类型,可选。输入的分token量化的scale系数。维度 ``[N]``。 +- ``b``:Tensor类型,必选。GEMM算子的权重。维度 ``[N, K]``。 +- ``b_scale``:Tensor类型,可选。权重的分channel量化的scale系数。维度 ``[N]``。 +- ``dtype``:dype类型, ``output`` 的数据类型,必须是 ``torch.float`` 、 ``torch.half`` 、 ``torch.bfloat16`` 。 +- ``bias``:Tensor类型,可选。GEMM算子的偏置。维度 ``[N]``。 +- ``c`` : Tensor类型,可选。维度 ``[M, N]``。 +- ``act_mode`` : string类型,可选的激活函数:"relu"、"gelu"、"silu"、"none"。 +- ``alpha`` : float类型。默认值为1.0。 +- ``beta`` : float类型。默认值为1.0。 +- ``use_hp_active`` :bool类型,激活的计算方式,默认值为False。 + + +**返回值说明** + +- ``ouput``: Tensor类型, 维度 ``[M, N]``。 + +**规格限制** + + +- 数据类型支持: + + - a:INT8 + - a_scale:FP32 + - b: INT8 + - b_scale:FP32 + - bias:HALF、FP32、BF16 + - c:HALF、FP32、BF16 + - output:HALF、FP32、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_quant_ffn.py`` 。 + + +.. _torch_mlu_ops.smooth_quant_matmul_allreduce: + +torch_mlu_ops.smooth_quant_matmul_allreduce +------------------------------------------------ + +.. code:: + + torch_mlu_ops.smooth_quant_matmul_allreduce(cncl_comm, a, a_scale, b, b_scale, dtype, bias, c, alpha, beta, block_m) + +**功能描述** + +集成平滑量化功能的GEMM通信计算融合算子。 + +**公式说明** + +无。 + +**参数说明** + +- ``cncl_comm``:Int类型,必选。cnclComm,用于归约通信。 +- ``a``:Tensor类型,必选。GEMM算子的输入。维度 ``[M, K]``。 +- ``a_scale``:Tensor类型,必选。输入的scale系数。维度 ``[M]``。 +- ``b``:Tensor类型,必选。GEMM算子的权重。维度 ``[N, K]``。 +- ``b_scale``:Tensor类型,必选。权重的scale系数。维度 ``[N]``。 +- ``dtype``:Torch.dtype类型,必选。返回值类型,可选的类型: ``torch.half``、 ``torch.bfloat16``, ``torch.float32``。 +- ``bias``:Tensor类型,可选。GEMM算子的偏置。维度 ``[N]``。 +- ``c``:Tensor类型,可选。残差。维度 ``[M, N]``。 +- ``alpha``:float类型,可选。默认1.0f。 +- ``beta``:float类型,可选。默认1.0f。 +- ``block_m``:Int类型,可选。a在M维拆分的份数。 + +**返回值说明** + +- ``d``: Tensor类型, 维度 ``[M, N]``。 + +**规格限制** + +- 参数限制: + + - M/N/K: M<=INT32_MAX, NK<=INT32_MAX, K>=16, N>=16。 + - a_scale/b_scale: 维度分别是 ``[M]`` 和 ``[N]``。 + +- 数据类型支持: + + - cncl_comm:INT64 + - a:HALF、FP32、BF16、INT8 + - a_scale:FP32 + - b:INT8 + - b_scale:FP32 + - bias:HALF、FP32、BF16 + - c:HALF、FP32、BF16 + - block_m: INT32 + - d:HALF、FP32、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/allreduce_case/test_quant_matmul_all_reduce.py`` 。 + +.. _torch_mlu_ops.swap_blocks: + +torch_mlu_ops.swap_blocks +------------------------------------------ + +.. code:: + + torch_mlu_ops.swap_blocks(dst, src, block_mapping) + +**功能描述** + +将 ``src`` 中源位置的BLOCK复制到 ``dst`` 的指定位置,拷贝方向可以是:HostToDev、DevToHost、DevToDev。 + +**公式说明** + +无。 + +**参数说明** + +- ``src``:Vector类型,必选。输入维度 ``[num_layers]``。其中 ``num_layers`` 表示layer数量, ``[num_heads, block_size, head_size]`` 整体大小作为每次需要copy的BLOCK大小。 ``num_blocks`` 表示BLOCK数量; ``block_size`` 表示输入序列中的特定块或子序列大小; ``num_heads`` 表示transformer模块中的头数; ``head_size`` 表示每个头的维度。 +- ``dst``:Vector类型,必选。输入维度 ``[num_layers]``。其中每个参数含义与 ``src`` 一致。 +- ``block_mapping``:Map>类型,必选。第一个int64表示需要copy的源BLOCK位置(小于num_blocks),第二个int64表示将源BLOCK copy到的目标位置(小于num_blocks), ``vector`` 表示一个源BLOCK可以copy到多个目标位置。 + +**返回值说明** + +- dst,原位操作。 + +**规格限制** + +- 数据类型支持: + + - src:FP16、FP32、BF16、INT8、UINT8、INT16、INT32、INT64 + - dst:FP16、FP32、BF16、INT8、UINT8、INT16、INT32、INT64 + - block_mapping:INT64 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_swap_blocks.py`` 。 + +.. _torch_mlu_ops.weight_only_quant_matmul: + +torch_mlu_ops.weight_only_quant_matmul +------------------------------------------ + +.. code:: + + torch_mlu_ops.weight_only_quant_matmul(a,b,scale,zero,bias,c,act_mode,quant_bit_size,alpha,beta,use_hp_active)-> torch.Tensor + +**功能描述** + +实现仅权重量化功能的GEMM算子。 + +**公式说明** + +* weight_only_quant_matmul标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + ab = matmul(a, b) * scale + bias\\ + acted = active(ab)\\ + output = alpha * acted + beta * c + \end{array} + + +* 公式中各参数描述如下: + + - ``a`` :矩阵乘运算的a矩阵。 + - ``b`` :矩阵乘运算的b矩阵。 + - ``scale`` : 权重的分channel scale系数。 + - ``bias`` :偏置。 + - ``alpha`` :浮点标量,默认1.0f。 + - ``beta`` :浮点标量,默认1.0f。 + - ``c`` :第三个elemenwise加c矩阵。 + - ``ouput`` :GEMM返回值。 + +* 公式中各函数描述如下: + - ``matmul`` :矩阵乘操作。 + - ``active`` :激活操作。 + +**参数说明** + +- ``a``:Tensor类型,必选。GEMM算子的输入。维度 ``[M, K]``。 +- ``b``:Tensor类型,必选。GEMM算子的权重。当 ``quant_bit_size = 8`` ,维度 ``[N, K]``,当 ``quant_bit_size = 8`` ,维度 ``[N, K//2]``。 +- ``scale``:Tensor类型,可选。当group_wise量化时,维度为 ``[N, G]``,G表示Group_num,Group_size为K//Group_num,Group_size必须为64、128、256或512,类型必须与a相同。当per_channel量化时,维度为 ``[N]``,类型必须是float, +- ``zero``:Tensor类型,目前不开放,设置为None。 +- ``bias``:Tensor类型,可选。GEMM算子的偏置。维度 ``[N]``。 +- ``c`` : Tensor类型,可选。维度 ``[M, N]``。 +- ``act_mode`` : string类型,可选的激活函数:"relu"、"gelu"、"silu"、"none"。 +- ``quant_bit_size`` : int类型,权重 ``b`` 的位宽。 +- ``alpha`` : float类型。默认值为1.0。 +- ``beta`` : float类型。默认值为1.0。 +- ``use_hp_active`` : bool类型,激活的计算方式,默认值为False。 + + +**返回值说明** + +- ``ouput``: Tensor类型, 维度 ``[M, N]``。 + +**规格限制** + + +- 数据类型支持: + + - a:HALF、FP32、BF16、INT8 + - b: INT8 + - scale:当使用group_wise量化时,类型与a相同,使用per_channel量化时,类型为FLOAT32 + - bias:HALF、FP32、BF16 + - c:HALF、FP32、BF16 + - output:HALF、FP32、BF16 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_quant_ffn.py`` 。 + +.. _torch_mlu_ops.moe_active: + +torch_mlu_ops.moe_active +------------------------------------------ + +.. code:: + + torch_mlu_ops.moe_active(input, act_mode, is_gated, output, bias, cusumt_token_count, start_expert_id, expert_size)-> torch.Tensor + +**功能描述** + +激活算子。 + +**公式说明** + +* 公式如下: + +无。 + +**参数说明** + +- ``input``:Tensor类型,必选。输入维度 ``[..., C]``。 其中 ``C`` 为通道的个数,在最低维。 +- ``act_mode``:string类型。可以是 ``silu`` 或者 ``gelu`` 。 +- ``is_gated``:Bool类型。当 ``is_gated=true`` 时,为gated激活。 +- ``output``:Tensor类型,可选。当 ``is_gated=true`` 时,维度为 ``[..., C//2]``,否则与input形状一致。 +- ``bias``:Tensor类型,可选。输入维度是 ``[expert_num, C]`` 。其中 ``expert_num`` 为专家的总数。 +- ``cusum_token_count``:Tensor类型,可选,默认值为None。token_count的前缀和,维度 ``[num_expert + 1]`` 。 +- ``start_expert_id``:Int32类型,必选。表示EP模型下起始专家的位置信息。 +- ``expert_size``:Int32类型,必选。表示EP模型下实际使用的专家数量。 + +**返回值说明** + +- ``output``:激活后返回的tensor。 + +**规格限制** + +- 数据类型支持: + + - input:FP32、FP16、BF16 + - output:FP32、FP16、BF16 + - bias:FP32、FP16、BF16 + - cusumt_token_count:INT32 + - start_expert_id:INT32 + - expert_size:INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_add_bias_activation.py`` 。 + +.. _torch_mlu_ops.moe_cast_gating: + +torch_mlu_ops.moe_cast_gating +------------------------------------------ + +.. code:: + + torch_mlu_ops.moe_cast_gating(input, weight) -> torch.Tensor + +**功能描述** + +将input的数据类型强转成float类型,再和weight进行矩阵乘操作。 + +**公式说明** + +* 公式如下: + +无。 + +**参数说明** + +- ``input``:Tensor类型,必选。输入维度 ``[..., hidden_size]``。 其中 ``hidden_size`` 表示各个token的特征信息。 该输入必须连续。 +- ``weight``:Tensor类型,必选。输入维度 ``[expert_num, hidden_size]`` 。 其中 ``expert_num`` 表示专家数量。 该输入必须连续。 + +**返回值说明** + +- ``output``:Tensor类型。维度为 ``[..., expert_num]``。 + +**规格限制** + +- 数据类型支持: + + - input:FP16、BF16 + - weight:FP32 + - output:FP32 + +- 规模支持: + + - hidden_size:小于或等于16384 + - expert_num:小于或等于128 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_moe_cast_gating.py`` 。 + +.. _torch_mlu_ops.update_out_and_lse: + +torch_mlu_ops.update_out_and_lse +------------------------------------------ + +.. code:: + + torch_mlu_ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)-> None + +**功能描述** + +根据seq_offset, cu_seqs, block_cu_seqs,将block_out和block_lse的结果更新到out和lse中。 + +**公式说明** + +* 公式如下: + .. math:: + + \begin{array}{lcl} + out = out - sigmoid(block_lse - lse) * (out - block_out)\\ + lse = lse - logsigmoid(lse - block_lse)\\ + \end{array} + +**参数说明** + +- ``out``:Tensor类型,必选。输入维度 pad模式: ``[batch, max_seq_len, head_num, head_size]`` pack模式: ``[total_seq_lens, head_num, head_size]``。 +- ``lse``:Tensor类型,必选。输入维度 pad模式: ``[batch, head_num, max_seq_len]`` 不支持pack模式。 +- ``block_out``:Tensor类型,必选。输入维度 pad模式: ``[batch, block_seq_len, head_num, head_size]`` pack模式: ``[total_block_seq_lens, head_num, head_size]``。 +- ``block_lse``:Tensor类型,必选。输入维度 pad模式: ``[batch, head_num, block_seq_len]`` 不支持pack模式。 +- ``seq_offset``:Tensor类型,可选。输入维度是 ``[batch]`` 。 +- ``cu_seqs``:Tensor类型,可选。输入维度是 ``[batch + 1]`` ,仅限pack模式。 +- ``block_cu_seqs``:Tensor类型,可选。输入维度是 ``[batch + 1]`` ,仅限pack模式。 + +**返回值说明** + +无返回值。 + +**规格限制** + +- 数据类型支持: + + - out:FP32、FP16、BF16 + - lse:FP32、FP16、BF16 + - block_out:FP32、FP16、BF16 + - block_lse:FP32、FP16、BF16 + - seq_offset:INT32 + - cu_seqs:INT32 + - block_cu_seqs:INT32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_update_out_and_lse.py`` 。 + +.. _torch_mlu_ops.single_query_mixed_cached_kv_attn: + +torch_mlu_ops.single_query_mixed_cached_kv_attn +----------------------------------------------------------- + +.. code:: + + torch_mlu_ops.single_query_mixed_cached_kv_attn(q_ori,k_cache_lp,v_cache_lp,k_cache_hp,v_cache_hp,out,block_tables_lp,block_tables_hp,context_lens_lp,context_lens_hp,k_cache_quant_scale_lp,v_cache_quant_scale_lp,k_cache_quant_scale_hp,v_cache_quant_scale_hp,alibi_slopes,max_context_len_lp,max_context_len_hp,softmax_scale,return_lse,kv_cache_quant_bit_size,kv_cache_quant_bit_size_lp,kv_cache_quant_bit_size_hp)->Tensor + +**功能描述** + +Decoder阶段的context_seq分段量化,采用低精度与高精度融合的Single Query Cached KV Attention算子。 + +**公式说明** + +* context_seq分段量化的融合的Single Query Cached KV Attention标准结构公式如下: + + .. math:: + + \begin{array}{lcl} + lse1, attn\_out1 = single\_query\_cached\_kv\_attn(q\_ori, k\_cache\_lp, v\_cache\_lp, block\_tables\_lp,...)\\ + lse2, attn\_out2 = single\_query\_cached\_kv\_attn(q\_ori, k\_cache\_hp, v\_cache\_hp, block\_tables\_hp,...)\\ + output = update\_out\_and\_lse(lse1, attn\_out1, lse2, attn\_out2)\\ + \end{array} + +* 公式中各参数描述如下: + + - ``q_ori`` :Single Query Cached KV Attention模块的输入。 + - ``attn_out1``、 ``attn_out1`` :两段Single Query Cached KV Attention的输出。 + - ``lse1``、 ``lse2`` :两段Single Query Cached KV Attention的lse输出。 + - ``output``:最终的拼接输出。 + +* 公式中各函数描述如下: + + - ``single_query_cached_kv_attn`` :Single Query Cached KV Attention算子。 + - ``update_out_and_lse``: update_out_and_lse算子。 + +**参数说明** + +- ``q_ori``:Tensor类型,必选。算子的输入。维度 ``[batch, seq_q, q_head_num, head_size]``。 +- ``k_cache_lp``:Tensor类型,必选。算子的输入, 第一段低精度context对应的cache。维度 ``[num_blocks_lp, kv_head_num, block_size_lp, head_size]``。 +- ``v_cache_lp``:Tensor类型,必选。算子的输入, 第一段低精度context对应的cache。维度 ``[num_blocks_lp, kv_head_num, block_size_lp, head_size]``。 +- ``k_cache_hp``:Tensor类型,必选。算子的输入, 第二段高精度context对应的cache。维度 ``[num_blocks_hp, kv_head_num, block_size_hp, head_size]``。 +- ``v_cache_hp``:Tensor类型,必选。算子的输入, 第二段高精度context对应的cache。维度 ``[num_blocks_hp, kv_head_num, block_size_hp, head_size]``。 +- ``out``:Tensor类型,可选。维度同input。 +- ``block_tables_lp``:Tensor类型,必选。低精度部分block的id与kv_cache之间的映射关系。维度 ``[bs, max_num_blocks_per_seq_lp]``,非paged attention时,max_num_blocks_per_seq_lp为1。 +- ``block_tables_hp``:Tensor类型,必选。高精度部分block的id与kv_cache之间的映射关系。维度 ``[bs, max_num_blocks_per_seq_hp]``,非paged attention时,max_num_blocks_per_seq_hp为1。 +- ``context_lens_lp``:Tensor类型,必选。记录低精度部分每个batch的实际seq长度,维度 ``[batch]``。 +- ``context_lens_hp``:Tensor类型,必选。记录高精度部分每个batch的实际seq长度,维度 ``[batch]``。 +- ``k_cache_quant_scale_lp``:Tensor类型,可选。k_cache_lp&v_cache_lp为量化类型时k_cache_lp的量化参数,维度 ``[num_blocks_lp, kv_head_num, block_size_lp]``表示per_token量化、 ``[num_blocks_lp, kv_head_num, block_size_lp, 1]`` 表示带group的per_token量化,但当前group_num必须为1, 以及 ``[kv_head_num, head_size]`` 表示per_channel量化。 +- ``v_cache_quant_scale_lp``:Tensor类型,可选。k_cache_lp&v_cache_lp为量化类型时v_cache_lp的量化参数,维度同 ``k_cache_quant_scale``。 +- ``k_cache_quant_scale_hp``:Tensor类型,可选。k_cache_hp&v_cache_hp为量化类型时k_cache_hp的量化参数,维度 ``[num_blocks_hp, kv_head_num, block_size_hp]``表示per_token量化、 ``[num_blocks_hp, kv_head_num, block_size_hp, 1]`` 表示带group的per_token量化,但当前group_num必须为1, 以及 ``[kv_head_num, head_size]`` 表示per_channel量化。 +- ``v_cache_quant_scale_hp``:Tensor类型,可选。k_cache_hp&v_cache_hp为量化类型时v_cache_hp的量化参数,维度同 ``k_cache_quant_scale``。 +- ``alibi_slope``:Tensor类型,可选。是一种相对位置编码方法,维度 ``[batch, q_head_num]``,类型必须是float32。 +- ``max_context_len_lp``:Int64类型,必选。算子低精度部分的最大seq长度。 +- ``max_context_len_hp``:Int64类型,必选。算子高精度部分的最大seq长度。 +- ``softmax_scale``:Float32类型,必选。softmax之前的系数。 +- ``return_lse``:Bool类型,标记是否返回lse tensor。 +- ``kv_cache_quant_bit_size_lp``:INT32类型,低精度部分KV cache量化的位宽,可选8(表示int8)、4(表示int4)和-1(表示和input一样的数据类型)。 +- ``kv_cache_quant_bit_size_hp``:INT32类型,高精度部分KV cache量化的位宽,可选8(表示int8)、4(表示int4)和-1(表示和input一样的数据类型)。 + +**返回值说明** + +- ``output``:Tensor类型,可选。算子的输出,当out不为空时有此输出,维度 ``[batch, seq_q, q_head_num, head_size]``。 +- ``output_lse``:Tensor类型,可选。算子的lse输出,当 ``return_lse`` 为True时有此输出,维度 ``[batch, q_head_num, seq_q]``。 + +**规格限制** + +- 数据类型支持: + + - q_ori:FP16、BF16、FP32 + - k_cache_lp:FP16、BF16、FP32、INT8、int4 + - v_cache_lp:FP16、BF16、FP32、INT8、int4 + - k_cache_hp:FP16、BF16、FP32、INT8、int4 + - v_cache_hp:FP16、BF16、FP32、INT8、int4 + - block_tables_lp:INT32 + - block_tables_hp:INT32 + - context_lens_lp:INT32 + - context_lens_hp:INT32 + - k_cache_quant_scale_lp:FP32 + - v_cache_quant_scale_lp:FP32 + - k_cache_quant_scale_hp:FP32 + - v_cache_quant_scale_hp:FP32 + - alibi_slope:FP32 + - max_context_len_lp:INT64 + - max_context_len_hp:INT64 + - softmax_scale:FP32 + - output:FP16、BF16、FP32 + - output_lse:FP32 + +**使用示例** + +请参考 ``torch_mlu_ops/tests/ops_pytest/test_single_query_mixed_cached_kv_attn.py`` 。 diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_6_best_practice/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_6_best_practice/index.rst new file mode 100755 index 0000000..5355547 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_6_best_practice/index.rst @@ -0,0 +1,8 @@ +.. _最佳实践: + +最佳实践 +========================== + +.. toctree:: + + tmo_best_practice diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_6_best_practice/tmo_best_practice.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_6_best_practice/tmo_best_practice.rst new file mode 100755 index 0000000..219c10a --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_6_best_practice/tmo_best_practice.rst @@ -0,0 +1,200 @@ +.. _tmo_benchmark: + +Torch-MLU-Ops性能benchmark测试 +========================================== + +Torch-MLU-Ops benchmark测试脚本为用户提供了进行算子性能测试的便捷入口。 + +用户可通过以下命令获取各个参数的含义。 + +:: + + python3 benchmark_xxx.py --help + + +各个参数含义如下: + +- ``-h, --help``:显示帮助信息。 +- ``--repeat_times, REPEAT_TIMES``:配置测试重复次数。 +- ``--csv``:将报告数据输出到csv文件。 +- ``-o O``:当配置了 ``--csv`` 参数后,指定csv文件的输出目录。 + +测试命令示例如下: + +:: + + python3 benchmark_active.py --repeat_times 10 --csv -o './active/' + +支持如下算子: + + + .. table:: Torch-MLU-Ops benchmark测试支持算子 + :name: spporttedops + :class: longtable + :widths: 10 + + +------------------------------------+ + | op_name | + +====================================+ + | active | + +------------------------------------+ + | apply_rotary | + +------------------------------------+ + | attention_project | + +------------------------------------+ + | combine_result | + +------------------------------------+ + | ffn | + +------------------------------------+ + | flash_attn | + +------------------------------------+ + | fused_moe | + +------------------------------------+ + | fused_norm_attention_project | + +------------------------------------+ + | fused_norm_residual_ffn | + +------------------------------------+ + | fused_rms_norm | + +------------------------------------+ + | group_gemm | + +------------------------------------+ + | matmul | + +------------------------------------+ + | offline_quant_to_linear_cache | + +------------------------------------+ + | per_token_smooth_quantize | + +------------------------------------+ + | preload | + +------------------------------------+ + | quantize | + +------------------------------------+ + | quant_to_linear_cache | + +------------------------------------+ + | reshape_linear_cache | + +------------------------------------+ + | reshape_paged_cache | + +------------------------------------+ + | single_query_cached_kv_attn | + +------------------------------------+ + | smooth_quant_matmul | + +------------------------------------+ + | weight_only_quant_matmul | + +------------------------------------+ + + + +.. _gen_case: + +Torch-MLU-Ops gen case +========================================== + +Torch-MLU-Ops 通过把算子参数信息保存成pt文件实现gen case功能。 + +gen case功能提供4个环境变量,如下所示: + +- ``TMO_GEN_CASE``:值为1时开启dump case功能。 +- ``TMO_GEN_CASE_DUMP_DATA``:值为1时,所有参数均保存真值;``TMO_GEN_CASE_DUMP_DATA`` 不存在或为0时,则基础类型:int16、int32、int64类型的tensor,以及元素为整型的list保存真值。 +- ``TMO_GEN_CASE_OP_NAME``:指定需要dump case的算子名称,``torch_mlu_ops/_ops.py`` 文件的 ``__all__`` 列表记录的算子大部分已支持;如果不存在,则dump 程序运行过程中遇到的所有的TMO算子信息。 +- ``TMO_GEN_CASE_PATH``:指定case文件的输出目录,默认保存在当前目录下。 + + +使用示例: + +:: + + export TMO_GEN_CASE=1 + python test_group_gemm.py + bs: 1, seq_len: 5, k: 512, n: 512, experts_num: 8, topk: 2, dtype: torch.float16, idx_mode: False, has_bias: False testing... + [torch_mlu_ops] dump case ====> tmo_gen_case/group_gemm/group_gemm_1734940034732854.pt + bs: 1, seq_len: 5, k: 512, n: 512, experts_num: 8, topk: 2, dtype: torch.float16, idx_mode: True, has_bias: True testing... + [torch_mlu_ops] dump case ====> tmo_gen_case/group_gemm/group_gemm_1734940034736277.pt + +:: + + export TMO_GEN_CASE=1 + export TMO_GEN_CASE_DUMP_DATA=1 + export TMO_GEN_CASE_OP_NAME="group_gemm" + python test_moe.py + bs: 3, seq_len: 5, hidden_size: 512, inner_size: 768, experts_num: 8, start_expert_id: 4, expert_size: 2, topk: 2, gated: True, renormalize: True, data_type: torch.float16, act_mode: sil + u testing... + [torch_mlu_ops] dump case ====> tmo_gen_case/group_gemm/group_gemm_1735008239641861.pt + [torch_mlu_ops] dump case ====> tmo_gen_case/group_gemm/group_gemm_1735008239672238.pt + +.. _pt_example: + +Torch-MLU-Ops 运行pt格式测例 +========================================== + +使用 ``tests/ops_pytest/run_gen_case.py`` 文件运行pt格式的case。 + +:: + + python3 run_gen_case.py --help + + +各个参数含义如下: + +- ``-h, --help``:显示帮助信息。 +- ``--case_path CASE_PATH``:指定pt文件的路径。 +- ``--detail``:不运行case,只打印case信息。 + +测试命令示例如下: + +:: + + python run_gen_case.py --case_path tmo_gen_case/group_gemm/group_gemm_1735008239672238.pt + [torch_mlu_ops] run tmo_gen_case/group_gemm/group_gemm_1735008239672238.pt ... + [torch_mlu_ops] run tmo_gen_case/group_gemm/group_gemm_1735008239672238.pt successfully + +:: + + python run_gen_case.py --detail --case_path tmo_gen_case/group_gemm/group_gemm_1735008239672238.pt + op: group_gemm + dump_data: True + a: tensor([[-7.5928e-02, 0.0000e+00, -2.9156e+01, ..., -0.0000e+00, + 1.3725e+02, -3.6625e+02], + [-8.6609e-02, 4.9902e-01, 0.0000e+00, ..., -1.4185e-01, + -2.5859e+01, -5.4844e+01], + [ 4.5031e+01, 0.0000e+00, 8.7500e-01, ..., -9.3359e-01, + 1.3960e+03, -0.0000e+00], + ..., + [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, + 0.0000e+00, 0.0000e+00], + [ 2.0000e+00, -6.6250e+00, 5.1200e+02, ..., -6.5664e+00, + 5.1200e+02, 5.8984e+00], + [-5.1200e+02, -5.5352e+00, 2.0000e+00, ..., 0.0000e+00, + 0.0000e+00, 0.0000e+00]], device='mlu:0', dtype=torch.float16) + b: tensor([[[ 8.9502e-01, 2.5488e-01, -9.5947e-01, ..., 1.0635e+00, + 2.3848e+00, -1.8872e-01], + [-1.4636e-01, -1.2900e+00, -3.3154e-01, ..., 1.3516e+00, + 5.0244e-01, 6.2012e-01], + [ 1.3877e+00, 1.3916e-01, -6.0303e-01, ..., -1.6094e+00, + -5.7861e-01, 1.3838e+00], + ..., + [ 1.2350e-04, 4.5502e-02, 6.7041e-01, ..., 9.9316e-01, + -4.0845e-01, -7.1777e-01], + [-8.8818e-01, -1.3330e-01, -2.3779e-01, ..., -2.8662e-01, + 1.5908e+00, -7.5098e-01], + [-1.0459e+00, -1.2305e+00, 8.1934e-01, ..., -1.6914e+00, + -1.5176e+00, -1.4102e+00]], + + [[-1.8301e+00, 7.3975e-01, -6.4575e-02, ..., 4.7534e-01, + -2.6953e-01, -1.2490e+00], + [-1.9775e-01, -1.9775e+00, -1.2598e+00, ..., -1.8860e-01, + 1.3220e-01, 9.9219e-01], + [-8.6768e-01, -9.7949e-01, 2.9712e-01, ..., 1.0332e+00, + 7.2266e-02, -9.5068e-01], + ..., + [ 3.0615e-01, 2.4023e+00, 1.1484e+00, ..., 1.0166e+00, + 4.0161e-01, -1.0527e+00], + [ 3.8965e-01, -3.8867e-01, -1.0420e+00, ..., -1.3867e+00, + 1.1553e+00, -8.5938e-01], + [-8.1152e-01, 2.6904e-01, -1.5840e+00, ..., 1.0518e+00, + 2.1855e+00, -1.6035e+00]]], device='mlu:0', dtype=torch.float16) + m_list: tensor([6, 4], device='mlu:0', dtype=torch.int32) + expand_idx: None + c: None + alpha: None + beta: None + max_m: 15 + bias: None diff --git a/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_7_FAQ/index.rst b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_7_FAQ/index.rst new file mode 100755 index 0000000..c09199d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/docs/user_guide/source/tmo_7_FAQ/index.rst @@ -0,0 +1,32 @@ +.. _FAQ: + +FAQ +========================================== + + +Cambricon Torch-MLU-Ops和BangTransformer的关系 +------------------------------------------------------------------------ + +Torch-MLU-Ops从BangTransformer演化而来,保留了BangTransformer提供的PyTorch第三方算子API能力,重新命名为Torch-MLU-Ops。原先BangTransformer的网络评测功能由Cambricon vLLM继承。 + + +Cambricon Torch-MLU-Ops和Cambricon PyTorch的关系 +------------------------------------------------------------------------ + +Cambricon PyTorch为能够运行在寒武纪设备上的PyTorch框架组件。Cambricon Torch-MLU-Ops为寒武纪PyTorch第三方算子库。在运行时,Torch-MLU-Ops需要依赖Cambricon PyTorch环境。 + + +Cambricon Torch-MLU-Ops和Cambricon MLU-Ops的关系 +------------------------------------------------------------------------ + +Cambricon Torch-MLU-Ops为寒武纪PyTorch第三方算子库,Cambricon MLU-Ops为寒武纪C++算子库,两者无直接关系。 + + +Cambricon Torch-MLU-Ops中 flash_attention 算子和Cambricon FlashAttention以及Cambricon PyTorch源生SDPA算子的关系 +-------------------------------------------------------------------------------------------------------------------------- + +Cambricon Torch-MLU-Ops中提供的 flash_attention 算子为专门为推理场景下定制优化的算子,后端调用的是CNNL_Extra组件中的 ``cnnlScaledDotProductAttn_v3`` 算子。 + +Cambricon FlashAttention和Cambricon PyTorch源生SDPA算子均针对训练场景,后端调用的是CNNL_Extra组件中的 ``cnnlFlashAttnFwd`` 和 ``cnnlFlashAttnBwd`` 算子。 + +针对推理场景下的flash attention算子测试,应当使用Cambricon Torch-MLU-Ops中的 flash_attention 算子,以获得MLU计算卡上更好的性能表现。 diff --git a/torch_mlu_ops-v1.3.2/requirements.txt b/torch_mlu_ops-v1.3.2/requirements.txt new file mode 100644 index 0000000..210a0c1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/requirements.txt @@ -0,0 +1,16 @@ +expecttest +hypothesis +jinja2 +math +ninja +numpy==1.26.4 +packaging +pandas +pathlib +pytest +random +regex +setuptools==65.5.0 +shutil +tabulate +unittest diff --git a/torch_mlu_ops-v1.3.2/setup.py b/torch_mlu_ops-v1.3.2/setup.py new file mode 100644 index 0000000..46501d8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/setup.py @@ -0,0 +1,225 @@ +import os +import sys +import shutil +import torch +import subprocess +import platform +from pathlib import Path +from packaging.version import Version +from jinja2 import Environment, FileSystemLoader +import distutils.command.clean # pylint: disable=C0411 +import setuptools.command.install # pylint: disable=C0411 +from setuptools import setup, find_packages, distutils # pylint: disable=C0411 +from torch_mlu.utils.cpp_extension import BuildExtension, MLUExtension + + +def get_tmo_version(tmo_version_file): + if (not os.path.isfile(tmo_version_file)): + print("Failed to find version file: {0}".format(tmo_version_file)) + sys.exit(1) + with open(tmo_version_file, 'r') as f: + lines = f.readlines() + for line in lines: + if "TMO_VERSION" in line: + return Version(line.split("=")[1].strip()).base_version + raise RuntimeError(f"Can not find version from {tmo_version_file}.") + +def get_source_files(dir_path, include_kernel: bool = True): + source_list = [] + for file in Path(dir_path).rglob('*.cpp'): + source_list.append(file.as_posix()) + if include_kernel: + for file in Path(dir_path).rglob('*.mlu'): + source_list.append(file.as_posix()) + return source_list + +def get_include_dirs(dir_path): + out_dirs = [] + for root, dirs, files in os.walk(dir_path): + for dir in dirs: + out_dirs.append(os.path.join(root, dir)) + return out_dirs + +def _check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + + +class install(setuptools.command.install.install): + def run(self): + super().run() + +class Build(BuildExtension): + kernel_args = () + kernel_kwargs = {} + + @classmethod + def build_kernel(cls, *args, **kwargs): + cls.kernel_args = args + cls.kernel_kwargs = kwargs + pass + + @classmethod + def pre_run(cls, func, *args, **kwargs): + cls.build_kernel = func + cls.kernel_args = args + cls.kernel_kwargs = kwargs + + def run(self): + Build.build_kernel(*Build.kernel_args, **Build.kernel_kwargs) + super().run() + +class Clean(distutils.command.clean.clean): + def run(self): + super().run() + for root, dirs, files in os.walk('.'): + for dir in dirs: + if dir == '__pycache__' or dir == 'build': + shutil.rmtree(os.path.join(root, dir)) + for file in files: + if file.endswith('.pyc') or file.endswith('.pyo') or file.endswith('.so'): + os.remove(os.path.join(root, file)) + + +def main(): + BUILD_KERNELS_WITH_CMAKE = _check_env_flag('TMO_BUILD_KERNELS_WITH_CMAKE') + DEBUG_MODE = _check_env_flag('TMO_DEBUG_MODE') + CXX_FLAGS = [] + CNCC_FLAGS = [] + + # base_dir: bangtransformer + base_dir = os.path.dirname(os.path.abspath(__file__)) + csrc_dir = os.path.join(base_dir, "csrc") + include_dir_list = [csrc_dir] + + # get tmo ops version + tmo_version_file = os.path.join(base_dir, 'build.property') + torch_tmo_version = get_tmo_version(tmo_version_file) + tv = Version(torch.__version__) + torch_tmo_version = torch_tmo_version + "+pt" + f"{tv.major}{tv.minor}" + + # create _version.py + env = Environment(loader=FileSystemLoader(base_dir)) + tpl = env.get_template('version.tpl') + ver_cxt = {'VERSION': torch_tmo_version} + with open('torch_mlu_ops/_version.py', 'w', encoding='utf-8') as f: + f.write(tpl.render(ver_cxt)) + + # source files in source_list can not be absolute path + source_list = get_source_files("csrc", not BUILD_KERNELS_WITH_CMAKE) + include_dir_list.extend(get_include_dirs("csrc")) + library_dir_list = [] + + neuware_home = os.getenv("NEUWARE_HOME", "/usr/local/neuware") + neuware_home_include = os.path.join(neuware_home, "include") + neuware_home_lib = os.path.join(neuware_home, "lib64") + include_dir_list.append(neuware_home_include) + library_dir_list.append(neuware_home_lib) + + CXX_FLAGS += [ + '-fPIC', + '-Wall', + '-Werror', + '-Wno-error=deprecated-declarations', + ] + if DEBUG_MODE: + CXX_FLAGS += ['-Og', '-g', '-DDEBUG'] + + TMO_MEM_CHECK = _check_env_flag('TMO_MEM_CHECK') + if TMO_MEM_CHECK: + SANITIZER_FLAGS=["-O1", + "-g", + "-DDEBUG", + "-fsanitize=address", + "-fno-omit-frame-pointer",] + CXX_FLAGS.extend(SANITIZER_FLAGS) + CNCC_FLAGS.extend(SANITIZER_FLAGS) + + def build_kernel_with_cmake(*args, **kwargs): + if DEBUG_MODE: + os.environ['BUILD_MODE'] = 'debug' + cmd = os.path.join(csrc_dir, "kernels/build.sh") + res = subprocess.run(cmd, shell=True) + assert res.returncode == 0, 'failed to build tmo kernels by cmake' + + if BUILD_KERNELS_WITH_CMAKE: + Build.pre_run(build_kernel_with_cmake) + else: + CNCC_FLAGS += [ + '-fPIC', + '--bang-arch=compute_30', + '--bang-mlu-arch=mtp_592', + '--bang-mlu-arch=mtp_613', + '-Xbang-cnas', + '-fno-soft-pipeline', + '-Wall', + '-Werror', + '-Wno-error=deprecated-declarations', + '--no-neuware-version-check', + ] + if platform.machine() == "x86_64": + CNCC_FLAGS += ['-mcmodel=large'] + if DEBUG_MODE: + CNCC_FLAGS += ['-Og', '-g', '-DDEBUG'] + + # MLUExtension does not add include_dirs automatically for mlu files, + # so we have to add them in CNCC_FLAGS. + for dir in include_dir_list: + CNCC_FLAGS.append("-I" + dir) + + packages=find_packages( + include=( + "torch_mlu_ops", + ) + ) + + # Read in README.md for long_description + with open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f: + long_description = f.read() + + extra_link_args = ['-Wl,-rpath,$ORIGIN/'] + if not DEBUG_MODE: + extra_link_args += ['-Wl,--strip-all'] + if BUILD_KERNELS_WITH_CMAKE: + extra_link_args+=['-Wl,-whole-archive', + os.path.join(csrc_dir, "kernels/build/lib/libtmo_kernels.a"), + '-Wl,-no-whole-archive',] + + C = MLUExtension('torch_mlu_ops._C', + sources=[*source_list], + include_dirs=[*include_dir_list], + library_dirs=[*library_dir_list], + extra_compile_args={'cxx': CXX_FLAGS, + 'cncc': CNCC_FLAGS}, + extra_link_args=extra_link_args, + ) + ext_modules = [C] + + cmdclass={ + 'build_ext': Build, + 'clean': Clean, + 'install': install, + } + + install_requires=[ + "torch >= 2.1.0", + "torch_mlu >= 1.20.0", + "ninja", + "jinja2", + "packaging", + ] + + setup(name="torch_mlu_ops", + version=torch_tmo_version, + packages=packages, + description='BangTransformer Torch API', + long_description=long_description, + long_description_content_type="text/markdown", + url = 'http://www.cambricon.com', + ext_modules=ext_modules, + cmake_find_package=True, + cmdclass=cmdclass, + install_requires=install_requires, + python_requires=">=3.10") + +if __name__ == "__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/CMakeLists.txt b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/CMakeLists.txt new file mode 100644 index 0000000..6055833 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/CMakeLists.txt @@ -0,0 +1,101 @@ +cmake_minimum_required(VERSION 3.10) +set(CMAKE_C_COMPILER "gcc") +set(CMAKE_CXX_COMPILER "g++") + +project(kernel_test) +message(STATUS "project name: ${PROJECT_NAME}") + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib") +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/archive") +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{NEUWARE_HOME}/cmake/modules") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +add_compile_options(-std=c++17 -O3 -g -fPIC -Wall -Werror -Wextra -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unknown-pragmas) + +link_directories($ENV{NEUWARE_HOME}/lib64) +link_libraries(m stdc++ dl pthread cnnl cnnl_extra cnrt cndrv "-Wl,-rpath,$ENV{NEUWARE_HOME}/lib64 -Wl,--disable-new-dtags") +include_directories($ENV{NEUWARE_HOME}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/ ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/kernels) + +function(find_torch) + # get the path and cxx11_abi flag + execute_process( + COMMAND python3 -c "import torch; print(torch.__path__[0], torch.compiled_with_cxx11_abi(), sep=';')" + RESULT_VARIABLE TORCH_NOT_FOUND + OUTPUT_VARIABLE TORCH_INFO + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + if(TORCH_NOT_FOUND) + return() + endif() + + list(GET TORCH_INFO 0 TORCH_PATH) + message(STATUS "torch path: ${TORCH_PATH}") + + list(GET TORCH_INFO 1 TORCH_CXX11_ABI) + message(STATUS "torch cxx11 abi: ${TORCH_CXX11_ABI}") + + set(Torch_DIR ${TORCH_PATH}/share/cmake/Torch PARENT_SCOPE) +endfunction() + +# import pytorch +find_torch() +message(STATUS "Torch_DIR: ${Torch_DIR}") +find_package(Torch QUIET) +find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") + +# import torch_mlu +execute_process( + COMMAND python3 -c "import torch_mlu.utils as mlu_utils;print(mlu_utils.cmake_prefix_path)" + OUTPUT_VARIABLE Torch_MLU_MODULE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# find ops library +execute_process( + COMMAND python3 -c "import torch_mlu_ops as ops;print(ops._utils.get_custom_op_library_path())" + RESULT_VARIABLE LIBOPS_NOT_FOUND + OUTPUT_VARIABLE LIBOPS_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE +) +if(LIBOPS_NOT_FOUND) + message(FATAL_ERROR "torch_mlu_ops not installed, can not find ops library.") +endif() + +set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${Torch_MLU_MODULE_DIR}) +find_package(TorchMLU QUIET) +# torch_mlu throw [-Werror=sign-compare] error, so ignore it +add_compile_options(-Wno-sign-compare) +# TorchMLUConfig.cmake run will get TORCH_ATEN_LIBRARY-NOTFOUND, it will cause compile fail, remove it +string(REPLACE "TORCH_ATEN_LIBRARY-NOTFOUND" "" TORCH_MLU_LIBRARIES_MODIFIED "${TORCH_MLU_LIBRARIES}") + +execute_process( + COMMAND python3 -c "from distutils import sysconfig; print(sysconfig.get_python_inc())" + OUTPUT_VARIABLE PYTHON_INCLUDE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +if(Torch_FOUND AND ((TORCH_CXX11_ABI AND USE_CXX11_ABI) OR (NOT TORCH_CXX11_ABI AND NOT USE_CXX11_ABI))) + include_directories(${PYTHON_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} ${TORCH_MLU_INCLUDE_DIRS}) + link_libraries(${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} ${TORCH_MLU_LIBRARIES_MODIFIED}) + file(GLOB_RECURSE TEST_SRCS "src/*.cpp" RECURSE) + execute_process( + COMMAND python3 -c "import sysconfig;print(sysconfig.get_config_var('EXT_SUFFIX'))" + OUTPUT_VARIABLE TEST_LIB_NAME_SUFFIX + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + message("${PYTHON_EXTENSION}") + set(CMAKE_SHARED_LIBRARY_PREFIX "") + set(CMAKE_SHARED_LIBRARY_SUFFIX "") + set(TEST_LIB_NAME_PREFIX "btunittests") + string(APPEND TEST_LIB_NAME "${TEST_LIB_NAME_PREFIX}${TEST_LIB_NAME_SUFFIX}") + add_library(${TEST_LIB_NAME} SHARED ${TEST_SRCS}) + target_link_libraries(${TEST_LIB_NAME} ${LIBOPS_PATH} -lstdc++fs) +else() + message(STATUS "Torch not found, or torch abi is different with which you specified, will not build") + message(STATUS "if torch not found, please set env Torch_DIR to the directory containing TorchConfig.cmake") +endif() diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/README.md b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/README.md new file mode 100644 index 0000000..d9d7ec1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/README.md @@ -0,0 +1,17 @@ +## BT_OPS测试脚本使用方式 + +```bash +# 测试所有测例 +bash run_test.sh +``` + +```bash +# 测试单个测例 +python3 test_测例名称.py +``` + +- 必须在Torch-MLU-Ops docker容器内运行。 + +- 测试脚本的命名规则为 `test_测例名称.py`。 + +- 必须保证 Torch-MLU-Ops whl包正确安装。 diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/build.sh b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/build.sh new file mode 100755 index 0000000..ce8b499 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/build.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +TOP_DIR="$( cd "$( dirname "$0" )" && pwd )" +cd ${TOP_DIR} + +rm -rf build > /dev/null 2>&1 +cmake $TOP_DIR -Bbuild $@ +cmake --build build -- -j32 diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/run_test.sh b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/run_test.sh new file mode 100755 index 0000000..d91b019 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/run_test.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +tmo_kernel_case=$(find "${SCRIPT_DIR}" -name "test_*.py") +coverage=${1} + +for sc in ${tmo_kernel_case} +do +echo -n "${sc} " +echo -n "Testing...." +if [ "${coverage}" = "coverage" ];then +coverage run -a ${sc} +else +python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1 +fi +if [ $? == 0 ];then +echo -e "\033[32m success \033[0m" +else +echo -e "\033[31m failed \033[0m" +fi +done +echo "End of pytest..." diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/alibi_slope_kernel_test.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/alibi_slope_kernel_test.cpp new file mode 100644 index 0000000..8b824cc --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/alibi_slope_kernel_test.cpp @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "kernels/generate_alibi_slope.mluh" +#include "register_api.h" +#include "utils.h" + +namespace tmo { +namespace kernel_test_api { +at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens, + const int batch, + const int tp_head_num, + const int tp_num, + const bool use_dynamic, + const int max_sequence_length) { + MLU_TENSOR_CHECK_FATAL(true_seq_lens); + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + + int head_num = tp_head_num * tp_num; + std::shared_ptr dev = std::make_shared(true_seq_lens.device()); + torch::Tensor output = torch::zeros( + {batch, head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false)); + + for (int tp_id = 0; tp_id < tp_num; tp_id++) { + torch::Tensor tp_slopes = torch::zeros( + {batch, tp_head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false)); + TMO_KERNEL_CHECK_FATAL(invokeGenerateAlibiSlope( + queue, tp_slopes.data_ptr(), true_seq_lens.data_ptr(), batch, tp_id * tp_head_num, + tp_head_num, head_num, max_sequence_length, use_dynamic)); + for (int batch_id = 0; batch_id < batch; batch_id++) { + CNRT_CHECK( + cnrtMemcpyAsync((float *)output.data_ptr() + batch_id * head_num + tp_id * tp_head_num, + (float *)tp_slopes.data_ptr() + batch_id * tp_head_num, + tp_head_num * sizeof(float), queue, cnrtMemcpyDevToDev)); + } + } + + return output; +} + +} // namespace kernel_test_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/create_cos_sin_table_test.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/create_cos_sin_table_test.cpp new file mode 100644 index 0000000..3d75140 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/create_cos_sin_table_test.cpp @@ -0,0 +1,83 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "kernels/create_cos_sin_table.mluh" +#include +#include +#include "register_api.h" +#include "utils.h" + +namespace { +constexpr int DYNAMIC_NTK_SCALING = 2; +} +namespace tmo { +namespace kernel_test_api { +at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached, + const at::Tensor &seq_lens, + const int max_position_embeddings, + const int batch_stride, + const int rotary_seq_len, + const int rotary_dim, + const int rotary_stride, + const float rotary_scaling, + const int rotary_scaling_type, + const float rotary_base, + const bool interleaved) { + MLU_TENSOR_CHECK_FATAL(rotary_emb_alpha_cached); + MLU_TENSOR_CHECK_FATAL(seq_lens); + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + cnnlDataType_t dtype = + torch2cnnlDataType(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype())); + + std::shared_ptr dev = + std::make_shared(rotary_emb_alpha_cached.device()); + int batch = seq_lens.size(0); + + torch::Tensor cos_sin_table = + rotary_scaling_type == DYNAMIC_NTK_SCALING + ? torch::zeros({batch, rotary_seq_len, rotary_stride}, + torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype())) + .device(*dev) + .requires_grad(false)) + : torch::zeros({rotary_seq_len, rotary_stride}, + torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype())) + .device(*dev) + .requires_grad(false)); + + size_t io_bytes = 0; + + cnrtNotifier_t time_begin; + cnrtNotifier_t time_end; + + CNRT_CHECK(cnrtNotifierCreate(&time_begin)); + CNRT_CHECK(cnrtNotifierCreate(&time_end)); + + CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue)); + TMO_KERNEL_CHECK_FATAL(invokeCreateCosSinTable( + queue, cos_sin_table.data_ptr(), + reinterpret_cast(rotary_emb_alpha_cached.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), max_position_embeddings, batch, batch_stride, + rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling, rotary_scaling_type, + interleaved, dtype)); + CNRT_CHECK(cnrtPlaceNotifier(time_end, queue)); + cnrtQueueSync(queue); + float usec = 0.0f; + CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec)); + print_info(usec, io_bytes); + + return cos_sin_table; +} + +} // namespace kernel_test_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/embedding_kernel_test.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/embedding_kernel_test.cpp new file mode 100644 index 0000000..6f0fb84 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/embedding_kernel_test.cpp @@ -0,0 +1,48 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "kernels/embedding.mluh" +#include "register_api.h" +#include "utils.h" + +namespace tmo { +namespace kernel_test_api { +at::Tensor embedding_test(const at::Tensor &input, + const at::Tensor &weight, + const int vocab_offset, + const int vocab_part) { + MLU_TENSOR_CHECK_FATAL(input); + MLU_TENSOR_CHECK_FATAL(weight); + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + std::shared_ptr dev = std::make_shared(input.device()); + int batch = input.size(0); + int seq = input.size(1); + int total_vocab_size = weight.size(0); + int hidden_size = weight.size(1); + cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(weight.dtype())); + torch::Tensor output = torch::zeros( + {batch, seq, hidden_size}, + torch::dtype(torch::typeMetaToScalarType(weight.dtype())).device(*dev).requires_grad(false)); + + TMO_KERNEL_CHECK_FATAL(invokeEmbedding(queue, weight.data_ptr(), input.data_ptr(), + output.data_ptr(), dtype, vocab_offset, vocab_part, + total_vocab_size, hidden_size, batch * seq)); + + return output; +} + +} // namespace kernel_test_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/operate_cu_seq_lens_kernel_test.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/operate_cu_seq_lens_kernel_test.cpp new file mode 100644 index 0000000..25ea5d7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/operate_cu_seq_lens_kernel_test.cpp @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "kernels/operate_cu_seq_lens.mluh" +#include "register_api.h" +#include "utils.h" + +namespace tmo { +namespace kernel_test_api { +void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens, + at::Tensor &sliced_cu_seq_lens, + int batch, + int parallel_num) { + MLU_TENSOR_CHECK_FATAL(cu_seq_lens); + MLU_TENSOR_CHECK_FATAL(sliced_cu_seq_lens); + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + + size_t io_bytes = 0; + + cnrtNotifier_t time_begin; + cnrtNotifier_t time_end; + + CNRT_CHECK(cnrtNotifierCreate(&time_begin)); + CNRT_CHECK(cnrtNotifierCreate(&time_end)); + + CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue)); + TMO_KERNEL_CHECK_FATAL(invokeSliceCuSeqlens(queue, (int *)cu_seq_lens.data_ptr(), + (int *)sliced_cu_seq_lens.data_ptr(), batch, + parallel_num)); + CNRT_CHECK(cnrtPlaceNotifier(time_end, queue)); + cnrtQueueSync(queue); + float usec = 0.0f; + CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec)); + print_info(usec, io_bytes); +} + +void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens, + int seq_len, + int parallel_num, + bool is_causal_mask, + bool is_kv_seq_len) { + MLU_TENSOR_CHECK_FATAL(gen_cu_seq_lens); + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + + size_t io_bytes = 0; + + cnrtNotifier_t time_begin; + cnrtNotifier_t time_end; + + CNRT_CHECK(cnrtNotifierCreate(&time_begin)); + CNRT_CHECK(cnrtNotifierCreate(&time_end)); + + CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue)); + TMO_KERNEL_CHECK_FATAL(invokeGenerateCuSeqlens(queue, (int *)gen_cu_seq_lens.data_ptr(), seq_len, + parallel_num, is_causal_mask, is_kv_seq_len)); + CNRT_CHECK(cnrtPlaceNotifier(time_end, queue)); + cnrtQueueSync(queue); + float usec = 0.0f; + CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec)); + print_info(usec, io_bytes); +} + +} // namespace kernel_test_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/per_head_quantize_kernel_test.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/per_head_quantize_kernel_test.cpp new file mode 100644 index 0000000..4ab8c57 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/per_head_quantize_kernel_test.cpp @@ -0,0 +1,67 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "kernels/quantize.mluh" +#include "register_api.h" +#include "utils.h" + +namespace tmo { +namespace kernel_test_api { +void per_head_quantize_test(at::Tensor &dst, + at::Tensor &scale, + const at::Tensor &src, + int bs, + int seq_len, + int head_num, + int head_size, + int dst_bs_stride, + int dst_seq_stride, + int dst_head_stride, + int src_bs_stride, + int src_seq_stride, + int src_head_stride) { + MLU_TENSOR_CHECK_FATAL(dst); + MLU_TENSOR_CHECK_FATAL(scale); + MLU_TENSOR_CHECK_FATAL(src); + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + + cnnlDataType_t dst_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(dst.dtype())); + cnnlDataType_t scale_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(scale.dtype())); + cnnlDataType_t src_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(src.dtype())); + + size_t io_bytes = + bs * seq_len * head_num * head_size * (dtype_size_map[src_dtype] + dtype_size_map[dst_dtype]); + + cnrtNotifier_t time_begin; + cnrtNotifier_t time_end; + + CNRT_CHECK(cnrtNotifierCreate(&time_begin)); + CNRT_CHECK(cnrtNotifierCreate(&time_end)); + + CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue)); + TMO_KERNEL_CHECK_FATAL(invokeMluQuantizePerHead( + queue, (void *)dst.data_ptr(), (void *)scale.data_ptr(), (void *)src.data_ptr(), dst_dtype, + scale_dtype, src_dtype, bs, seq_len, head_num, head_size, dst_bs_stride, dst_seq_stride, + dst_head_stride, src_bs_stride, src_seq_stride, src_head_stride)); + CNRT_CHECK(cnrtPlaceNotifier(time_end, queue)); + cnrtQueueSync(queue); + float usec = 0.0f; + CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec)); + print_info(usec, io_bytes); +} + +} // namespace kernel_test_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_api.h b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_api.h new file mode 100644 index 0000000..6d1a7be --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_api.h @@ -0,0 +1,82 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef TEST_KERNELS_PYTEST_REGISTER_API_H_ +#define TEST_KERNELS_PYTEST_REGISTER_API_H_ +#include +#include + +namespace tmo { +namespace kernel_test_api { +at::Tensor embedding_test(const at::Tensor &input, + const at::Tensor &weight, + const int vocab_offset, + const int vocab_part); + +at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens, + const int batch, + const int tp_head_num, + const int tp_num, + const bool use_dynamic, + const int max_sequence_length); + +at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached, + const at::Tensor &seq_lens, + const int max_position_embeddings, + const int batch_stride, + const int rotary_seq_len, + const int rotary_dim, + const int rotary_stride, + const float rotary_scaling, + const int rotary_scaling_type, + const float rotary_base, + const bool interleaved); + +void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens, + at::Tensor &sliced_cu_seq_lens, + int batch, + int parallel_num); + +void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens, + int seq_len, + int parallel_num, + bool is_causal_mask, + bool is_kv_seq_len); + +void per_head_quantize_test(at::Tensor &dst, + at::Tensor &scale, + const at::Tensor &src, + int bs, + int seq_len, + int head_num, + int head_size, + int dst_bs_stride, + int dst_seq_stride, + int dst_head_stride, + int src_bs_stride, + int src_seq_stride, + int src_head_stride); + +at::Tensor rotary_embedding_test(const at::Tensor &input, + const at::Tensor &sin_table, + const at::Tensor &cos_table, + const c10::optional &seq_offsets, + const c10::optional &cu_seq_lens, + const bool &interleaved, + const bool &discrete, + const bool &dynamic_ntk, + const bool &rope_2d, + const int &max_seq_len, + const bool &no_offset); + +} // namespace kernel_test_api +} // namespace tmo +#endif // TEST_KERNELS_PYTEST_REGISTER_API_H_ diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_function.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_function.cpp new file mode 100644 index 0000000..c1f97bf --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_function.cpp @@ -0,0 +1,29 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include "register_api.h" + +PYBIND11_MODULE(btunittests, m) { + m.def("embedding_test", &tmo::kernel_test_api::embedding_test, "Test embedding kernel function."); + m.def("alibi_slope_test", &tmo::kernel_test_api::alibi_slope_test, + "Test alibi_slope kernel function."); + m.def("create_cos_sin_table_test", &tmo::kernel_test_api::create_cos_sin_table_test, + "Test create_cos_sin_table_test kernel function."); + m.def("slice_cu_seq_lens_test", &tmo::kernel_test_api::slice_cu_seq_lens_test, + "Test slice_cu_seq_lens_test kernel function."); + m.def("generate_cu_seq_lens_test", &tmo::kernel_test_api::generate_cu_seq_lens_test, + "Test generate_cu_seq_lens_test kernel function."); + m.def("per_head_quantize_test", &tmo::kernel_test_api::per_head_quantize_test, + "Test per_head_quantize_test kernel function."); + m.def("rotary_embedding_test", &tmo::kernel_test_api::rotary_embedding_test, + "Test rotary_embedding kernel function."); +} diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/rotary_embedding_kernel_test.cpp b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/rotary_embedding_kernel_test.cpp new file mode 100644 index 0000000..28ec353 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/rotary_embedding_kernel_test.cpp @@ -0,0 +1,202 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include +#include +#include "kernels/rotary_embedding.mluh" +#include "register_api.h" +#include "utils.h" + +namespace tmo { +namespace kernel_test_api { +at::Tensor rotary_embedding_test(const at::Tensor &input, + const at::Tensor &sin_table, + const at::Tensor &cos_table, + const c10::optional &seq_offsets, + const c10::optional &cu_seq_lens, + const bool &interleaved, + const bool &discrete, + const bool &dynamic_ntk, + const bool &rope_2d, + const int &max_seq_len, + const bool &no_offset) { + MLU_TENSOR_CHECK_FATAL(input); + MLU_TENSOR_CHECK_FATAL(sin_table); + MLU_TENSOR_CHECK_FATAL(cos_table); + + bool has_seq_offsets = seq_offsets.has_value(); + bool has_cu_seq_lens = cu_seq_lens.has_value(); + int *seq_offsets_ptr = + has_seq_offsets ? reinterpret_cast(seq_offsets.value().data_ptr()) : nullptr; + int *cu_seq_lens_ptr = + has_cu_seq_lens ? reinterpret_cast(cu_seq_lens.value().data_ptr()) : nullptr; + + int batch = 0; + int total_seq_len = 0; + int head_size = input.size(-1); + + if (input.dim() == 3) { // pack mode + TORCH_CHECK(has_cu_seq_lens, + "input has 3 dims: (total_seq_len, head_num, head_size)," + " which means pack mode, cu_seqlens should not be None"); + total_seq_len = input.size(0); + batch = cu_seq_lens.value().size(0) - 1; + } else if (input.dim() == 4) { + TORCH_CHECK(!has_cu_seq_lens, + "input has 4 dims: (batch_size, seq_len, head_num, head_size)," + " which means pad mode, cu_seqlens should be None"); + TORCH_CHECK(max_seq_len == input.size(1), + "input has 4 dims: (batch_size, seq_len, head_num, head_size)," + " which means pad mode, max_seqlen must be equtals to input.size(1)"); + batch = input.size(0); + total_seq_len = batch * input.size(1); + } else { + TORCH_CHECK(false, "input only support 3 or 4 dims"); + } + + const int rope_seqlen = dynamic_ntk ? sin_table.size(1) : sin_table.size(0); + const int rope_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1); + +#if 0 + if (has_seq_offsets) { + if (discrete) { + CHECK_SHAPE(seq_offsets.value(), total_seq_len); + } else { + CHECK_SHAPE(seq_offsets.value(), batch); + } + } + if (dynamic_ntk) { + CHECK_SHAPE(sin_table, batch, rope_seqlen, rope_dim); + CHECK_SHAPE(cos_table, batch, rope_seqlen, rope_dim); + } else { + CHECK_SHAPE(sin_table, rope_seqlen, rope_dim); + CHECK_SHAPE(cos_table, rope_seqlen, rope_dim); + } +#endif + + TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous"); + if (dynamic_ntk) { + TORCH_CHECK(sin_table.stride(1) == cos_table.stride(1), + "sin_table second stride must be equal to cos_table second stride"); + } else { + TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0), + "sin_table first stride must be equal to cos_table second stride"); + } + + if (has_seq_offsets) { + TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous"); + } + + if (has_cu_seq_lens) { + TORCH_CHECK(cu_seq_lens.value().is_contiguous(), "cu_seq_lens must be contiguous"); + } + + // auto output = input; + // if (is_qkv) { + // } + + int dims = input.dim(); + int head_num = input.size(dims - 2); // head_num_qk + TORCH_CHECK(head_size <= 256, "only support input head_size <= 256"); + int head_dim = input.size(dims - 1); + int rotary_seq_len = dynamic_ntk ? sin_table.size(1) : sin_table.size(0); + int rotary_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1); + +#if 0 + if (rope_2d) { + total_seq_len = input.size(0); + rotary_seq_len *= 2; + } +#endif + int rotary_stride = dynamic_ntk ? sin_table.stride(1) : sin_table.stride(0); + + size_t input_seq_stride = input.stride(dims - 3); + size_t input_head_stride = input.stride(dims - 2); + + cnnlHandle_t handle = torch_mlu::getCurrentHandle(); + cnrtQueue_t queue; + cnnlGetQueue(handle, &queue); + +#if 0 + std::shared_ptr dev = std::make_shared(input.device()); + + std::cout + << " batch * max_seq_len : " << batch * max_seq_len + << " max_seq_len : " << max_seq_len + << " total_seq_len : " << total_seq_len + << " head_num : " << head_num + << " head_size : " << head_size << "\n"; + + // at::IntArrayRef output_shape = {batch, max_seq_len, head_num, head_size}; + // if (input.dim() == 3) { + // output_shape = {batch * max_seq_len, head_num, head_size}; + // } +#endif + + cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(input.dtype())); + auto output = input; + + size_t output_seq_stride = output.stride(dims - 3); + size_t output_head_stride = output.stride(dims - 2); + +#if 0 + std::cout << " <<<< batch: " << batch << "\n"; + std::cout << " <<<< max_seq_len: " << max_seq_len << "\n"; + std::cout << " <<<< total_seq_len: " << total_seq_len << "\n"; + std::cout << " <<<< head_num: " << head_num << "\n"; + std::cout << " <<<< head_size: " << head_size << "\n"; + std::cout << " <<<< rotary_seq_len_: " << rotary_seq_len << "\n"; + std::cout << " <<<< rotary_dim: " << rotary_dim << "\n"; + std::cout << " <<<< rotary_stride: " << rotary_stride << "\n"; + std::cout << " <<<< input_seq_stride: " << input_seq_stride << "\n"; + std::cout << " <<<< input_head_stride: " << input_head_stride << "\n"; + std::cout << " <<<< output_seq_stride: " << output_seq_stride << "\n"; + std::cout << " <<<< output_head_stride: " << output_head_stride << "\n"; + std::cout << " <<<< interleaved: " << interleaved << "\n"; + std::cout << " <<<< discrete: " << discrete << "\n"; + std::cout << " <<<< dynamic_ntk: " << dynamic_ntk << "\n"; +#endif + + size_t io_bytes = + (size_t)total_seq_len * head_num * (head_size + rotary_dim) * dtype_size_map[dtype] * 2; + io_bytes += (discrete && !no_offset) ? total_seq_len * sizeof(int) : batch * sizeof(int); + + cnrtNotifier_t time_begin; + cnrtNotifier_t time_end; + + CNRT_CHECK(cnrtNotifierCreate(&time_begin)); + CNRT_CHECK(cnrtNotifierCreate(&time_end)); + + CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue)); + if (rope_2d) { + TMO_KERNEL_CHECK_FATAL(invokeGlm6BRotaryEmbedding( + queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(), + seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, total_seq_len, head_num, head_size, + rotary_seq_len, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride, + output_head_stride, interleaved, dtype)); + } else { + TMO_KERNEL_CHECK_FATAL(invokeRotaryEmbedding( + queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(), + seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, head_num, head_size, rotary_seq_len, + rotary_dim, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride, + output_head_stride, interleaved, discrete, dynamic_ntk, dtype)); + } + CNRT_CHECK(cnrtPlaceNotifier(time_end, queue)); + cnrtQueueSync(queue); + float usec = 0.0f; + CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec)); + print_info(usec, io_bytes); + + return output; +} + +} // namespace kernel_test_api +} // namespace tmo diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/utils.h b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/utils.h new file mode 100644 index 0000000..84668bf --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/utils.h @@ -0,0 +1,133 @@ +/************************************************************************* + * Copyright (C) [2023-2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef TEST_KERNELS_PYTEST_UTILS_H_ +#define TEST_KERNELS_PYTEST_UTILS_H_ +#include +#include +#include +#include "aten/cnnl/cnnlHandle.h" +#include "common/utils.h" +#include "framework/core/MLUStream.h" +#include "framework/core/caching_allocator.h" +#include "framework/core/device.h" +#include "framework/core/mlu_guard.h" +#include "kernels/kernel_utils.h" + +namespace tmo { +static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) { + switch (dtype) { + case torch::kFloat32: + return CNNL_DTYPE_FLOAT; + case torch::kFloat16: + return CNNL_DTYPE_HALF; + case torch::kFloat64: + return CNNL_DTYPE_DOUBLE; + case torch::kInt8: + return CNNL_DTYPE_INT8; + case torch::kInt16: + return CNNL_DTYPE_INT16; + case torch::kInt32: + return CNNL_DTYPE_INT32; + case torch::kInt64: + return CNNL_DTYPE_INT64; + case torch::kUInt8: + return CNNL_DTYPE_UINT8; + case torch::kBool: + return CNNL_DTYPE_BOOL; + case torch::kBFloat16: + return CNNL_DTYPE_BFLOAT16; + default: + throw std::runtime_error("Unsupported torch::Dtype"); + } +} + +static constexpr int dtype_size_map[] = { + [CNNL_DTYPE_INVALID] = 0, + [CNNL_DTYPE_HALF] = 2, + [CNNL_DTYPE_FLOAT] = 4, + [CNNL_DTYPE_INT8] = 1, + [CNNL_DTYPE_INT16] = 2, + [CNNL_DTYPE_INT31] = 4, + [CNNL_DTYPE_INT32] = 4, + [CNNL_DTYPE_UINT8] = 1, + [CNNL_DTYPE_BOOL] = 1, + [CNNL_DTYPE_INT64] = 8, + [10] = 0, + [CNNL_DTYPE_UINT32] = 4, + [CNNL_DTYPE_UINT64] = 8, + [CNNL_DTYPE_UINT16] = 2, + [CNNL_DTYPE_DOUBLE] = 8, + [CNNL_DTYPE_COMPLEX_HALF] = 4, + [CNNL_DTYPE_COMPLEX_FLOAT] = 8, + [CNNL_DTYPE_BFLOAT16] = 2, +}; + +static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) { + switch (dtype) { + case CNNL_DTYPE_FLOAT: + return torch::kFloat32; + case CNNL_DTYPE_HALF: + return torch::kFloat16; + case CNNL_DTYPE_DOUBLE: + return torch::kFloat64; + case CNNL_DTYPE_INT8: + return torch::kInt8; + case CNNL_DTYPE_INT16: + return torch::kInt16; + case CNNL_DTYPE_INT32: + return torch::kInt32; + case CNNL_DTYPE_INT64: + return torch::kInt64; + case CNNL_DTYPE_UINT8: + return torch::kUInt8; + case CNNL_DTYPE_BOOL: + return torch::kBool; + case CNNL_DTYPE_BFLOAT16: + return torch::kBFloat16; + default: + throw std::runtime_error("Unsupported cnnlDataType_t"); + } +} + +static float getBandWidth() { + int card = -1; + CNRT_CHECK(cnrtGetDevice(&card)); + if (cndevInit(0) != CNDEV_SUCCESS) { + abort(); + } + cndevDDRInfo_t ddrinfo; + ddrinfo.version = CNDEV_VERSION_5; + if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) { + abort(); + } + double band_width = ddrinfo.bandWidth; + double band_width_decimal = ddrinfo.bandWidthDecimal; + do { + band_width_decimal /= 10; + } while (band_width_decimal > 1); + return float(band_width + band_width_decimal); +} + +static void print_info(float time_usec, size_t io_bytes) { + float io_bandwidth = getBandWidth(); + std::cout << "kernel time: " << time_usec << "us" << std::endl; + std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl; + std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl; +} + +#define MLU_TENSOR_CHECK_FATAL(tensor) \ + if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \ + throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \ + } + +} // namespace tmo +#endif // TEST_KERNELS_PYTEST_UTILS_H_ diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_alibi_slope_kernel.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_alibi_slope_kernel.py new file mode 100644 index 0000000..0571786 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_alibi_slope_kernel.py @@ -0,0 +1,75 @@ +import sys +import os +from unit_test import UnitTest +import pytest +import torch +import torch_mlu +import btunittests + +import math + +def alibi_slopes(batch, + num_heads, + true_seq_lens, + train_seq_len, + use_dynamic): + scale = 1.0 + if use_dynamic: + # dynamic ntk factor according to actual sequence length + a0 = 1.0 + # train_seq_len = 2048 + dynamic_seq_len = true_seq_lens # [batch, 1] + a = a0 * dynamic_seq_len / train_seq_len # [batch, 1] + a = a.masked_fill(a < 1.0, 1.0) # dynamic step 1: dynamic ntk scaling factor + scale = a ** (1.0 / (num_heads-1)) # dynamic step 2: coefficient b, for computation convenience + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32 + ) + if use_dynamic: + base = base / scale # dynamic step 3: divide b to alibi base + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + if use_dynamic: + slopes = slopes * scale # dynamic step 4: fix alibi bias m_h by multiplying b + + if closest_power_of_2 != num_heads: # todo: fix ntk when num_heads is not power of 2 + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32) + extra_slopes = torch.pow(extra_base, extra_powers) + if use_dynamic: + extra_slopes = extra_slopes.unsqueeze(0) + extra_slopes = extra_slopes.repeat(batch, 1) + slopes = torch.cat([slopes, extra_slopes], dim=-1) + + if not use_dynamic: + slopes = slopes.unsqueeze(0) + slopes = slopes.repeat(batch, 1) + + return slopes + +class TestAlibiSlopeKernel(UnitTest): + @pytest.mark.parametrize("batch, tp_head_num, tp_num, use_dynamic, max_sequence_length", [ + (1, 11, 1, False, 1024), + (1, 11, 2, False, 1024), + (8, 16, 1, False, 4096), + (8, 16, 2, False, 4096), + (1, 11, 1, True, 1024), + (1, 11, 2, True, 1024), + (8, 16, 1, True, 4096), + (8, 16, 2, True, 4096)]) + def test_alibi_slope(self, batch, tp_head_num, tp_num, use_dynamic, max_sequence_length): + torch.manual_seed(0) + true_seq_lens = torch.randint(0, 2 * max_sequence_length, (batch, 1), dtype=torch.int32) + output_base = alibi_slopes(batch, tp_head_num * tp_num, true_seq_lens, max_sequence_length, use_dynamic) + output_mlu = btunittests.alibi_slope_test(true_seq_lens.mlu(), batch, tp_head_num, tp_num, use_dynamic, max_sequence_length) + diff1 = super().diff1(output_mlu, output_base) + diff2 = super().diff2(output_mlu, output_base) + assert diff1 <= 1e-5 and diff2 <= 1e-5 + +if __name__ == '__main__': + exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml']) + exit(exit_code) diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_create_cos_sin_table_kernel.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_create_cos_sin_table_kernel.py new file mode 100644 index 0000000..163b432 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_create_cos_sin_table_kernel.py @@ -0,0 +1,109 @@ +import sys +import os +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +sys.path.append(parent_dir) +from unit_test import UnitTest +import pytest +import torch +import torch_mlu +import btunittests +import random +import math + +def create_cos_sin_table(rotary_base, rotary_seq_len_, rotary_stride_, rotary_scaling, rotary_dim, interleaved, datatype): + torch.manual_seed(0) + random.seed(0) + inv_freq=1.0 / (rotary_base ** (torch.arange(0, rotary_dim, 2).mlu() / rotary_dim)) + t = torch.arange(rotary_seq_len_).mlu() + t = t / rotary_scaling + freqs = torch.outer(t, inv_freq).mlu() + cos_table = torch.cos(freqs).mlu() + sin_table = torch.sin(freqs).mlu() + if interleaved == 0: + cos_table = torch.cat((cos_table,cos_table), dim = 1).mlu() + sin_table = torch.cat((sin_table,sin_table), dim = 1).mlu() + emb = torch.cat((cos_table, sin_table), dim=-1).mlu() + else: + cos_tmp = torch.cat((cos_table[:,0].unsqueeze(1),cos_table[:,0].unsqueeze(1)), dim = 1).mlu() + emb = cos_tmp + for i in range(1, cos_table.size(1)): + cos_tmp = torch.cat((cos_table[:,i].unsqueeze(1),cos_table[:,i].unsqueeze(1)), dim = 1).mlu() + emb = torch.cat((emb,cos_tmp), dim = 1).mlu() + for i in range(0, cos_table.size(1)): + sin_tmp = torch.cat((sin_table[:,i].unsqueeze(1),sin_table[:,i].unsqueeze(1)), dim = 1).mlu() + emb = torch.cat((emb,sin_tmp), dim = 1).mlu() + return emb + +def get_ntk_alpha(true_seq_len, max_position_embeddings, rotary_base, rotary_dim): + context_value = math.log(true_seq_len / max_position_embeddings, 2) + 1 + ntk_alpha = 2 ** math.ceil(context_value) - 1 + ntk_alpha = max(ntk_alpha, 1) + rotary_b = rotary_base * (ntk_alpha ** (rotary_dim / (rotary_dim - 2))) + return ntk_alpha, rotary_b + +class TestCreateCosSinTableKernel(UnitTest): + @pytest.mark.parametrize("batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,\ + rotary_scaling_type, rotary_dim , interleaved, datatype", + [(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.float32), + (8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.float32), + (4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float32), + (4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float32), + (4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float32), + (4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float32), + (4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float32), + (4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float32), + (8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.bfloat16), + (8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.bfloat16), + (4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float16), + (4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float16), + (4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float16), + (4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float16), + (4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float16), + (4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float16)]) + def test_all(self, batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling, + rotary_scaling_type, rotary_dim, interleaved ,datatype): + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name and datatype == torch.bfloat16: + datatype = torch.half + print("BFLOAT16 is not support on {}, use half instead".format(mlu_name)) + torch.manual_seed(0) + random.seed(0) + seq_lens = torch.randint(size = (batch,), low = 0, high = 512, + dtype = torch.int32).mlu() + ntk_alpha_list = [] + rotary_emb_alpha_cached = torch.rand((batch), dtype=torch.float32).mlu() + ref_rotary_emb_alpha_cached = rotary_emb_alpha_cached.clone() + output_mlu = btunittests.create_cos_sin_table_test(ref_rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch_stride, + rotary_seq_len, rotary_dim, rotary_stride, rotary_scaling, + rotary_scaling_type, rotary_base, interleaved) + if rotary_scaling_type != 2: + output_base = create_cos_sin_table(rotary_base, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype) + else: + max_seq_len, max_index = torch.max(seq_lens, dim = 0) + if max_seq_len > max_position_embeddings: + for i in range(batch): + ntk_alpha, rb = get_ntk_alpha(seq_lens[i], max_position_embeddings, rotary_base, rotary_dim) + rotary_emb_alpha_cached[i] = ntk_alpha + cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0) + cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp + else: + ntk_alpha, rb = get_ntk_alpha(max_seq_len, max_position_embeddings, rotary_base, rotary_dim) + for i in range(batch): + rotary_emb_alpha_cached[i] = ntk_alpha + cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0) + cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp + output_base = cos_sin_table + output_base_1 = rotary_emb_alpha_cached + diff1 = super().diff1(output_mlu, output_base) + diff2 = super().diff2(output_mlu, output_base) + assert diff1 < 0.003 and diff2 < 0.003 + if rotary_scaling_type == 2: + diff1 = super().diff1(ref_rotary_emb_alpha_cached, output_base_1) + diff2 = super().diff2(ref_rotary_emb_alpha_cached, output_base_1) + assert diff1 < 0.003 and diff2 < 0.003 + assert diff1 < 0.003 and diff2 < 0.003 + +if __name__ == '__main__': + exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml']) + exit(exit_code) diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_embedding_kernel.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_embedding_kernel.py new file mode 100644 index 0000000..1bdee6f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_embedding_kernel.py @@ -0,0 +1,46 @@ +import os +from unit_test import UnitTest +import pytest +import torch +import torch_mlu +import torch.nn as nn +import btunittests + +class TestEmbeddingKernel(UnitTest): + @pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype", [ + (1, 1024, 10000, 128, torch.float16), + (6, 128, 20000, 4096, torch.float16), + (1, 10, 5000, 1024, torch.float32), + (5, 10, 10000, 128, torch.float32)]) + def test_all(self, batch, seq, vocab_size, hidden_size, dtype): + torch.manual_seed(0) + input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu() + embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size).to(dtype).mlu() + output_base = embedding_layer(input) + output_mlu = btunittests.embedding_test(input, embedding_layer.weight, 0, vocab_size) + diff1 = super().diff1(output_mlu, output_base) + diff2 = super().diff2(output_mlu, output_base) + assert diff1 == 0 and diff2 == 0 + + @pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part", [ + (1, 1024, 10000, 128, torch.float16, 2000, 1000), + (6, 128, 20000, 4096, torch.float16, 10000, 10000), + (1, 10, 5000, 1024, torch.float32, 0, 2500), + (5, 10, 10000, 128, torch.float32, 100, 512)]) + def test_part(self, batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part): + torch.manual_seed(0) + input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu() + input_mask = torch.ones_like(input) * vocab_size + input = torch.where(input < vocab_offset, input_mask, input) + input = torch.where(input >= vocab_offset + vocab_part, input_mask, input) + embedding_layer = nn.Embedding(num_embeddings=vocab_size + 1, embedding_dim=hidden_size, padding_idx=vocab_size).to(dtype).mlu() + output_base = embedding_layer(input) + part_weight = torch.Tensor(embedding_layer.weight)[vocab_offset : vocab_offset + vocab_part, :] + output_mlu = btunittests.embedding_test(input, part_weight, vocab_offset, vocab_part) + diff1 = super().diff1(output_mlu, output_base) + diff2 = super().diff2(output_mlu, output_base) + assert diff1 == 0 and diff2 == 0 + +if __name__ == '__main__': + exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml']) + exit(exit_code) diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_operate_cu_seq_lens_kernel.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_operate_cu_seq_lens_kernel.py new file mode 100644 index 0000000..c90581d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_operate_cu_seq_lens_kernel.py @@ -0,0 +1,92 @@ +import sys +import os +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +sys.path.append(parent_dir) +from unit_test import UnitTest +import pytest +import torch +import torch_mlu +import btunittests +import random + +class TestSliceCuSeqlensKernel(UnitTest): + @pytest.mark.parametrize("batch_size, parallel_num", [ + (8,3), + (16,16), + (32,16), + (64,16), + (128,16), + (256,16), + (512,16), + (1024,16), + (2048,16), + (4096,16), + ]) + def test_all(self, batch_size, parallel_num): + torch.manual_seed(0) + random.seed(0) + if batch_size < parallel_num: + return + cu_seq_lens = torch.randint(low=0, high=1024*1024, size=(batch_size+1,), dtype=torch.int32).mlu() + every = (batch_size + parallel_num - 1) // parallel_num + repeat = batch_size // every + remain = batch_size % every + loop = repeat + (remain != 0) + result = [] + for i in range(loop): + elem_num = 1 + (i == loop - 1 and remain != 0 and remain or every) + start_idx = i * every + end_idx = start_idx + elem_num + slice_tensor = cu_seq_lens[start_idx:end_idx] + adjusted_tensor = slice_tensor - slice_tensor[0] + result.append(adjusted_tensor) + sliced_cu_seq_lens = torch.cat(result) + sliced_cu_seq_lens_mlu = torch.zeros(batch_size + loop, dtype=torch.int32).mlu() + btunittests.slice_cu_seq_lens_test(cu_seq_lens, sliced_cu_seq_lens_mlu, batch_size, parallel_num) + sliced_cu_seq_lens_diff1 = super().diff1(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu) + sliced_cu_seq_lens_diff2 = super().diff2(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu) + assert sliced_cu_seq_lens_diff1 == 0 and sliced_cu_seq_lens_diff2 == 0 + +class TestGenerateCuSeqlensKernel(UnitTest): + @pytest.mark.parametrize("seq_len, parallel_num, is_causal_mask, is_kv_seq_len", [ + (2154, 4, False, False), + (3412, 3, True, False), + (996, 7, False, True), + (872, 4, False, False), + (634, 12, True, False), + (486, 4, False, True), + (125, 6, False, False), + ]) + def test_all(self, seq_len, parallel_num, is_causal_mask, is_kv_seq_len): + every = (seq_len + parallel_num - 1) // parallel_num + repeat = seq_len // every + remain = seq_len % every + loop = repeat + (remain != 0) + + generate_cu_seq_lens = torch.zeros(2 * loop, dtype=torch.int32).mlu() + + rep = seq_len // loop + rem = seq_len % loop + base = rep + (rem != 0) + + if is_causal_mask and is_kv_seq_len: + for i in range(loop): + generate_cu_seq_lens[2 * i + 1] = seq_len if i == loop - 1 else (i + 1) * base + elif not is_kv_seq_len: + for i in range(loop): + generate_cu_seq_lens[2 * i + 1] = seq_len - (loop - 1) * base if i == loop - 1 else base + elif not is_causal_mask and is_kv_seq_len: + for i in range(loop): + generate_cu_seq_lens[2 * i + 1] = seq_len + + generate_cu_seq_lens_mlu = torch.zeros(2 * loop, dtype=torch.int32).mlu() + btunittests.generate_cu_seq_lens_test(generate_cu_seq_lens_mlu, seq_len, parallel_num, is_causal_mask, is_kv_seq_len) + generate_cu_seq_lens_diff1 = super().diff1(generate_cu_seq_lens, generate_cu_seq_lens_mlu) + generate_cu_seq_lens_diff2 = super().diff2(generate_cu_seq_lens, generate_cu_seq_lens_mlu) + assert generate_cu_seq_lens_diff1 == 0 and generate_cu_seq_lens_diff2 == 0 + + +if __name__ == '__main__': + exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml']) + exit(exit_code) diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_per_head_quantize_kernel.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_per_head_quantize_kernel.py new file mode 100644 index 0000000..82f1ab8 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_per_head_quantize_kernel.py @@ -0,0 +1,82 @@ +import sys +import os +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +sys.path.append(parent_dir) +from unit_test import UnitTest +import pytest +import torch +import torch_mlu +import btunittests +import random + +def round_func(tensor, mode=1): + tensor = tensor.float() + fractional_part = tensor - tensor.floor() + even_mask = (tensor.floor() % 2 == 0) + if mode == 1: + rounded_tensor = torch.where(fractional_part == 0.5, + torch.where(even_mask, tensor.floor(), tensor.ceil()), + torch.round(tensor)) + elif mode == 2: + rounded_tensor = torch.where(fractional_part == 0.5, tensor + 0.5, torch.round(tensor)) + elif mode == 3: + rounded_tensor = torch.round(tensor) + return rounded_tensor + +def compute_with_tensor(src, batch, seq_len, head_num, head_size): + src = src.view(batch * seq_len * head_num * 2, head_size) + mask = torch.ones(src.size(0), dtype=bool) + for i in range(head_num, src.size(0), head_num * 2): + mask[i:i+head_num] = False + src = src[mask] + max_value_tensor, _ = torch.max(src.abs(), dim=1) + scale = (max_value_tensor / 127.0).view(batch, seq_len, head_num) + scale_ori_quant = 127 / max_value_tensor + temp_tensor = src * scale_ori_quant.view(scale_ori_quant.size(0),1) + dst = round_func(temp_tensor, 1).view(batch, seq_len, head_num, head_size) + return scale,dst + +class TestPerHeadQuantizeKernel(UnitTest): + @pytest.mark.parametrize("batch, seq_len, head_num, head_size, src_dtype", [ + (17,1436,31,21, torch.float32), + (17,1436,31,21, torch.float16), + (random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float32), + (random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float16), + (2,3,5,5, torch.float32) + ]) + def test_all(self, batch, seq_len, head_num, head_size, src_dtype): + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name and src_dtype == torch.bfloat16: + src_dtype = torch.half + print("BFLOAT16 is not support on {}, use half instead".format(mlu_name)) + torch.manual_seed(0) + random.seed(0) + hidden = head_num * head_size + seq_stride = 2 * hidden + scale_dtype = torch.float32 + dst_dtype = torch.int8 + dst_bs_stride = seq_len * hidden + dst_seq_stride = hidden + dst_head_stride = head_size + src_bs_stride = seq_len * seq_stride + src_seq_stride = seq_stride + src_head_stride = head_size + + src = torch.rand((batch, seq_len, head_num, 2 * head_size), dtype=src_dtype).mlu() + scale, dst = compute_with_tensor(src, batch, seq_len, head_num, head_size) + scale_mlu = torch.zeros((batch, seq_len, head_num), dtype=scale_dtype).mlu() + dst_mlu = torch.zeros((batch, seq_len, head_num, head_size), dtype=dst_dtype).mlu() + btunittests.per_head_quantize_test(dst_mlu, scale_mlu, src, batch, seq_len, head_num, head_size, dst_bs_stride, + dst_seq_stride, dst_head_stride, src_bs_stride, src_seq_stride, + src_head_stride) + dst_diff1 = super().diff1(dst, dst_mlu) + dst_diff2 = super().diff2(dst, dst_mlu) + scale_diff1 = super().diff1(scale, scale_mlu) + scale_diff2 = super().diff2(scale, scale_mlu) + assert dst_diff1 < 0.003 and dst_diff2 < 0.003 + assert scale_diff1 < 0.003 and scale_diff2 < 0.003 + +if __name__ == '__main__': + exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml']) + exit(exit_code) diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_rotary_embedding_kernel.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_rotary_embedding_kernel.py new file mode 100644 index 0000000..dd076f5 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/test_rotary_embedding_kernel.py @@ -0,0 +1,220 @@ +import sys +import os +from unit_test import UnitTest +import pytest +import torch +import torch_mlu +import btunittests + +class ApplyRotary(torch.nn.Module): + def __init__(self): + super().__init__() + + def rotate(self, x: torch.Tensor, interleaved: bool): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + y = torch.empty_like(x) + x1, x2 = x[..., ::2], x[..., 1::2] + y[..., ::2], y[..., 1::2] = -x2, x1 + return y + + def forward(self, + output: torch.Tensor, # [total_seqlen, num_heads, head_size] + input: torch.Tensor, # [total_seqlen, num_heads, head_size] + sin_cache: torch.Tensor, # [rope_seqlen, rotary_dim] / [batch, rope_seqlen, rotary_dim] + cos_cache: torch.Tensor, # + seq_offsets: torch.Tensor, # [batch] / [batch, max_seqlen] + cu_seq_lens: torch.Tensor, # [batch + 1] + interleaved: bool, + discrete: bool, + dynamic: bool, + max_seqlen: int, + no_offset:bool, + rope_2d:bool): + packed = input.dim() == 3 + rope_dim = sin_cache.shape[-1] + batch_size = cu_seq_lens.shape[0] - 1 if packed else input.shape[0] + + if not rope_2d: + for i in range(batch_size): + input_i = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i] + ouput_i = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i] + input_i = input_i[..., 0:rope_dim] + ouput_i = ouput_i[..., 0:rope_dim] + start_seq_idx = 0 if no_offset else seq_offsets[i] + sin_cache_i = sin_cache[i] if dynamic else sin_cache + cos_cache_i = cos_cache[i] if dynamic else cos_cache + seq = input_i.shape[0] + if discrete: + start_seq_idx = 0 + token_offset = seq_offsets[cu_seq_lens[i]:cu_seq_lens[i]+seq] + sin_cache_i = sin_cache_i[token_offset] + cos_cache_i = cos_cache_i[token_offset] + sin_cache_i = sin_cache_i[start_seq_idx:seq+start_seq_idx] + cos_cache_i = cos_cache_i[start_seq_idx:seq+start_seq_idx] + rot = self.rotate(input_i, interleaved) + + ouput_i[:] = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1) + output[..., rope_dim:] = input[..., rope_dim:] + else: + for i in range(batch_size): + input_i_left = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i] + ouput_i_left = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i] + input_i_left = input_i_left[..., 0:rope_dim] + ouput_i_left = ouput_i_left[..., 0:rope_dim] + sin_cache_i = sin_cache[i] if dynamic else sin_cache + cos_cache_i = cos_cache[i] if dynamic else cos_cache + seq = input_i_left.shape[0] + if discrete: + token_offset = seq_offsets[0][cu_seq_lens[i]:cu_seq_lens[i]+seq] + sin_cache_i = sin_cache_i[token_offset] + cos_cache_i = cos_cache_i[token_offset] + sin_cache_i = sin_cache_i[:seq] + cos_cache_i = cos_cache_i[:seq] + rot = self.rotate(input_i_left, interleaved) + + ouput_i_left[:] = rot * sin_cache_i.unsqueeze(1) + input_i_left * cos_cache_i.unsqueeze(1) + + input_i_right = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i] + ouput_i_right = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i] + input_i_right = input_i_right[..., rope_dim:] + ouput_i_right = ouput_i_right[..., rope_dim:] + sin_cache_i = sin_cache[i] if dynamic else sin_cache + cos_cache_i = cos_cache[i] if dynamic else cos_cache + seq = input_i_right.shape[0] + if discrete: + token_offset = seq_offsets[1][cu_seq_lens[i]:cu_seq_lens[i]+seq] + sin_cache_i = sin_cache_i[token_offset] + cos_cache_i = cos_cache_i[token_offset] + sin_cache_i = sin_cache_i[:seq] + cos_cache_i = cos_cache_i[:seq] + rot = self.rotate(input_i_right, interleaved) + + ouput_i_right[:] = rot * sin_cache_i.unsqueeze(1) + input_i_right * cos_cache_i.unsqueeze(1) + +class TestRotaryEmbeddingKernel(UnitTest): + @pytest.mark.parametrize( + "batch, max_seq_len, head_num_q, head_num_kv, head_size, discrete, rotary_dim, interleaved, rope_2d, dynamic_ntk, packed, no_offset, data_type", + [ + (4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float32), + (4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float32), + (4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float32), + (4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float32), + (4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float16), + (4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float16), + (4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float16), + (4, 16, 8, 4, 128, True, 128, True, False, True, True, True, torch.float16), + (4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float16), + (4, 16, 8, 4, 128, True, 128, True, False, False, True, True, torch.float16), + (4, 16, 8, 4, 128, True, 128, True, True, False, True, True, torch.float16), + (4, 16, 8, 4, 128, False, 128, False, False, True, True, False, torch.float32), + (4, 16, 8, 4, 128, True, 128, False, False, True, True, False, torch.float32), + (4, 16, 8, 4, 128, True, 128, False, True, True, True, False, torch.float32), + ]) + def testall( + self, + batch, + max_seq_len, + head_num_q, + head_num_kv, + head_size, + discrete, + rotary_dim, + interleaved, + rope_2d, + dynamic_ntk, + packed, + no_offset, + data_type): + torch.manual_seed(0) + if rope_2d: + interleaved = False + discrete = True + dynamic_ntk = False + no_offset = False + if head_num_q != 2 * head_num_kv: + head_num_q = 2 * head_num_kv + if rotary_dim %2 != 0: + rotary_dim -= 1 + if no_offset: + discrete = False + if rope_2d: + head_size = int(head_size if (head_size % 4 == 0) else (head_size // 4) * 4) + rotary_dim = int(head_size / 2); + + seq_lens = torch.randint(size=(batch, ), low=1, high=max_seq_len+1, dtype=torch.int32, device='mlu') + total_seq_len = batch * max_seq_len + if not packed: + seq_lens = torch.randint(size=(batch, ), low=max_seq_len, high=max_seq_len+1, dtype=torch.int32, device='mlu') + max_context_len = int(seq_lens.max()) + + cu_seq_lens = None + if packed: + cu_seq_lens = torch.cumsum(seq_lens, dim=-1, dtype=torch.int32) + cu_seq_lens = torch.nn.functional.pad(cu_seq_lens, (1,0), "constant", 0).mlu() + total_seq_len = int(cu_seq_lens[-1]) + + head_num_qkv = head_num_q + head_num_kv + head_num_qk = head_num_q + head_num_kv + + context_shape = (total_seq_len, head_num_qkv + 1, head_size) if packed else \ + (batch, max_seq_len, head_num_qkv + 1, head_size) + context = torch.randn(size=context_shape, dtype=data_type).mlu() + + qkv = context[..., 0 : head_num_qkv, :] + qk = context[..., 0 : head_num_qk, :] + seq_offsets = None + if rope_2d: + seq_offsets = torch.randint(size=(2, total_seq_len), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu') + else: + if discrete and not no_offset: + seq_offsets = torch.randint(size=(total_seq_len,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu') + elif not no_offset: + seq_offsets = torch.randint(size=(batch,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu') + + cos_cache = None + sin_cache = None + if dynamic_ntk: + cos_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu() + sin_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu() + else: + cos_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu() + sin_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu() + + base_output = torch.empty_like(qk) + + apply_rotary = ApplyRotary() + apply_rotary(base_output, qkv, sin_cache, cos_cache, \ + seq_offsets, cu_seq_lens, interleaved, discrete, dynamic_ntk, max_context_len, no_offset, rope_2d) + + output = btunittests.rotary_embedding_test( + qkv, + sin_cache, + cos_cache, + seq_offsets, + cu_seq_lens, + interleaved, + discrete, + dynamic_ntk, + rope_2d, + max_context_len, + no_offset) + + diff1 = super().diff1(output, base_output) + diff2 = super().diff2(output, base_output) + assert diff1 < 0.003 and diff2 < 0.003 + + +if __name__ == "__main__": + exit_code = pytest.main( + [ + "-vs", + os.path.abspath(__file__), + "--junitxml=" + + os.path.splitext(os.path.basename(__file__))[0] + + "_report.xml", + ] + ) + exit(exit_code) diff --git a/torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py new file mode 100644 index 0000000..d1b0027 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py @@ -0,0 +1,25 @@ +import sys +import os +build_lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/build/lib" +sys.path.append(build_lib_dir) +import torch + +class UnitTest: + def diff1(self, result: torch.Tensor, baseline: torch.Tensor): + result = result.flatten().float().to('cpu') + baseline = baseline.flatten().float().to('cpu') + assert result.shape == baseline.shape + error = torch.abs(baseline - result) + denominator = torch.sum(torch.abs(baseline)).item() + eps = 0.0 if denominator > 0 else 1e-9 + diff1 = torch.sum(error) / (denominator + eps) + return diff1.item() + + def diff2(self, result: torch.Tensor, baseline: torch.Tensor): + result = result.flatten().float().to('cpu') + baseline = baseline.flatten().float().to('cpu') + error = torch.abs(baseline - result) + denominator = torch.sum(baseline**2).item() + eps = 0.0 if denominator > 0 else 1e-9 + diff2 = torch.sqrt(torch.sum(error**2) / (denominator + eps)) + return diff2.item() diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/README.md b/torch_mlu_ops-v1.3.2/tests/ops_pytest/README.md new file mode 100644 index 0000000..d9d7ec1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/README.md @@ -0,0 +1,17 @@ +## BT_OPS测试脚本使用方式 + +```bash +# 测试所有测例 +bash run_test.sh +``` + +```bash +# 测试单个测例 +python3 test_测例名称.py +``` + +- 必须在Torch-MLU-Ops docker容器内运行。 + +- 测试脚本的命名规则为 `test_测例名称.py`。 + +- 必须保证 Torch-MLU-Ops whl包正确安装。 diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_flash_attn_sq_mm_allreduce.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_flash_attn_sq_mm_allreduce.py new file mode 100755 index 0000000..1ab25b3 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_flash_attn_sq_mm_allreduce.py @@ -0,0 +1,147 @@ +import torch +import torch_mlu +import unittest +import math +import torch_mlu_ops as ops +import torch.multiprocessing as mp +import sys +import os + +work_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(work_dir)) +from common_utils import * + +def flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale, + bias, softmax_scale, is_causal, world_size = 1): + q_list = q.chunk(world_size, dim=2) + k_list = k.chunk(world_size, dim=2) + v_list = v.chunk(world_size, dim=2) + smooth_list = smooth.chunk(world_size, dim=0) + quant_weight_list = quant_weight.chunk(world_size, dim=1) + quant_weight_list = [w.contiguous() for w in quant_weight_list] + output1 = torch.zeros(q.size(0) * q.size(1), q.size(2) * q.size(3), dtype=q.dtype).mlu() + for i in range(world_size): + attn_output = ops.flash_attention(q_list[i], k_list[i], v_list[i], None, None, None, + None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1) + quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth_list[i], None) + output1 += ops.smooth_quant_matmul(quant_input, input_scale, + quant_weight_list[i], weight_scale, q.dtype, bias if i == 0 else None) + attn_output = ops.flash_attention(q, k, v, None, None, None, + None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1) + quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None) + output2 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias) + return output1, output2 + +def tp_flash_attn_sq_mm(rank, *args): + world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, base_output = args + q_cpu, k_cpu, v_cpu, smooth_cpu = q.cpu(), k.cpu(), v.cpu(), smooth.cpu() + quant_weight_cpu, weight_scale_cpu, bias_cpu = quant_weight.cpu(), weight_scale.cpu(), bias.cpu() + base_output_cpu = base_output.cpu() + setup(rank, world_size) + pg = get_default_group() + cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank) + head_num_q = q.size(2) + head_num_kv = k.size(2) + seq = q.size(1) + assert head_num_q % world_size == 0 + assert head_num_kv % world_size == 0 + head_num_q_tp = head_num_q // world_size + head_num_kv_tp = head_num_kv // world_size + q = q_cpu.mlu() + k = k_cpu.mlu() + v = v_cpu.mlu() + smooth = smooth_cpu.mlu() + quant_weight = quant_weight_cpu.mlu() + weight_scale = weight_scale_cpu.mlu() + # Note: only tp0 add bias + bias = bias_cpu.mlu() if rank == 0 else None + q_list = q.chunk(world_size, dim=2) + k_list = k.chunk(world_size, dim=2) + v_list = v.chunk(world_size, dim=2) + smooth_list = smooth.chunk(world_size, dim=0) + quant_weight_list = quant_weight.chunk(world_size, dim=1) + quant_weight_list = [w.contiguous() for w in quant_weight_list] + # test pad mode + output_pad = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_list[rank], k_list[rank], v_list[rank], + None, None, None, None, + smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq, + seq, softmax_scale, is_causal) + assertTensorsEqual(output_pad.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True) + # test pack mode + cu_seq_lens_q = torch.tensor([0, seq], dtype=torch.int32).mlu() + cu_seq_lens_k = torch.tensor([0, seq], dtype=torch.int32).mlu() + q_pack = q_list[rank].flatten(0, 1) + k_pack = k_list[rank].flatten(0, 1) + v_pack = v_list[rank].flatten(0, 1) + output_pack = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_pack, k_pack, v_pack, + cu_seq_lens_q, cu_seq_lens_k, None, None, + smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq, + seq, softmax_scale, is_causal) + assertTensorsEqual(output_pack.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True) + cleanup() + +class TestFlashAttnSqMMAllreduce(BtTestCase): + def op_impl_base(self, *args): + return super().op_impl_base(*args) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_flash_attn_sq_mm_split_seq due to ASan issues") + def test_flash_attn_sq_mm_split_seq(self): + batch, seq, head_num_q, head_num_kv, head_size, is_causal, block_seq = 16, 1024, 8, 1, 128, True, 4 + dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half + hidden_size = head_num_q * head_size + softmax_scale = 1 / math.sqrt(head_size) + qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu() + smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0 + bias = torch.zeros(hidden_size, dtype=dtype).mlu() + weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu() + quant_weight, weight_scale = QuantByRow(weight / smooth, 8) + q = qkv[:, :, :head_num_q, :] + k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :] + v = qkv[:, :, head_num_q+head_num_kv:, :] + attn_output = ops.flash_attention(q, k, v, None, None, None, + None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1) + quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None) + output1 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias) + output_list = [] + block_size = seq // block_seq + for i in range(block_seq): + start = i * block_size + end = seq if i == block_seq - 1 else (i + 1) * block_size + attn_output = ops.flash_attention(q[:,start:end], k[:,:end], v[:,:end], None, None, None, + None, None, end - start, end, softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1) + quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None) + out = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias) + output_list.append(out) + output2 = torch.cat(output_list, dim=0) + output2 = output2.reshape(block_seq, batch, block_size, hidden_size).transpose(0, 1).reshape(batch*seq, hidden_size) + assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.0045, use_MSE=True, use_RAE=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_flash_attn_sq_mm_allreduce(self): + world_size = min(torch_mlu.mlu.device_count(), 8) + batch, seq, head_num_q, head_num_kv, head_size, is_causal = 1, 8192, 64, 8, 128, True + dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half + for i in range(1): + # seq = random.randint(1, 32768) + # is_causal = bool(random.randint(0, 1)) + print("=============test[{}]: seq = {}, causal = {}===============".format(i, seq, is_causal)) + hidden_size = head_num_q * head_size + softmax_scale = 1 / math.sqrt(head_size) + qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu() + smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0 + bias = torch.randn(hidden_size, dtype=dtype).mlu() + weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu() + quant_weight, weight_scale = QuantByRow(weight / smooth, 8) + q = qkv[:, :, :head_num_q, :] + k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :] + v = qkv[:, :, head_num_q+head_num_kv:, :] + output1, output2 = flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale, + bias, softmax_scale, is_causal, world_size) + args = world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, output1 + mp.spawn(tp_flash_attn_sq_mm, args, nprocs=world_size, join=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestFlashAttnSqMMAllreduce)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py new file mode 100755 index 0000000..ba0fbf4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_matmul_all_reduce.py @@ -0,0 +1,109 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from itertools import product +import torch.multiprocessing as mp +import sys +import os + +work_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(work_dir)) +from common_utils import * + +def mm_split_k(rank, *args): + torch.manual_seed(0) + mat_m, mat_n, mat_k, has_bias, has_res, dtype, world_size, block_m = args + assert mat_k % world_size == 0, f"mat_k{mat_k} must be divisible by tp{world_size}" + block_k = mat_k // world_size + start_k = rank * block_k + end_k = start_k + block_k + + setup(rank, world_size) + pg = get_default_group() + cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank) + + alpha = 0.625 + beta = 1.0 if has_res else 0. + input = torch.randn((mat_m, mat_k), dtype=dtype, device='mlu') + residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') if has_res else None + weight = torch.randn((mat_n, mat_k), dtype=dtype, device="mlu") + bias = torch.randn(mat_n, dtype=dtype, device="mlu") if has_bias else None + pt_output = torch.matmul(input, weight.permute(1, 0)) + if has_bias: + pt_output = pt_output + bias + pt_output *= alpha + if has_res: + pt_output += beta * residual + if has_bias: + bias *= alpha + input = input[..., start_k:end_k].contiguous() + weight = weight[..., start_k:end_k].contiguous() + bias = bias if (bias is not None and rank == 0) else None + residual = residual if (residual is not None and rank == 0) else None + beta = beta if (residual is not None and rank == 0) else 0. + output = ops.matmul_allreduce(cncl_comm, input, weight, bias, residual, alpha, beta, block_m) + if rank == 0: + assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + cleanup() + + +class TestMatMulAllReduceOp(BtTestCase): + def op_impl_base(self, *args): + return super().op_impl_base(*args) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce due to ASan issues") + def test_matmul_allreduce(self): + device_n = torch_mlu.mlu.device_count() + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + + mat_m_list = [32] + mat_n_list = [256] + mat_k_list = [1680] + has_res_list = [False, True] + has_bias_list = [False, True] + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + block_m = 4 + args = product(mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, dtype_list) + + for mat_m, mat_n, mat_k, has_bias, has_res, dtype in args: + print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp={}, block_m={}, testing...".format( + mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m), flush=True) + param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m] + mp.spawn(mm_split_k, param, nprocs=device_n, join=True) + + @unittest.skip("not test") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce_random due to ASan issues") + def test_matmul_allreduce_random(self): + import random + random.seed(0) + device_n = torch_mlu.mlu.device_count() + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for i in range(10): + tp_num = random.randint(1, device_n) + mat_m = random.randint(1, 4096) + mat_n = random.choice([512, 1024, 2048, 4096]) + k_start = (1024 // tp_num) * tp_num + mat_k = random.randrange(k_start, 10240, tp_num) + has_res = random.choice([False, True]) + has_bias = random.choice([False, True]) + dtype = random.choice(dtype_list) + block_m = random.randint(1, 10) + assert mat_k % tp_num == 0, f"mat_k{mat_k} must be divisible by tp_num{tp_num}" + print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format( + mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m), flush=True) + param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m] + mp.spawn(mm_split_k, param, nprocs=tp_num, join=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestMatMulAllReduceOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py new file mode 100755 index 0000000..6603b9d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_moe_all_reduce.py @@ -0,0 +1,305 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from typing import Union, List, Tuple +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import torch.multiprocessing as mp +import math +import sys +import os + +work_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(work_dir)) +from common_utils import * + +def moe_split_inner_size(rank, *args): + torch.manual_seed(0) + batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args + assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}" + block_k = inner_size // world_size + start_k = rank * block_k + end_k = start_k + block_k + + setup(rank, world_size) + pg = get_default_group() + cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank) + + scale_s = 0.01 # avoid the occurrence of inf + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype) + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype) + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=dtype) * scale_s + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=dtype) + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=dtype) * scale_s + bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=dtype) + + weight1 = weight1.view(expert_num, -1, inner_size, hidden_size) + weight1 = weight1[:, :, start_k:end_k] + weight1 = weight1.reshape(expert_num, -1, hidden_size).contiguous() + weight2 = weight2[..., start_k:end_k].contiguous() + if bias1 is not None: + bias1 = bias1.view(expert_num, -1, inner_size) + bias1 = bias1[..., start_k:end_k] + bias1 = bias1.reshape(expert_num, -1).contiguous() + residual = residual if (residual is not None and rank == 0) else None + + param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual, + None, None, None, None, topk, renormalize, gated, act_mode] + output = ops.fused_moe(*param) + all_reduce(output, ReduceOp.SUM, group=pg) + + param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual, + None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm] + output1 = ops.fused_moe(*param) + + new_inner_size = weight2.shape[-1] + block_e = 4096 // new_inner_size + if block_e * new_inner_size == 4096 and expert_num % block_e == 0: + param = [hidden_states, router_logit, weight1, weight2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(), + bias1, bias2, residual, + None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm] + output2 = ops.fused_moe(*param) + if rank == 0: + assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + if block_e * new_inner_size == 4096 and expert_num % block_e == 0: + assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + cleanup() + +def sq_moe_split_inner_size(rank, *args): + torch.manual_seed(0) + batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args + assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}" + block_k = inner_size // world_size + start_k = rank * block_k + end_k = start_k + block_k + + setup(rank, world_size) + pg = get_default_group() + cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank) + + scale_s = 0.1 # avoid the occurrence of inf + eps = 0.1 # Avoid the occurrence of nan + residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype) + bias1, bias2 = None, None + hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=dtype) + weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=dtype) + weight2 = torch.normal(0, 0.01, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=dtype) + router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) + input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps + act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps + weight1_shape, weight2_shape = weight1.shape, weight2.shape + weight1 = weight1 / input_smooth.unsqueeze(1) + weight2 = weight2 / act_smooth.unsqueeze(1) + quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) + quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) + quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) + w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) + quant_w1 = quant_w1.view(expert_num, -1, inner_size, hidden_size) + quant_w1 = quant_w1[:, :, start_k:end_k] + quant_w1 = quant_w1.reshape(expert_num, -1, hidden_size).contiguous() + quant_w2 = quant_w2[..., start_k:end_k].contiguous() + + if bias1 is not None: + bias1 = bias1.view(expert_num, -1, inner_size) + bias1 = bias1[..., start_k:end_k] + bias1 = bias1.reshape(expert_num, -1).contiguous() + if w1_scale is not None: + w1_scale = w1_scale.view(expert_num, -1, inner_size) + w1_scale = w1_scale[..., start_k:end_k] + w1_scale = w1_scale.reshape(expert_num, -1).contiguous() + if act_smooth is not None: + act_smooth = act_smooth[..., start_k:end_k].contiguous() + residual = residual if (residual is not None and rank == 0) else None + + param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual, + input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode] + output = ops.fused_moe(*param) + all_reduce(output, ReduceOp.SUM, group=pg) + + param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual, + input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, + block_n, cncl_comm] + output1 = ops.fused_moe(*param) + + new_inner_size = quant_w2.shape[-1] + block_e = 4096 // new_inner_size + if block_e * new_inner_size == 4096 and expert_num % block_e == 0: + param = [hidden_states, router_logit, quant_w1, + quant_w2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(), + bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk, + renormalize, gated, act_mode, 0, block_n, cncl_comm] + output2 = ops.fused_moe(*param) + if rank == 0: + assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + if block_e * new_inner_size == 4096 and expert_num % block_e == 0: + assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + cleanup() + + +class TestFusedMOEAllReduceOp(BtTestCase): + def op_impl_base(self, *args): + return super().op_impl_base(*args) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_fused_moe_allreduce due to ASan issues") + def test_single_fused_moe_allreduce(self): + print("test_single_fused_moe_allreduce") + device_n = min(torch_mlu.mlu.device_count(), 8) + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + block_n = 2 + + batch, seq, hidden_size, inner_size = 1, 1024, 8192, 8192 + expert_num, topk, gated, renormalize, act_mode, dtype = 8, 2, True, True, 'silu', torch.float16 + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \ +tp_num: {device_n}, block_n: {block_n}, testing...", flush=True) + param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n] + mp.spawn(moe_split_inner_size, param, nprocs=device_n, join=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_fused_moe_allreduce due to ASan issues") + def test_random_fused_moe_allreduce(self): + print("test_random_fused_moe_allreduce") + import random + random.seed(0) + device_n = min(torch_mlu.mlu.device_count(), 8) + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + + act_mode = 'gelu' + case_list = set() + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + while(len(case_list) < 10): + block_n = random.randint(-5, 5) + tp_num = random.randint(1, device_n) + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(512, 1024, 2) + if block_n != 0: + if hidden_size // abs(block_n) < 256: + continue + hidden_size = (hidden_size // block_n) * block_n + else: + hidden_size = 1024 * random.randint(1, 10) + k_start = (512 // tp_num) * tp_num + inner_size = random.randrange(k_start, 1024, tp_num * 2) + expert_num = random.randint(1, 32) + topk = random.randint(1,expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + dtype = random.choice(dtype_list) + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n) + if case in case_list: + continue + case_list.add(case) + print(f"random bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \ +tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True) + param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num] + mp.spawn(moe_split_inner_size, param, nprocs=tp_num, join=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_sq_fused_moe_allreduce due to ASan issues") + def test_sq_fused_moe_allreduce(self): + print("test_sq_fused_moe_allreduce") + device_n = min(torch_mlu.mlu.device_count(), 8) + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + block_n = 2 + batch, seq, hidden_size, inner_size = 5, 9, 8192, 8192 + expert_num, topk, gated, renormalize, act_mode, dtype = 20, 16, False, True, 'gelu', torch.float16 + print(f"sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \ +tp_num: {device_n}, block_n: {block_n}, testing...", flush=True) + param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n] + mp.spawn(sq_moe_split_inner_size, param, nprocs=device_n, join=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_allreduce due to ASan issues") + def test_random_sq_fused_moe_allreduce(self): + print("test_random_sq_fused_moe_allreduce") + import random + random.seed(0) + device_n = min(torch_mlu.mlu.device_count(), 8) + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + act_mode = 'gelu' + case_list = set() + while(len(case_list) < 10): + block_n = random.randint(-5, 5) + tp_num = random.randint(1, device_n) + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(256, 512, 2) + if block_n != 0: + if hidden_size // abs(block_n) < 256: + continue + hidden_size = (hidden_size // block_n) * block_n + else: + hidden_size = 1024 * random.randint(1, 10) + k_start = (512 // tp_num) * tp_num + inner_size = random.randrange(k_start, 1024, tp_num * 2) + expert_num = random.randint(1, 32) + topk = random.randint(1, expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + dtype = random.choice([ torch.float16]) + if torch_mlu.mlu.get_device_name() == 'MLU370': + dtype = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n) + if case in case_list: + continue + case_list.add(case) + print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \ +tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True) + param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num] + mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_with_4D_w2_allreduce due to ASan issues") + def test_random_sq_fused_moe_with_4D_w2_allreduce(self): + print("test_random_sq_fused_moe_with_4D_w2_allreduce") + import random + random.seed(0) + device_n = min(torch_mlu.mlu.device_count(), 8) + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + act_mode = 'gelu' + case_list = set() + while (len(case_list) < 10): + block_n = random.randint(-5, 5) + tp_num = random.randint(1, device_n) + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(512, 4096, 2) + if block_n != 0: + if hidden_size // abs(block_n) < 256: + continue + hidden_size = (hidden_size // block_n) * block_n + else: + hidden_size = 1024 * random.randint(1, 10) + inner_size_per_tp = random.choice([256, 512]) + inner_size = inner_size_per_tp * tp_num + expert_num_base = random.randint(1, 4) + expert_num_factor = 4096 // inner_size_per_tp + expert_num = expert_num_base * expert_num_factor + topk = random.randint(1, expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + dtype = random.choice([ torch.float16]) + if torch_mlu.mlu.get_device_name() == 'MLU370': + dtype = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n) + if case in case_list: + continue + case_list.add(case) + print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \ +tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True) + param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num] + mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestFusedMOEAllReduceOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_quant_matmul_all_reduce.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_quant_matmul_all_reduce.py new file mode 100755 index 0000000..b740551 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/allreduce_case/test_quant_matmul_all_reduce.py @@ -0,0 +1,157 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import torch.multiprocessing as mp +import sys +import os + +work_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(work_dir)) +from common_utils import * + +def compute_weight_only_scale(weight, quant_bit): + int_max = float(2 ** (quant_bit - 1) - 1) + weight_max = torch.max(torch.abs(weight), axis=1, keepdims=True) + weight_scale = torch.div(int_max, weight_max[0]) + weight_int = torch.mul(weight, weight_scale) + weight_int = weight_int.type(torch.int8) + weight_scale_recip = torch.div(weight_max[0], int_max).type(torch.float).squeeze() + return weight_int, weight_scale_recip + +def quant_mm_split_k(rank, *args): + torch.manual_seed(0) + M, N, K, has_bias, has_res, dtype, block_m, world_size = args + assert K % world_size == 0 + block_k = K // world_size + start_k = rank * block_k + end_k = start_k + block_k + + setup(rank, world_size) + pg = get_default_group() + cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank) + + a = torch.randn(M, K, device="mlu", dtype=dtype) + b = torch.randn(N, K, device="mlu", dtype=dtype) + b_int, gemm_output_scale = compute_weight_only_scale(b, 8) + c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None + bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None + torch_quant_matmul = QuantMatmul(b_int, None, None, None, None, gemm_output_scale, dtype) + pt_output = torch_quant_matmul(a).detach() + + a = a[..., start_k:end_k].contiguous() + b_int = b_int[..., start_k:end_k].contiguous() + bias = bias if (bias is not None and rank == 0) else None + c = c if (c is not None and rank == 0) else None + + param = [cncl_comm, a, None, None, b_int, None, None, bias, c, None, + None, gemm_output_scale, None, "half", "weight_only", "quantize_none", + "quantize_per_channel", 8, 1.0, 1.0, False, True, block_m] + # beta = beta if (c is not None and rank == 0) else 0. + output = ops._ops.quant_matmul_allreduce(*param) + if rank == 0: + assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + cleanup() + +def sq_mm_split_k(rank, *args): + torch.manual_seed(0) + M, N, K, has_bias, has_res, dtype, world_size, block_m = args + assert K % world_size == 0 + block_k = K // world_size + start_k = rank * block_k + end_k = start_k + block_k + + setup(rank, world_size) + pg = get_default_group() + cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank) + + a = torch.randint(-10, 10, (M, K), dtype=torch.int8).mlu() + b = torch.randint(-10, 10, (N, K), dtype=torch.int8).mlu() + c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None + bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None + a_scale = torch.randn(M, device="mlu", dtype=torch.float) + b_scale = torch.randn(N, device="mlu", dtype=torch.float) + torch_quant_matmul = QuantMatmul(b, bias, c, a_scale, b_scale, None, dtype) + pt_output = torch_quant_matmul(a).detach() + + a = a[..., start_k:end_k].contiguous() + b = b[..., start_k:end_k].contiguous() + bias = bias if (bias is not None and rank == 0) else None + c = c if (c is not None and rank == 0) else None + # beta = beta if (c is not None and rank == 0) else 0. + param = [cncl_comm, a, a_scale, b, b_scale, dtype, bias, c, 1.0, 1.0, block_m] + output = ops.smooth_quant_matmul_allreduce(*param) + if rank == 0: + assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True) + cleanup() + +class TestGptQuantMatmulOp(BtTestCase): + def op_impl_base(self, *args): + return super().op_impl_base(*args) + + # weight only, no bias, no residual, no activation + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_weight_only_matmul_allreduce due to ASan issues") + def test_single_weight_only_matmul_allreduce(self): + device_n = torch_mlu.mlu.device_count() + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + block_m = 8 + M, K, N = 32, 256, 128 + has_bias = False + has_res = False + dtype = torch.half + print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format( + M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True) + assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}" + param = [M, N, K, has_bias, has_res, dtype, block_m, device_n] + mp.spawn(quant_mm_split_k, param, nprocs=device_n, join=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_sq_matmul_allreduce due to ASan issues") + def test_single_sq_matmul_allreduce(self): + device_n = torch_mlu.mlu.device_count() + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + block_m = 1 + M, K, N = 2598, 1024, 1024 + has_bias = False + has_res = False + dtype = torch.half + print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format( + M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True) + assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}" + param = [M, N, K, has_bias, has_res, dtype, device_n, block_m] + mp.spawn(sq_mm_split_k, param, nprocs=device_n, join=True) + + @unittest.skip("not test") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_mm_allreduce due to ASan issues") + def test_random_sq_mm_allreduce(self): + import random + random.seed(0) + device_n = torch_mlu.mlu.device_count() + if device_n < 2: + print(f"device count is {device_n}, can not test communication between devices.", flush=True) + + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + for i in range(10): + tp_num = random.randint(1, device_n) + M = random.randint(1, 4096) + N = random.choice([512, 1024, 2048, 4096]) + k_start = (1024 // tp_num) * tp_num + K = random.randrange(k_start, 2048, tp_num) + has_res = random.choice([False, True]) + has_bias = random.choice([False, True]) + dtype = random.choice(dtype_list) + block_m = random.randint(1, 10) + + print("M={}, N={}, K={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format( + M, N, K, has_bias, has_res, dtype, tp_num, block_m), flush=True) + param = [M, N, K, has_bias, has_res, dtype, tp_num, block_m] + mp.spawn(sq_mm_split_k, param, nprocs=tp_num, join=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestGptQuantMatmulOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/common_utils.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/common_utils.py new file mode 100644 index 0000000..cab5bec --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/common_utils.py @@ -0,0 +1,734 @@ +import sys +sys_args = sys.argv +sys.argv = [sys_args.pop(0)] # prevent unittest printing help info + +import os +import torch +import torch_mlu +from torch.testing._internal.common_utils import TestCase +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from typing import List, Tuple, Optional +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group as get_default_group, all_reduce, ReduceOp +import torch.testing._internal.optests as optests +import random +import argparse +from abc import abstractmethod, ABC +import unittest +import torch_mlu_ops as tmo +import os + +act_mode_dict = {"relu": torch.nn.functional.relu, + "gelu": torch.nn.functional.gelu, + "silu": torch.nn.functional.silu} + +class BtTestCase(TestCase, ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + os.environ['TORCH_ALLOW_TF32_CNMATMUL_OVERRIDE'] = '0' + + @abstractmethod + def op_impl_base(self, *args): + pass + + @abstractmethod + def test_inductor(self): + pass + + def base_opcheck(self, interface_overload, args): + target_check = ["test_schema", "test_autograd_registration"] + if torch.__version__ >= '2.3.0': + target_check.append("test_faketensor") + target_status = {key: "SUCCESS" for key in target_check} + result = optests.opcheck(interface_overload, args, test_utils=target_check) + self.assertEqual(result, target_status,) + + def assertException(self, error_msg, func, *args, **kwinputs): + try: + func(*args, **kwinputs) + self.assertTrue(False) + except Exception as e: + if error_msg: + self.assertTrue(error_msg == str(e)) + else: + self.assertTrue(True) + + def assertTensorsEqual(self, + a, + b, + prec=None, + message='', + allow_inf=False, + use_MSE=False, + use_RAE=False, + use_RMA=False): + '''unittest.TestCase''' + if a.dtype == torch.bool: + a = a.float() + if b.dtype == torch.bool: + b = b.float() + epsilon = 1.0 / 16384 + self.assertEqual(a.size(), b.size(), message) + assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor." + if a.numel() > 0: + # check that NaNs are in the same locations + nan_mask = a != a + self.assertTrue(torch.equal(nan_mask, b != b), message) + diff = a - b + diff[nan_mask] = 0 + a = a.clone() + b = b.clone() + a[nan_mask] = 0 + b[nan_mask] = 0 + # inf check if allow_inf=True + if allow_inf: + inf_mask = (a == float("inf")) | (a == float("-inf")) + self.assertTrue(torch.equal(inf_mask, + (b == float("inf")) | (b == float("-inf"))), + message) + diff[inf_mask] = 0 + a[inf_mask] = 0 + b[inf_mask] = 0 + # TODO: implement abs on CharTensor + if diff.is_signed() and 'CharTensor' not in diff.type(): + diff = diff.abs() + if use_MSE: + diff = diff.abs().pow(2).sum() + a_pow_sum = a.pow(2).sum() + if diff <= (2 * epsilon) * (2 * epsilon): + diff = 0.0 + if a_pow_sum <= epsilon: + a_pow_sum = a_pow_sum + epsilon + diff = torch.div(diff, (a_pow_sum * 1.0)) + self.assertLessEqual(diff.sqrt(), prec, message) + elif use_RAE: + diff = diff.abs().sum() + a_sum = a.abs().sum() + if a_sum == 0: + self.assertEqual(a, b, message) + else: + diff = torch.div(diff, a_sum) + self.assertLessEqual(diff, prec, message) + elif use_RMA: + a_mean = a.abs().mean() + b_mean = b.abs().mean() + if a_mean == 0: + self.assertEqual(a, b, message) + else: + diff = torch.div((a_mean - b_mean).abs(), a_mean) + self.assertLessEqual(diff, prec, message) + else: + max_err = diff.max() + self.assertLessEqual(max_err, prec, message) + +def run_unittest(case) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('-k', nargs='+', type=str, default="", help='specify case to run') + args = parser.parse_args(sys_args) + if args.k != "": + ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromNames(args.k, case)) + else: + ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(case)) + return not ret.wasSuccessful() + +class TMOTimer: + def __init__(self, repeat: int = 1): + self.repeat = repeat + + def __enter__(self): + self.notify_start = torch.mlu.Event(enable_timing=True) + self.notify_end = torch.mlu.Event(enable_timing=True) + self.notify_start.record() + + def __exit__(self, exc_type, exc_value, traceback): + self.notify_end.record() + self.notify_end.synchronize() + total_hardware_time = self.notify_start.hardware_time(self.notify_end) + self.average_hardware_time = total_hardware_time / self.repeat + +def QuantByRow(input: torch.Tensor, quant_bit: int, group_num: int=1): + input_shape = input.shape + if input.dim() > 2: + input = input.view(-1, input_shape[-1]) + if input.dim() == 1: + input = input.unsqueeze(0) + assert input.dim() == 2, "input must be 2-D tensor." + assert quant_bit == 4 or quant_bit == 8, "quant_bit must be 4 or 8." + assert group_num >= 1, "group_num >= 1." + int_max = float(2 ** (quant_bit - 1) - 1) + int_min = -float(2 ** (quant_bit - 1)) + group_size = input.size(-1) // group_num + input_v = input.view(input.size(0), group_num, group_size) if group_num > 1 else input + max, _ = input_v.abs().max(dim=-1, keepdim=True) + scale = max.to(torch.float) / int_max + quant_input = (input_v / scale).round().clamp(int_min, int_max).to(torch.int8).view(input.size()) + return quant_input.view(input_shape), scale.squeeze(-1) + +def QuantByTensor(input: torch.Tensor, quant_bit: int): + int_max = float(2 ** (quant_bit - 1) - 1) + int_min = -float(2 ** (quant_bit - 1)) + input_max = torch.max(torch.abs(input)) + input_scale = int_max / input_max + input_int = torch.mul(input, input_scale).round().clamp(int_min, int_max).to(torch.int8) + return input_int, input_scale + +def PairlyPackInt8(input): + assert input.dtype == torch.int8, "dtype of input must be int8." + assert input.dim() == 2 or input.dim() == 3, "input must be 2-D or 3-D tensor." + assert input.size(-1) % 2 == 0, "size(-1) of input must be even." + input_shape = list(input.shape) + input_flat = input.flatten() + d0 = input_flat[0::2].to(torch.uint8) + d1 = input_flat[1::2].to(torch.uint8) + dp = (d1 << 4) + (d0 & 0x0F) + input_shape[-1] = input_shape[-1] // 2 + return dp.to(torch.int8).reshape(input_shape) + +def UnpackInt4(input): + assert input.dtype == torch.int8, "dtype of input must be int8." + input_flat = input.flatten() + n = input_flat.size(0) + output = torch.zeros(n * 2, dtype=torch.int8, device=input.device) + high = input_flat >> 4 + low = input_flat << 4 + low = low >> 4 + output[0::2] = low + output[1::2] = high + + return output + +def smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias=None): + assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2" + assert a_scale.dim() == 1, "a_scale.dim() == 1" + assert a.size(0) == a_scale.size(0), "a.size(0) == a_scale.size(0)" + assert b.size(0) == b_scale.size(-1), "b.size(0) == b_scale.size(-1)" + m = a.size(0) + n = b.size(0) + a_k = a.size(1) + b_k = b.size(1) + if b_scale.dim() == 1: + b_scale = b_scale.unsqueeze(0) + quant_group = b_scale.size(0) + a = a.view(m, quant_group, -1).transpose(0, 1).contiguous() + if a_k == b_k * 2: + b = UnpackInt4(b) + b = b.view(n, quant_group, -1).transpose(0, 1).contiguous() + out = torch.zeros(m, n, dtype=torch.float, device=a.device) + for i in range(quant_group): + scale_mn = torch.matmul(a_scale.unsqueeze(1), b_scale[i].unsqueeze(0)) # (m, 1) x (1, n) = (m, n) + out += torch.einsum('mk,nk->mn', a[i].to(torch.float), b[i].to(torch.float)) * scale_mn + # out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype) + out = out.to(out_dtype) + if bias is not None: + out += bias + return out + +def smooth_quant_matmul_w4w8_mixed(a, a_scale, b, b_scale, out_dtype, bias=None, quant_flag=None): + m = a.shape[0] + k = a.shape[1] + quant_group = b_scale.shape[0] + group_wise = k // quant_group + n = b_scale.shape[1] + b = b.view(n, -1) + a = a.view(m, quant_group, -1).transpose(0, 1).contiguous() + new_b = [] + start = 0 + end = 0 + for i in range(quant_group): + if quant_flag[i] == 4: + end += group_wise // 2 + new_b.append(UnpackInt4(b[:, start:end]).view(n, -1)) + else: + end += group_wise + new_b.append((b[:, start:end])) + start = end + new_b = torch.cat(new_b, 1) + b = new_b.view(n, quant_group, -1).transpose(0, 1).contiguous() + out = torch.zeros(m, n, dtype=torch.float, device=a.device) + for i in range(quant_group): + out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype) + out = out.to(out_dtype) + if bias is not None: + out += bias + return out + +def weight_only_quant_matmul(a, b, scale, bias=None): + assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2" + assert scale.dim() == 1 or scale.dim() == 2, "scale.dim() == 1 or scale.dim() == 2" + assert b.size(0) == scale.size(0), "b.size(0) == b_scale.size(0)" + assert a.size(1) == b.size(1), "a.size(1) == b.size(1)" + if scale.dim() == 2: + group_size = b.size(1) // scale.size(1) + scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape) + else: + scale_bd = scale.unsqueeze(-1) + b1 = b * scale_bd + out = torch.einsum('mk,nk->mn', a.to(torch.float), b1.to(torch.float)).to(a.dtype) + if bias is not None: + out += bias + return out + +def single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale, + v_cache_quant_scale, alibi_slopes, window_size_left, window_size_right, softmax_scale, return_lse): + q = q.float() + k_cache = k_cache.float() + v_cache = v_cache.float() + def masked_attention(query, key, value, alibi_slope, context_len, window_size_left, window_size_right, qk_scale) -> torch.Tensor: + # (num_heads, seq_q, seq_k) + qk = torch.einsum('qhd,hkd->hqk', query, key) + qk = qk * qk_scale + if alibi_slope is not None: + alibi_dist = torch.arange(0, context_len, dtype=torch.float32).mlu() + alibi = alibi_slope[:, None] * alibi_dist + qk = qk + alibi[:, None, :] + + _, seq_q, seq_k = qk.size() + if seq_q > 1: #causal mask + ml = torch.zeros((seq_q, seq_k - seq_q), dtype=qk.dtype).mlu() + ones = torch.ones((seq_q, seq_q), dtype=qk.dtype).mlu() * -torch.inf + mr = torch.triu(ones, diagonal=1) + mask = torch.cat((ml, mr), dim=-1) + qk = qk + mask + if window_size_left != -1 or window_size_right != -1: + mask_w = torch.full((seq_q, seq_k), -torch.inf, dtype=torch.float, device="mlu") + for qi in range(seq_q): + left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0 + right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k + mask_w[qi, left:right] = 0 + qk += mask_w + attention = torch.softmax(qk, dim = -1, dtype=qk.dtype) + qkv = torch.einsum('hqk,hkd->qhd', attention, value) + return qkv, qk + + if k_cache_quant_scale is not None and v_cache_quant_scale is not None: + if k_cache_quant_scale.dim() == 2: # per_channel: [kv_head_num, head_size] + k_cache_quant_scale = k_cache_quant_scale.reshape(1, k_cache_quant_scale.shape[0], 1, k_cache_quant_scale.shape[1]) + v_cache_quant_scale = v_cache_quant_scale.reshape(1, v_cache_quant_scale.shape[0], 1, v_cache_quant_scale.shape[1]) + elif k_cache_quant_scale.dim() == 3: # per_token: [num_blocks, k_head_num, block_size] + k_cache_quant_scale = k_cache_quant_scale.reshape(*k_cache_quant_scale.shape, 1) + v_cache_quant_scale = v_cache_quant_scale.reshape(*v_cache_quant_scale.shape, 1) + k_cache *= k_cache_quant_scale + v_cache *= v_cache_quant_scale + bs, seq_q, num_heads, head_size = q.size() + head_size_v = v_cache.size(-1) + num_blocks, num_kv_heads, block_size, _ = k_cache.size() + output = torch.zeros((bs, seq_q, num_heads, head_size_v), dtype=torch.float16) + lse = torch.zeros((bs, num_heads, seq_q), dtype=torch.float) + + assert (num_heads % num_kv_heads == 0) + head_repeats = num_heads // num_kv_heads + for bs_id in range(bs): + q_bs = q[bs_id] + context_len = int(context_lens[bs_id]) + if context_len == 0: + output[bs_id] = torch.zeros((seq_q, num_heads, head_size_v), device = q.device, dtype=output.dtype) + lse[bs_id] = lse[bs_id].fill_(-float('inf')) + else : + block_table = block_tables[bs_id] + table_end = (context_len + block_size - 1) // block_size + block_ids = block_table[0 : table_end] + keys, values = k_cache[block_ids], v_cache[block_ids] + + keys = torch.repeat_interleave(keys, head_repeats, dim=1) + keys = keys.transpose(1, 0).contiguous().view(num_heads, -1, head_size) + keys = keys[:, 0:context_len, :] + + values = torch.repeat_interleave(values, head_repeats, dim=1) + values = values.transpose(1, 0).contiguous().view(num_heads, -1, head_size_v) + values = values[:, 0:context_len, :] + + alibi_slope = alibi_slopes[bs_id] if alibi_slopes is not None else None + qkv, qk= masked_attention(q_bs, keys, values, alibi_slope, context_len, window_size_left, window_size_right, softmax_scale) + output[bs_id] = qkv + lse[bs_id] = torch.logsumexp(qk, dim = -1) + return (output, lse) if return_lse else output + + +def update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs): + # only pad + is_pack = out.dim() == 3 + new_out, new_lse = out.clone(), lse.clone() + batch, max_seq_len, block_seq_len = lse.shape[0], lse.shape[-1], block_lse.shape[-1] + + lse_bsh = lse.transpose(-2, -1).unsqueeze(dim=-1) + new_lse_bsh = new_lse.transpose(-2, -1).unsqueeze(dim=-1) + block_lse_bsh = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + if not is_pack: + for i in range(batch): + out_seq_offset = 0 if seq_offsets is None else seq_offsets[i] + out_i = out[i, out_seq_offset : out_seq_offset + block_seq_len] + lse_i = lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] + block_out_i = block_out[i, :] + block_lse_i = block_lse_bsh[i, :] + new_out[i, out_seq_offset : out_seq_offset + block_seq_len] = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i) + new_lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] = (lse_i - F.logsigmoid(lse_i - block_lse_i)) + else: + for i in range(batch): + block_i_begin = block_cu_seqs[i] + block_i_end = block_cu_seqs[i + 1] + block_i_lens = block_i_end - block_i_begin + out_i_begin = cu_seqs[i] + out_seq_offset = seq_offsets[i] + + block_out_i = block_out[block_i_begin : block_i_end] + block_lse_i = block_lse_bsh[i, 0 : block_i_lens] + out_i = out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] + lse_i = lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] + new_out_i = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i) + new_lse_i = (lse_i - F.logsigmoid(lse_i - block_lse_i)) + new_out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] = new_out_i + new_lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] = new_lse_i + + return (new_out, new_lse_bsh.squeeze(dim=-1).transpose(-2, -1)) + +class QuantMatmul(torch.nn.Module): + def __init__(self, weight, bias, residual, input_scale, weight_scale, gemm_output_scale, dtype, + alpha:float = 1.0, beta:float = 1.0, act_mode:str = 'none') -> None: + super().__init__() + self.dtype = dtype + self.weight = Parameter(weight.type(dtype)) + self.input_scale = input_scale + self.weight_scale = weight_scale + self.gemm_output_scale = gemm_output_scale + if bias is not None: + self.bias = Parameter(bias) + else: + self.bias = None + if residual is not None: + self.residual = Parameter(residual) + else: + self.residual = None + self.alpha = alpha + self.beta = beta + if act_mode == 'none': + self.act = None + else: + self.act = act_mode_dict[act_mode] + # d = (a * b + bias) * alpha + c * beta + # output = (input * weight + bias) * alpha + residual * beta + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = F.linear(input.type(self.dtype), self.weight, self.bias) + if self.input_scale is not None: + i_scale = self.input_scale.expand(self.weight_scale.shape[0], -1).transpose(0, 1) + output = torch.mul(output, i_scale) + if self.weight_scale is not None: + output = torch.mul(output, self.weight_scale) + if self.gemm_output_scale is not None: + output = torch.mul(output, self.gemm_output_scale) + output = torch.mul(output, self.alpha) + if self.residual is not None: + residual = torch.mul(self.residual, self.beta) + output = torch.add(output, residual) + if self.act is not None: + output = self.act(output) + return output + + +# for multiprocessing +def assertTensorsEqual( a, + b, + prec=None, + message='', + allow_inf=False, + use_MSE=False, + use_RAE=False, + use_RMA=False): + tc = TestCase() + if a.dtype == torch.bool: + a = a.float() + if b.dtype == torch.bool: + b = b.float() + epsilon = 1.0 / 16384 + tc.assertEqual(a.size(), b.size(), message) + assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor." + if a.numel() > 0: + # check that NaNs are in the same locations + nan_mask = a != a + tc.assertTrue(torch.equal(nan_mask, b != b), message) + diff = a - b + diff[nan_mask] = 0 + a = a.clone() + b = b.clone() + a[nan_mask] = 0 + b[nan_mask] = 0 + # inf check if allow_inf=True + if allow_inf: + inf_mask = (a == float("inf")) | (a == float("-inf")) + tc.assertTrue(torch.equal(inf_mask, + (b == float("inf")) | (b == float("-inf"))), + message) + diff[inf_mask] = 0 + a[inf_mask] = 0 + b[inf_mask] = 0 + # TODO: implement abs on CharTensor + if diff.is_signed() and 'CharTensor' not in diff.type(): + diff = diff.abs() + if use_MSE: + diff = diff.abs().pow(2).sum() + a_pow_sum = a.pow(2).sum() + if diff <= (2 * epsilon) * (2 * epsilon): + diff = 0.0 + if a_pow_sum <= epsilon: + a_pow_sum = a_pow_sum + epsilon + diff = torch.div(diff, (a_pow_sum * 1.0)) + tc.assertLessEqual(diff.sqrt(), prec, message) + elif use_RAE: + diff = diff.abs().sum() + a_sum = a.abs().sum() + if a_sum == 0: + tc.assertEqual(a, b, message) + else: + diff = torch.div(diff, a_sum) + tc.assertLessEqual(diff, prec, message) + elif use_RMA: + a_mean = a.abs().mean() + b_mean = b.abs().mean() + if a_mean == 0: + tc.assertEqual(a, b, message) + else: + diff = torch.div((a_mean - b_mean).abs(), a_mean) + tc.assertLessEqual(diff, prec, message) + else: + max_err = diff.max() + tc.assertLessEqual(max_err, prec, message) + +def setup(rank, world_size, backend='cncl'): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '3458' + dist.init_process_group(backend, rank=rank, world_size=world_size) + torch_mlu.mlu.set_device(rank) + +def cleanup(): + dist.barrier() + dist.destroy_process_group() + +def generate_token_count(num_expert, + total_token_count): + token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32) + sum = torch.sum(token_count, dim=-1) * 1.0 + token_count *= total_token_count / sum.item() + token_count = token_count.to(dtype=torch.int32) + cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32) + end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count + cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) + cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count) + cusum_token_count[-1] = total_token_count + return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1] + +def generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, + quant_mode=False, offline=False, invalid_batch_size=0): + q_heads = 1 + total_heads = q_heads + num_heads * 2 + max_bs = batch_size + 1 + + context_lens = torch.randint(size=(batch_size, ), low=1, + high=cache_memory_len // 2, + dtype=torch.int32, device='mlu') + max_context_len = context_lens.max().item() + max_seq_offset = max_context_len // 3 + 1 + + cache_bs_id = random.sample([*range(0, batch_size)], batch_size) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + if invalid_batch_size > 0: + cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch_size)] = -1 + context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1, + high=(cache_memory_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + if packed > 0: + context = torch.randn((total_seqlen, total_heads, head_size), + dtype=torch.float, device='mlu') + context_seq_offsets = None + else: + context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size), + dtype=torch.float, device='mlu') + cu_context_lens = context_lens + context = context.to(dtype) + key = context[..., q_heads:q_heads + num_heads, :] + value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :] + + cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu') + cache_scale = None + if quant_mode: + cache = (cache - 0.5) * 256 + cache = cache.to(torch.int8) + if offline: + cache_scale = torch.randn((2, cache.shape[2], cache.shape[4]), dtype=torch.float, device='mlu') + else: + cache_scale = torch.randn((2, max_bs, num_heads, cache_memory_len), dtype=torch.float, device='mlu') + else: + cache = cache.to(dtype) + + block_size = 16 if "MLU3" not in torch.mlu.get_device_name() else max_context_len + min_blocks = (total_seqlen + block_size - 1) // block_size + num_blocks = min(min_blocks + 10, 2 * min_blocks) + num_slots = num_blocks * block_size + slot_mapping = random.sample(range(num_slots), total_seqlen.item()) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu() + slot_mapping[-1] = -1 + + return [key, value, cache[0], cache[1], cu_context_lens, max_context_len, + packed > 0, context_seq_offsets, cache_bs_id, cache_seq_offsets, + cache_scale, slot_mapping] + +def fused_moe(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], + residual: Optional[torch.Tensor], + input_smooth: Optional[torch.Tensor], + act_smooth: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + topk: int, + renormalized: bool, + gated: bool, + act_mode: str, + start_expert_id: int = 0, + block_n: int = 0, + cncl_comm: int = 0, + w1_quant_flag: Optional[List] = None, + w2_quant_flag: Optional[List] = None): + dtype = hidden_states.dtype + ori_input_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + tokens = hidden_states.size(0) + gating_output = gating_output.reshape(-1, gating_output.size(-1)) + residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None + expert_num = gating_output.size(-1) + expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1) + + per_token_sq = False + # check quant + check_list = [input_smooth, act_smooth, w1_scale, w2_scale] + if all(x is not None for x in check_list): + per_token_sq = True + + if not (all(x is None for x in check_list) or all(x is not None for x in check_list)): + raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present " + "and absent at the same time.") + + # softmax_topk + reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalized) + # gen_idx + expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num) + + if per_token_sq: + if torch.mlu.get_device_name() == 'MLU370': + expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx, + cusum_token_count, start_expert_id, expert_size) + quant_input, input_scale = tmo.moe_quantize(expand_hidden_states, + input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size]) + else: + quant_input, input_scale = tmo.moe_quantize(hidden_states, + input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx, + cusum_token_count[start_expert_id].unsqueeze(0)) + else: + expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx, + cusum_token_count, start_expert_id, expert_size) + + # group gemm + if per_token_sq: + gemm1_out = tmo.smooth_quant_group_gemm(quant_input, + w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, + input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag) + else: + gemm1_out = tmo.group_gemm(expand_hidden_states, + w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, + None, + None, + None, tokens) + # add_bias_active + act_out = tmo.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size) + + if per_token_sq: + quant_input, input_scale = tmo.moe_quantize(act_out, act_smooth, None, + token_count[start_expert_id:start_expert_id+expert_size]) + + if cncl_comm > 0: + raise ValueError("not support communication and computing fusion currently.") + else: + if per_token_sq: + gemm2_out = tmo.smooth_quant_group_gemm(quant_input, + w2, token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag) + else: + gemm2_out = tmo.group_gemm(act_out, + w2, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, tokens) + + output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx, + residual, cusum_token_count, start_expert_id, + expert_size, bias2) + return output.reshape(ori_input_shape) + +def min_mem_size(shape, stride): + if stride is None: + mem_size = 0 + mem_size += shape.numel() + else: + mem_size = 1 + for k,v in zip(shape, stride): + mem_size += (k - 1) * v + return mem_size + +def create_tensor(shape, dtype, is_contiguous, device, stride = None, mean=0, var=1, is_uniform=False, low=0, high=1): + if is_contiguous: + if dtype in (torch.int8, torch.uint8): + t = torch.randint(-128, 127, shape, device=device).to(dtype) + else: + if is_uniform: + t = torch.empty(shape, dtype=dtype, device=device).uniform_(low, high) + else: + t = torch.normal(mean, var, shape, dtype=dtype, device=device) + else: + mem_size = min_mem_size(shape, stride) + if dtype in (torch.int8, torch.uint8): + t = torch.randint(-128, 127, (mem_size,), device=device).to(dtype) + else: + if is_uniform: + t = torch.empty((mem_size,), dtype=dtype, device=device).uniform_(low, high) + else: + t = torch.normal(mean, var, (mem_size,), dtype=dtype, device=device) + t = t.as_strided(shape, stride) + return t + +def create_tensor_from_dic(dic:dict, mean=0, var=1, is_uniform=False, low=0, high=1): + if dic['data'] is None: + return None + shape = dic['shape'] + dtype = dic['dtype'] + is_contiguous = dic['is_contiguous'] + device = dic['device'] + stride = dic['stride'] + return create_tensor(shape, dtype, is_contiguous, device, stride, mean, var, is_uniform, low, high) + +def create_op_param(dic: dict): + if dic['type'] in (list, tuple): + return [create_op_param(elem) for elem in dic['data']] if dic['has_compound'] else dic['data'] + elif dic['type'] is dict: + return {k:create_op_param(v) for k,v in dic['data'].items()} + elif dic['type'] is torch.Tensor: + if dic['data'] is None: + return None + else: + if dic['dtype'] in (torch.int16, torch.int32, torch.int64): + return dic['data'] + else: + return create_tensor(dic['shape'], dic['dtype'], dic['is_contiguous'], dic['device'], dic['stride']) + else: + return dic['data'] diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/run_gen_case.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/run_gen_case.py new file mode 100644 index 0000000..dd32964 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/run_gen_case.py @@ -0,0 +1,151 @@ +import os +os.environ['TMO_GEN_CASE'] = '0' + +import sys +sys_args = sys.argv +sys.argv = [sys_args.pop(0)] # prevent unittest printing help info + +import copy +import argparse +import torch +import torch_mlu +import importlib +import random +from common_utils import create_op_param + +def assert_tensor_equal(a: torch.Tensor, b: torch.Tensor, threshold): + assert a.size() == b.size() + a_ = a.cpu().reshape(-1).double() + b_ = b.cpu().reshape(-1).double() + nan_mask = a_ != a_ + assert torch.equal(nan_mask, b_ != b_), "tensor a and tensor b have different number of nan" + diff = a_ - b_ + diff[nan_mask] = 0 + a_[nan_mask] = 0 + b_[nan_mask] = 0 + eps = 1e-10 + diff1 = diff.abs().sum() / (a_.abs().sum() + eps) + diff2 = torch.sqrt((diff**2).sum() / ((a_**2).sum() + eps)) + print(f"[torch_mlu_ops] diff1: {diff1}, diff2: {diff2}") + assert diff1 <= threshold, f"diff1: {diff1} <= threshold: {threshold}" + assert diff2 <= threshold, f"diff2: {diff2} <= threshold: {threshold}" + +def check_value_equal(x, y, threshold): + assert type(x) == type(y) + if type(x) is torch.Tensor: + assert_tensor_equal(x, y, threshold) + elif type(x) is list or type(x) is tuple: + for i in range(len(x)): + check_value_equal(x[i], y[i], threshold) + elif x is not None: + assert x == y + +def check_equal(a, b, threshold): + assert type(a) == type(b) + if type(a) is tuple or type(a) is list: + assert len(a) == len(b) + for x, y in zip(a, b): + check_value_equal(x, y, threshold) + else: + check_value_equal(a, b, threshold) + +def get_base_obj(module_name, class_name): + try: + mod_ = importlib.import_module(module_name) + cls_ = getattr(mod_, class_name) + return cls_() + except ImportError as e: + print(f"Failed to import class '{class_name}' from module '{module_name}': {e}") + return None + except AttributeError as e: + print(f"Module '{module_name}' does not have a class named '{class_name}': {e}") + return None + +def get_tmo_func(func_name): + import torch_mlu_ops as tmo + return getattr(tmo, func_name) + + +op_map = { + "active": ["test_active", "TestActive", 0.004], + "apply_rotary": ["test_apply_rotary", "TestApplyRotaryOp", 0.003], + "attention_project": ["test_attn_proj", "TestAttnProjOp", 0.003], + "batch_matmul": ["test_batch_matmul", "TestBatchMatMulOp", 0.004], + "copy_blocks": ["test_copy_blocks", "TestCopyBlocksOp", 0], + "dequant_from_linear_cache": ["test_dequant_from_linear_cache", "TestDequantFromLinearCache", 0.001], + "dequant_from_paged_cache": ["test_dequant_from_paged_cache", "TestDequantFromPagedCache", 0.001], + "ffn": ["test_ffn", "TestFFNOp", 0.005], + "flash_attention": ["test_flash_attention", "TestFlashAttnOp", 0.005], + # "flash_attn_sq_mm_allreduce": ["test_flash_attn_sq_mm_allreduce", "TestFlashAttnSqMMAllreduce", 0.003], + "fused_layer_norm": ["test_fused_layernorm", "TestFuseLayerNormOp", 0.003], + "fused_moe": ["test_moe", "TestFusedMOEOp", 0.006], + "fused_norm_attention_project": ["test_fused_attn_proj", "TestFusedNormAttnProjOp", 0.003], + "fused_norm_residual_ffn": ["test_fused_ffn", "TestFusedNormResidualFFNoP", 0.003], + "fused_rms_norm": ["test_fused_rmsnorm", "TestFuseRmsNormOp", 0.0032], + "fused_rope": ["test_fused_rope", "TestFusedRopeOp", 0.003], + "group_gemm": ["test_group_gemm", "TestGroupGemmOp", 0.006], + "matmul": ["test_matmul", "TestMatMulOp", 0.003], + # "matmul_allreduce": ["test_matmul_all_reduce", "TestMatMulAllReduceOp", 0.006], + "moe_active": ["test_moe_add_bias_activation", "TestMoeActiveKernel", 0.003], + "moe_cast_gating": ["test_moe_cast_gating", "TestMoeCastGating", 0.0001], + "moe_combine_result": ["test_moe_combine_result", "TestCombineResult", 0.003], + "moe_expand_input": ["test_moe_expand_input", "TestExpandInput", 0], + "moe_gen_idx": ["test_moe_gen_idx", "TestGenIdx", 0], + "moe_quantize": ["test_smooth_quant", "TestSmoothQuantOp", 0.01], + "moe_softmax_topk": ["test_moe_softmax_topk", "TestSoftmaxTopkOp", 0.003], + "offline_quant_to_linear_cache": ["test_offline_quant_to_linear_cache", "TestOfflineQuantToLinearCache", 0.03], + "offline_quant_to_paged_cache": ["test_offline_quant_to_paged_cache", "TestOfflineQuantToPagedCache", 0.03], + "per_token_smooth_quantize": ["test_per_token_smooth_quantize", "TestPerTokenSmoothQuantizeOp", 0.003], + # "preload": ["test_preload", "TestPreloadOp", 1], + "quant_to_linear_cache": ["test_quant_to_linear_cache", "TestQuantToLinearCache", 0.003], + "quant_to_paged_cache": ["test_quant_to_paged_cache", "TestQuantToPagedCache", 0.009], + "quantize": ["test_quantize", "TestQuantizeOp", 0.003], + "reshape_linear_cache": ["test_reshape_linear_cache", "TestReshapeLinearCache", 0], + "reshape_paged_cache": ["test_reshape_paged_cache", "TestReshapePagedCacheOp", 0], + "single_query_cached_kv_attn": ["test_single_query_cached_kv_attn", "TestSingleQueryAttnOp", 0.003], + "single_query_mixed_cached_kv_attn": ["test_single_query_mixed_cached_kv_attn", "TestSingleQueryMixedKVAttnOp", 0.003], + "smooth_quant_group_gemm": ["test_smooth_quant_group_gemm", "TestSmoothQuantGroupGemmOp", 0.006], + "smooth_quant_matmul": ["test_smooth_quant_matmul", "TestSmoothQuantMatmulOp", 0.006], + # "smooth_quant_matmul_allreduce": ["test_quant_matmul_all_reduce", "TestGptQuantMatmulOp", 0.006], + "swap_blocks": ["test_swap_blocks", "TestSwapBlocksOp", 0], + "update_out_and_lse": ["test_update_out_and_lse", "TestUpdateOutAndLse", 0.005], + "weight_only_quant_matmul": ["test_weight_only_quant_matmul", "TestWeightOnlyQuantMatmulOp", 0.004], +} + + +def run_case(pt_case): + op_name = pt_case.pop('op') + op_obj = get_base_obj(op_map[op_name][0], op_map[op_name][1]) + if hasattr(op_obj, "run_gen_case"): + op_obj.run_gen_case(pt_case) + else: + dump_data = pt_case.pop('dump_data') + if dump_data: + params = pt_case + else: + params = dict() + for k,v in pt_case.items(): + params[k] = create_op_param(v) + params_bak = copy.deepcopy(params) + result_tmo = get_tmo_func(op_name)(**params) + result_base = op_obj.op_impl_base(*params_bak.values()) + check_equal(result_tmo, result_base, op_map[op_name][-1]) + +def main(): + random.seed(0) + torch.manual_seed(0) + parser = argparse.ArgumentParser() + parser.add_argument('--case_path', required=True, type=str, help='specify the case path') + parser.add_argument('--detail', action="store_true", help="show content of pt file") + args = parser.parse_args(args=sys_args) + pt_case = torch.load(args.case_path) + if args.detail: + for k,v in pt_case.items(): + print(f"{k}: {v}") + exit(0) + print(f"[torch_mlu_ops] run {args.case_path} ...") + run_case(pt_case) + print(f"[torch_mlu_ops] run {args.case_path} successfully") + +if __name__ == "__main__": + main() diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/run_test.sh b/torch_mlu_ops-v1.3.2/tests/ops_pytest/run_test.sh new file mode 100755 index 0000000..374083b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/run_test.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +tmo_ops_case=$(find "${SCRIPT_DIR}" -name "test*.py") +coverage=${1} + +for sc in ${tmo_ops_case} +do +echo -n "${sc} " +echo -n "Testing...." +if [ "${coverage}" = "coverage" ];then +coverage run -a ${sc} +else +python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1 +fi +if [ $? == 0 ];then +echo -e "\033[32m success \033[0m" +else +echo -e "\033[31m failed \033[0m" +fi +done +echo "End of pytest..." diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py new file mode 100644 index 0000000..bd71e01 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_active.py @@ -0,0 +1,88 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from common_utils import * +import random +from itertools import product +import time +import os + +class TestActive(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + input = create_tensor_from_dic(dic['input']) + act_mode = dic['act_mode']['data'] + is_gated = dic['is_gated']['data'] + active_coef = dic['active_coef']['data'] + self.launch(input, act_mode, is_gated, active_coef) + + def launch(self, *args): + torch_out = self.op_impl_base(*args) + tmo_out = tmo.active(*args) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + + def op_impl_base(self, *args): + input, act_mode, is_gated, active_coef = args + channel = input.size(-1) + if act_mode == "gelu": + if is_gated: + out = torch.nn.functional.gelu(input[..., :channel//2]) + out *= input[..., channel//2:] + else: + out = torch.nn.functional.gelu(input) + else: + if act_mode == "silu": + active_coef = 1.0 + elif act_mode == "quick_gelu": + active_coef = 1.702 + + def swish(input, coef): + return input * torch.sigmoid(coef * input) + + if is_gated: + out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:] + else: + out = swish(input, active_coef) + return out + + def test_active_random(self): + for _ in range(500): + dtype_list = [torch.float, torch.half] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + input_dtype = random.choice(dtype_list) + batch = random.randint(1, 10) + seq = random.randint(1, 2048) + hidden_size = random.randrange(2, 8192, 2) + is_gated = random.choice([True, False]) + act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish']) + if act_mode == 'silu': + active_coef = 1.0 + elif act_mode == 'quick_gelu': + active_coef = 1.702 + else: + active_coef = random.uniform(0, 1) + print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \ + [batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True) + input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu") + self.launch(input, act_mode, is_gated, active_coef) + + def test_inductor(self): + input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu") + output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu") + is_gated = True + act_mode = 'silu' + coef = 1.0 + args = (input, output, None, None, act_mode, is_gated, 0, 0, coef) + self.base_opcheck(torch.ops.torch_mlu_ops.active, args) + +if __name__ == '__main__': + exit(run_unittest(TestActive)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_apply_rotary.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_apply_rotary.py new file mode 100755 index 0000000..1281f59 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_apply_rotary.py @@ -0,0 +1,152 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +def gen_args(bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype): + cu_context_lens = None + total_seq_len = bs * seq_len + max_context_len = seq_len + if packed: + context_lens = torch.randint(size=(bs, ), low=1, high=seq_len+1, dtype=torch.int32, device='mlu') + total_seq_len = context_lens.sum().item() + max_context_len = context_lens.max().item() + cu_context_lens = torch.cumsum(context_lens, dim=-1, dtype=torch.int32) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0) + + context_shape = (total_seq_len, q_heads + kv_heads + 1, head_size) if packed else \ + (bs, seq_len, q_heads + kv_heads + 1, head_size) + context = torch.randn(size=context_shape, dtype=dtype).mlu() + qk = context[..., 0 : q_heads + kv_heads, :] + + position_id = None + if discrete: + position_id = torch.randint(0, max_context_len, size=(total_seq_len,), dtype=torch.int32, device="mlu") + else: + position_id = torch.randint(0, max_context_len, size=(bs,), dtype=torch.int32, device="mlu") + + rope_seqlen = seq_len * 2 + cos_shape = (bs, rope_seqlen, rope_dim) if dynamic_ntk else (rope_seqlen, rope_dim) + cos_cache = torch.randn(size=cos_shape, dtype=dtype, device="mlu") + sin_cache = torch.randn(size=cos_shape, dtype=dtype, device="mlu") + + return (qk, sin_cache, cos_cache, position_id, cu_context_lens, \ + interleaved, discrete, dynamic_ntk, max_context_len) + +class TestApplyRotaryOp(BtTestCase): + def op_impl_base(self, *args): + def rotate(x: torch.Tensor, interleaved: bool): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + y = torch.empty_like(x) + x1, x2 = x[..., ::2], x[..., 1::2] + y[..., ::2], y[..., 1::2] = -x2, x1 + return y + input, sin_cache, cos_cache, position_ids, cu_seqlen, interleaved, discrete, dynamic_ntk, max_seqlen = args + packed = input.dim() == 3 + rope_dim = sin_cache.shape[-1] + batch_size = cu_seqlen.shape[0] - 1 if packed else input.shape[0] + sin_cache_float = sin_cache.float() + cos_cache_float = cos_cache.float() + cu_seqlen_cpu = cu_seqlen.cpu() if cu_seqlen is not None else None + position_ids_cpu = position_ids.cpu() if position_ids is not None else None + for i in range(batch_size): + input_i = input[cu_seqlen_cpu[i] : cu_seqlen_cpu[i + 1]] if packed else input[i] + input_i = input_i[..., 0:rope_dim] + sin_cache_i = sin_cache_float[i] if dynamic_ntk else sin_cache_float + cos_cache_i = cos_cache_float[i] if dynamic_ntk else cos_cache_float + seq = input_i.shape[0] + if discrete: + if packed: + position_id_i = position_ids_cpu[cu_seqlen_cpu[i] : cu_seqlen_cpu[i + 1]] + else: + position_id_i = position_ids_cpu.view(batch_size, -1)[i] + sin_cache_i = sin_cache_i[position_id_i] + cos_cache_i = cos_cache_i[position_id_i] + else: + if position_ids_cpu is None: + sin_cache_i = sin_cache_i[:seq] + cos_cache_i = cos_cache_i[:seq] + else: + pos_id = position_ids_cpu[i].item() + sin_cache_i = sin_cache_i[pos_id : seq + pos_id] + cos_cache_i = cos_cache_i[pos_id : seq + pos_id] + rot = rotate(input_i.float(), interleaved) + output_i = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1) + input_i[:] = output_i.to(input.dtype) + return input + + def test_apply_rotary(self): + bs_list = [1, 8] + seq_len_list = [1, 128, 1024] + q_heads_list = [8, 32] + kv_heads_list = [1] + head_size_list = [128, 256] + rope_dim_list = [256, 128, 64, 24] + is_interleaved_list = [True, False] + discrete_list = [True, False] + dynamic_ntk_list = [True, False] + packed_list = [True, False] + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + args = product(bs_list, seq_len_list, q_heads_list, kv_heads_list, head_size_list, \ + rope_dim_list, is_interleaved_list, discrete_list, dynamic_ntk_list, packed_list, dtype_list) + for bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype in args: + print("bs: {}, seq_len: {}, q_heads: {}, kv_heads: {}, head_size: {}, " + "rope_dim: {}, interleaved: {}, discrete: {}, dynamic_ntk: {}, packed: {}, dtype: {} testing...".format( \ + bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype), flush=True) + if rope_dim > head_size: + print("rope_dim = {}, head_size = {},rope_dim should less than head_size".format(rope_dim, head_size)) + continue + + qk, sin_cache, cos_cache, position_id, cu_context_lens, \ + interleaved, _, _, max_context_len = gen_args(bs, seq_len, q_heads, kv_heads, head_size, + rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype) + + qk_base = qk.clone() + self.op_impl_base(qk_base, sin_cache, cos_cache, position_id, cu_context_lens, \ + interleaved, discrete, dynamic_ntk, max_context_len) + + qk_out = ops.apply_rotary(qk, sin_cache, cos_cache, position_id, cu_context_lens, \ + interleaved, discrete, dynamic_ntk, max_context_len) + + self.assertTensorsEqual(qk_out.cpu().float(), qk_base.cpu().float(), 0.003, use_MSE=True, use_RAE=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + bs, seq, head_num, head_size, dynamic_ntk, rotary_seqlen, rotary_dim, discrete = 1, 1024, 8, 128, False, 512, 128, True + + q = torch.randn(bs, seq, head_num, head_size, dtype=torch.half, device="mlu") + sin = torch.randn(rotary_seqlen, rotary_dim, dtype=torch.half, device="mlu") + cos = torch.randn(rotary_seqlen, rotary_dim, dtype=torch.half, device="mlu") + self.assertException("discrete must be false if position ids is null.", ops.apply_rotary, + q, sin, cos, None, None, False, discrete, dynamic_ntk, seq) + + discrete = False + self.assertException("max_seqlen must less than or equal to rope_seqlen.", ops.apply_rotary, + q, sin, cos, None, None, False, discrete, dynamic_ntk, seq) + position_ids = torch.zeros(bs, dtype=torch.int32, device="mlu") + self.assertException("max_seqlen must less than or equal to rope_seqlen.", ops.apply_rotary, + q, sin, cos, position_ids, None, False, discrete, dynamic_ntk, seq) + + def test_inductor(self): + is_interleaved_list = [True, False] + discrete_list = [True, False] + dynamic_ntk_list = [True, False] + packed_list = [True, False] + params = product(is_interleaved_list, discrete_list, dynamic_ntk_list, packed_list) + bs, seq_len, q_heads, kv_heads, head_size, rope_dim, dtype = 8, 1024, 8, 1, 256, 24, torch.float16 + for interleaved, discrete, dynamic_ntk, packed in params: + print(f"==== check apply_rotary interleaved: {interleaved}, discrete: {discrete}, dynamic_ntk: {dynamic_ntk}, packed: {packed} ====") + args = gen_args(bs, seq_len, q_heads, kv_heads, head_size, + rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.apply_rotary, args) + + +if __name__ == '__main__': + exit(run_unittest(TestApplyRotaryOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_attn_proj.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_attn_proj.py new file mode 100755 index 0000000..9601f1f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_attn_proj.py @@ -0,0 +1,44 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from torch.nn.parameter import Parameter +from torch.nn import functional as F + + +class TestAttnProjOp(BtTestCase): + def op_impl_base(self, *args): + input, weight, bias, residual, alpha, beta = args + proj = F.linear(input, weight, bias) + output = alpha * proj + beta * residual + return output + + def test_attn_proj(self): + N, T, input_size, hidden_size, alpha, beta = 32, 129, 2048, 4096, 0.5, 0.1 + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format( + N, T, input_size, hidden_size), flush=True) + input = torch.randn(N, T, input_size, dtype=dtype, device="mlu") + weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") + bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu") + residual = torch.randn(N, T, hidden_size * 3, dtype=dtype, device="mlu") + torch_out = self.op_impl_base(input, weight, bias, residual, alpha, beta) + tmo_out = ops.attention_project(input, weight, bias, residual, alpha, beta) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + N, T, input_size, hidden_size, dtype = 32, 129, 2048, 4096, torch.half + input = torch.randn(N, T, input_size, dtype=dtype, device="mlu") + weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") + bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu") + args = (input, weight, bias, None, None, None, + None, None, None, None, "nthc", 1, + 1e-5, 1., 0., False) + self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args) + +if __name__ == '__main__': + exit(run_unittest(TestAttnProjOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_batch_matmul.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_batch_matmul.py new file mode 100644 index 0000000..d7bdd7d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_batch_matmul.py @@ -0,0 +1,83 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +import numpy as np + +class TestBatchMatMulOp(BtTestCase): + def op_impl_base(self, *args): + a, b, c, alpha, beta, a_scale, b_scale, trans_a, trans_b = args + if trans_a: + a = a.transpose(1, 2) + if trans_b: + b = b.transpose(1, 2) + if c is None: + output_dtype = a.dtype + else: + output_dtype = c.dtype + if a_scale is not None: + a = torch.div(a, a_scale).to(output_dtype) + b = torch.div(b, b_scale).to(output_dtype) + output = alpha * torch.bmm(a, b) + if c is not None: + output += beta * c + c.copy_(output) + return output + + def test_batch_matmul(self): + batch_list = [5] + mat_m_list = [32] + mat_n_list = [256] + mat_k_list = [128] + has_res_list = [False, True] + trans_a_list = [False, True] + trans_b_list = [False, True] + + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + alpha = 0.625 + beta = 1.0 + args = product(batch_list, mat_m_list, mat_n_list, mat_k_list, has_res_list, dtype_list, trans_a_list, trans_b_list) + + for batch, mat_m, mat_n, mat_k, has_res, dtype, trans_a, trans_b in args: + print("batch={}, m={}, n={}, k={}, has_res={}, dtype={}, trans_a={}, trans_b={} testing...".format( + batch, mat_m, mat_n, mat_k, has_res, dtype, trans_a, trans_b), flush=True) + shape_a, shape_b = (batch, mat_m, mat_k), (batch, mat_k, mat_n) + if trans_a: + shape_a = (batch, mat_k, mat_m) + if trans_b: + shape_b = (batch, mat_n, mat_k) + input = torch.randn(shape_a, dtype=dtype, device='mlu') + weight = torch.randn(shape_b, dtype=dtype, device='mlu') + input8, a_scale = QuantByTensor(input, 8) + weight8, b_scale = QuantByTensor(weight, 8) + residual = torch.randn((batch, mat_m, mat_n), dtype=dtype, device='mlu') + res_bak = residual.clone() + tmo_output_int8 = torch.zeros((batch, mat_m, mat_n), dtype=dtype, device='mlu') + torch_output_int8 = tmo_output_int8.clone() + output = self.op_impl_base(input, weight, + residual if has_res else None, alpha, beta, 1.0, 1.0, trans_a, trans_b) + tmo_output = ops.batch_matmul(input, weight, + res_bak if has_res else None, alpha, beta, 1.0, 1.0, trans_a, trans_b) + + self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + if dtype != torch.bfloat16: + self.op_impl_base(input8, weight8, torch_output_int8, alpha, beta, a_scale.item(), b_scale.item(), trans_a, trans_b) + ops.batch_matmul(input8, weight8, tmo_output_int8, alpha, beta, a_scale.item(), b_scale.item(), trans_a, trans_b) + self.assertTensorsEqual(tmo_output_int8.cpu().float(), torch_output_int8.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + + def test_inductor(self): + batch, mat_m, mat_n, mat_k, alpha, beta, dtype = 6, 64, 256, 128, 0.8, 0.3, torch.float16 + a = torch.randn((batch, mat_m, mat_k), dtype=dtype, device='mlu') + b = torch.randn((batch, mat_n, mat_k), dtype=dtype, device='mlu') + c = torch.randn((batch, mat_m, mat_n), dtype=dtype, device='mlu') + args = (a, b, c, alpha, beta, 1.0, 1.0, False, True) + self.base_opcheck(torch.ops.torch_mlu_ops.batch_matmul, args) + +if __name__ == '__main__': + exit(run_unittest(TestBatchMatMulOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_copy_blocks.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_copy_blocks.py new file mode 100755 index 0000000..1e126f3 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_copy_blocks.py @@ -0,0 +1,167 @@ +import math +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +from typing import List, Tuple +import os +import copy + +class TestCopyBlocksOp(BtTestCase): + def create_kv_caches( + self, + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + seed: int + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = head_size**-0.5 + # vllm scale + # x = 16 // torch.tensor([], dtype=dtype).element_size() + # key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache_shape = (num_blocks, num_heads, block_size, head_size) + key_caches = [] + for _ in range(num_layers): + if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}: + info = torch.iinfo(dtype) + key_cache = torch.randint(info.min, info.max, size=key_cache_shape, dtype=dtype).mlu() + else: + key_cache = torch.empty(size=key_cache_shape, dtype=dtype).mlu() + key_cache.uniform_(-scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}: + info = torch.iinfo(dtype) + value_cache = torch.randint(info.min, info.max, size=value_cache_shape, dtype=dtype).mlu() + else: + value_cache = torch.empty(size=value_cache_shape, dtype=dtype).mlu() + value_cache.uniform_(-scale, scale) + value_caches.append(value_cache) + + return key_caches, value_caches + + def create_block_mapping(self, num_blocks, num_mappings, seed = 0): + random.seed(seed) + torch.random.manual_seed(seed) + assert 3 * num_mappings <= num_blocks + block_mapping = {} + src_blocks = random.sample(range(num_blocks), num_mappings) + remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) + for i in range(num_mappings): + src = src_blocks[i] + dst1 = dst_blocks[2 * i] + dst2 = dst_blocks[2 * i + 1] + block_mapping[src] = [dst1, dst2] + return block_mapping + + def op_impl_base(self, *args): + k_caches, v_caches, block_mapping = args + for src, dsts in block_mapping.items(): + srcs = [src for i in range(len(dsts))] + srcs_ind = torch.tensor(srcs, dtype=torch.int64) + dsts_ind = torch.tensor(dsts, dtype=torch.int64) + for key_cache in k_caches: + key_cache[dsts_ind] = key_cache[srcs_ind] + if v_caches is not None: + for value_cache in v_caches: + value_cache[dsts_ind] = value_cache[srcs_ind] + return (k_caches, v_caches) if v_caches is not None else k_caches + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_copy_blocks due to ASan issues") + def test_copy_blocks(self): + dtype_list = [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + num_tokens_list = [83] + num_heads_list = [8] + head_size_list = [64, 512] + num_blocks_list = [3600] + block_size_list = [8] + num_layers_list = [1, 6] + num_mappings_list = [128, 600] + seeds_list = [0] + only_key_cache_list = [True, False] + + args = product(num_tokens_list, num_heads_list, head_size_list, num_blocks_list, block_size_list, dtype_list, + num_layers_list, num_mappings_list, seeds_list) + for num_tokens, num_heads, head_size, num_blocks, block_size, dtype, num_layers, num_mappings, seed in args: + print("num_tokens: {}, num_heads: {}, head_size: {}, num_blocks: {}, block_size: {}, dtype: {}, num_layers: {}, \ +num_mappings: {}, seed: {}, testing...".format( + num_tokens, num_heads, head_size, num_blocks, block_size, dtype, num_layers, num_mappings, seed), flush=True) + block_mapping = self.create_block_mapping(num_blocks, num_mappings, seed) + only_key_cache = random.choice(only_key_cache_list) + # Create the KV caches. + key_caches, value_caches = self.create_kv_caches(num_blocks, block_size, + num_layers, num_heads, + head_size, dtype, seed) + # Clone the KV caches. + cloned_key_caches = [key_cache.cpu().clone() for key_cache in key_caches] + cloned_value_caches = [value_cache.cpu().clone() for value_cache in value_caches] + + # Call the copy blocks kernel. + if only_key_cache: + value_caches = None + cloned_value_caches = None + ops.copy_blocks(key_caches, value_caches, block_mapping) + + self.op_impl_base(cloned_key_caches, cloned_value_caches, block_mapping) + + # Compare the results. + for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): + self.assertTensorsEqual(key_cache.cpu().float(), cloned_key_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + if not only_key_cache: + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): + self.assertTensorsEqual(value_cache.cpu().float(), cloned_value_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_inductor due to ASan issues") + def test_prevent(self): + print("test_copy_block: test_prevent...") + num_blocks, block_size, head_num, head_size, block_mapping = 384, 6, 32, 128, 128 + k_cache = torch.randn(num_blocks * head_num, block_size, head_size, dtype=torch.half, device="mlu") + v_cache = torch.randn(num_blocks * head_num, head_size, block_size + 1, dtype=torch.float, device="mlu") + key_caches = [k_cache,] + value_caches = None + block_mapping = self.create_block_mapping(num_blocks, block_mapping) + self.assertException("every layer k_cache must be 4d.", ops.copy_blocks, + key_caches, value_caches, block_mapping) + k_cache = k_cache.reshape(num_blocks, head_num, block_size, head_size) + key_caches = [k_cache, k_cache,] + value_caches = [v_cache] + self.assertException("k_caches size must equal to v_caches size if v_caches is not none.", + ops.copy_blocks, key_caches, value_caches, block_mapping) + key_caches = [k_cache] + value_caches = [v_cache] + self.assertException("the data type of k_caches and v_caches are not the same.", + ops.copy_blocks, key_caches, value_caches, block_mapping) + value_caches[0] = value_caches[0].to(torch.half) + self.assertException("every layer k_cache dim must equal to v_cache dim.", + ops.copy_blocks, key_caches, value_caches, block_mapping) + value_caches[0] = value_caches[0].reshape(num_blocks, head_num, head_size, block_size + 1) + self.assertException("the block_size of k_caches and v_caches are not the same.", + ops.copy_blocks, key_caches, value_caches, block_mapping) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_inductor due to ASan issues") + def test_inductor(self): + num_heads, head_size, num_blocks, block_size, num_layers, num_mappings = 8, 64, 384, 8, 1, 128 + key_caches, value_caches = self.create_kv_caches(num_blocks, block_size, + num_layers, num_heads, + head_size, torch.float16, 0) + block_mapping = self.create_block_mapping(num_blocks, num_mappings) + self.base_opcheck(torch.ops.torch_mlu_ops.copy_blocks, (key_caches, value_caches, block_mapping)) + +if __name__ == '__main__': + exit(run_unittest(TestCopyBlocksOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_linear_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_linear_cache.py new file mode 100755 index 0000000..244dfc3 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_linear_cache.py @@ -0,0 +1,280 @@ +import random +import unittest + +import numpy as np +import math +import torch +import torch_mlu_ops as ops + +from common_utils import * + +def gen_args(max_batch_size, + batch_size, + max_context_len, + head_num_q, + head_num_kv, + cache_mem_len, + head_size, + group_size, + use_seq_offset, + dtype, + quant_mode, + quant_bit, + has_value = True): + # Preprocess arguments + assert max_batch_size >= batch_size, \ + "max_batch_size should greater than or equal to batch_size." + assert cache_mem_len >= max_context_len, \ + "cache_mem_len should greater then or equal to max_context_len." + assert head_size % group_size == 0, \ + "head_size should be a multiply of groupwise." + total_heads = head_num_q + head_num_kv * 2 + max_seq_offset = cache_mem_len - max_context_len + # Generates key and cache from context + context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1, + dtype=torch.int32, device="mlu") + if use_seq_offset: + context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset, + dtype=torch.int32, device="mlu") + else: + context_paddings = torch.zeros_like(context_lens) + cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1) + total_seqlen = cu_context_lens[-1] + context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu") + context_seq_offset[1:] = cu_context_lens[:-1] + context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu") + key = context[..., head_num_q:head_num_q + head_num_kv, :] + value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :] + + # Generates key_cache and value_cache + cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu() + cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size], + dtype=torch.int32, device="mlu") + if quant_bit == 4: + key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size, + head_num_kv, cache_mem_len, head_size // 2), device="mlu") + value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size, + head_num_kv, cache_mem_len // 2, head_size), device="mlu") + key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8) + else: + cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size), + low=-128, high=127, dtype=torch.int32, device="mlu") + cache = cache.to(torch.int8) + key_cache, value_cache = cache[[0, 1]] + + # Generates key_cache_scale and value_cache_scale + if quant_mode == 0: # quant_mode == 0 is per channel + cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu") + else: # quant_mode != 1 (== 1 for extend) is per head + cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len), + dtype=torch.float, device="mlu") + key_cache_scale, value_cache_scale = cache_scale[[0, 1]] + + # Prepare arguments + if has_value == False: + value = None + value_cache = None + value_cache_scale = None + args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale] + args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None, + cache_bs_id, cache_seq_offset] + args += [quant_mode, quant_bit] + return args + +class TestDequantFromLinearCache(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + key = create_tensor_from_dic(dic['key']) + value = create_tensor_from_dic(dic['value']) + key_cache = create_tensor_from_dic(dic['key_cache']) + value_cache = create_tensor_from_dic(dic['value_cache']) + key_cache_quant_scale = create_tensor_from_dic(dic['key_cache_quant_scale']) + value_cache_quant_scale = create_tensor_from_dic(dic['value_cache_quant_scale']) + context_lengths = dic['context_lengths']['data'] + max_context_len = dic['max_context_len']['data'] + context_seq_offset = dic['context_seq_offset']['data'] + cache_bs_id = dic['cache_bs_id']['data'] + cache_seq_offset = dic['cache_seq_offset']['data'] + quant_mode = dic['quant_mode']['data'] + quant_bit = dic['quant_bit']['data'] + + self.launch(key, value, key_cache, value_cache, key_cache_quant_scale, + value_cache_quant_scale, context_lengths, max_context_len, + context_seq_offset, cache_bs_id, cache_seq_offset, quant_mode, quant_bit) + + def launch(self, *args): + key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \ + context_lengths, max_context_len, context_seq_offset, cache_bs_id, \ + cache_seq_offset, quant_mode, quant_bit = args + if value is None or value_cache is None or value_cache_scale is None: + has_value = False + else: + has_value = True + ops.dequant_from_linear_cache(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, context_lengths, max_context_len, + context_seq_offset, cache_bs_id, cache_seq_offset, + quant_mode, quant_bit) + key_clone = key.clone() + if has_value: + value_clone = value.clone() + self.op_impl_base(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, context_lengths, max_context_len, + context_seq_offset, cache_bs_id, cache_seq_offset, quant_mode, + quant_bit) + self.assertTensorsEqual(key_clone.cpu().float(), key.cpu().float(), 0.001, use_MSE=True) + if has_value: + self.assertTensorsEqual(value_clone.cpu().float(), value.cpu().float(), 0.001, + use_MSE=True) + + def op_impl_base(self, *args): + def dequant_from_cache(quant_data: torch.Tensor, + scale_data: torch.Tensor, + quant_mode: int): + quant_data_fp32 = quant_data.clone().to(torch.float) + scale_data_fp32 = scale_data.clone().to(torch.float) + if quant_mode == 0: # per channel + scale_data_fp32 = scale_data[..., None, :] + else: # per head/token + scale_data_fp32 = scale_data[..., None] + dequant_data_fp32 = quant_data_fp32 * scale_data_fp32 + + return dequant_data_fp32 + + key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \ + context_lengths, max_context_len, context_seq_offset, cache_bs_id, \ + cache_seqlen_offset, quant_mode, quant_bit = args + batch_size = context_lengths.size(0) + if context_seq_offset is None: + cu_seq_offset = torch.cumsum(context_lengths, dim=-1) + context_seq_offset = torch.zeros_like(cu_seq_offset) + context_seq_offset[1:] = cu_seq_offset[:-1] + + total = 0 + cache_mem_len = key_cache.size(2) + for i in range(batch_size): + context_len = context_lengths[i].item() + seq_begin = context_seq_offset[i].item() + seq_end = seq_begin + context_len + total += context_len + cache_id = cache_bs_id[i] if cache_bs_id is not None else i + cache_seq_begin = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0 + cache_seq_end = cache_seq_begin + context_len + key_i = key[seq_begin:seq_end].transpose(1, 0) + if quant_bit == 4: + key_cache_i_temp = key_cache[cache_id, :, cache_seq_begin:cache_seq_end] + cache_size = list(key_cache_i_temp.size()) + cache_size[-1] *= 2 + key_cache_i = torch.zeros(cache_size, dtype=torch.int8, device="mlu") + key_cache_i[...,::2] = key_cache_i_temp << 4 >> 4 + key_cache_i[...,1::2] = key_cache_i_temp >> 4 + else: + key_cache_i = key_cache[cache_id, :, cache_seq_begin:cache_seq_end] + # We use negatice cache_seq_offset to skip unused batch + if cache_seq_begin < 0 or cache_seq_end > cache_mem_len: + continue + # dequant key from cache + if quant_mode == 0: + key_cache_scale_i = key_cache_scale + else: + key_cache_scale_i = key_cache_scale[cache_id, :, cache_seq_begin:cache_seq_end] + + dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale_i, quant_mode) + key_i[...] = dequant_key_i.to(key_i.dtype) + + # dequant value from cache + if not (value_cache is None or value is None or value_cache_scale is None): + value_i = value[seq_begin:seq_end].transpose(1, 0) + if quant_bit == 4: + pad_front = cache_seq_begin % 2 + pad_back = cache_seq_end % 2 + cache_seq_begin_temp = int(cache_seq_begin // 2) + cache_seq_end_temp = int(math.ceil(cache_seq_end / 2.0)) + + value_cache_i_temp = value_cache[cache_id, :, + cache_seq_begin_temp:cache_seq_end_temp] + cache_size = list(value_cache_i_temp.size()) + cache_size[-2] *= 2 + value_cache_i = torch.zeros(cache_size, dtype=torch.int8, device="mlu") + value_cache_i[...,::2,:] = value_cache_i_temp << 4 >> 4 + value_cache_i[...,1::2,:] = value_cache_i_temp >> 4 + if pad_front: + value_cache_i = value_cache_i[...,1:,:] + if pad_back: + value_cache_i = value_cache_i[...,:-1,:] + else: + value_cache_i = value_cache[cache_id, :, cache_seq_begin:cache_seq_end] + if quant_mode == 0: + value_cache_scale_i = value_cache_scale + else: + value_cache_scale_i = value_cache_scale[cache_id, :, + cache_seq_begin:cache_seq_end] + dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale_i, + quant_mode) + value_i[...] = dequant_value_i.to(value.dtype) + + def test_dequant_from_linear_cache(self): + test_cases = 100 + head_size_times = 16 + mlu_name = torch.mlu.get_device_name() + max_batch_size_list = torch.randint(low=32, high=64, size=[test_cases], + dtype=torch.int32) + batch_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32) + max_context_len_list = torch.randint(low=2, high=2048, size=[test_cases], + dtype=torch.int32) + head_num_q_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32) + head_num_kv_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32) + head_size_list = torch.randint(low=1, high=16, size=[test_cases], + dtype=torch.int32) * head_size_times + cache_mem_len_list = torch.randint(low=1024, high=2048, size=[test_cases], + dtype=torch.int32) * 2 + quant_mode_list = np.random.choice([0, 1], test_cases) + quant_bit_list = np.random.choice([4, 8], test_cases) + use_offset_list = np.random.choice([False, True], test_cases) + has_value_list = np.random.choice([False, True], test_cases) + dtype_list = [torch.half, torch.bfloat16] + dtype_list = dtype_list[:-1] if "MLU3" in mlu_name else dtype_list + dtype_list = np.random.choice(dtype_list, test_cases) + for i in range(test_cases): + max_batch_size = max_batch_size_list[i].item() + batch_size = batch_size_list[i].item() + head_num_q = head_num_q_list[i].item() + max_context_len = max_context_len_list[i].item() + head_num_kv = head_num_kv_list[i].item() + head_size = head_size_list[i].item() + cache_mem_len = cache_mem_len_list[i].item() + quant_mode = quant_mode_list[i] + quant_bit = quant_bit_list[i] + use_seq_offset = use_offset_list[i] + has_value = has_value_list[i] + dtype = dtype_list[i] + if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \ + * head_size >= 2**31 - 1): + print("large tensor is not support on {}, skip".format(mlu_name)) + continue + + print("batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, " + "quant_bit={}, dtype={} testing...".format(batch_size, head_num_kv, head_size, + max_context_len, quant_mode, quant_bit, dtype)) + + torch.manual_seed(2766) + args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv, + cache_mem_len, head_size, head_size, use_seq_offset, dtype, quant_mode, + quant_bit, has_value) + self.launch(*args) + + def test_inductor(self): + max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \ + head_size, dtype = 16, 8, 1024, 16, 32, 2048, 128, torch.float16 + quant_mode, quant_bit = 0, 8 + args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv, + cache_mem_len, head_size, head_size, 1, dtype, quant_mode, quant_bit) + self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_linear_cache, args) + args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv, + cache_mem_len, head_size, head_size, 0, dtype, quant_mode, quant_bit) + self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_linear_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestDequantFromLinearCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_paged_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_paged_cache.py new file mode 100755 index 0000000..7d15775 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_dequant_from_paged_cache.py @@ -0,0 +1,286 @@ +import random +import unittest + +import numpy as np +import math +import torch +import torch_mlu_ops as ops + +from common_utils import * + +def gen_args(batch_size, + max_context_len, + head_num_q, + head_num_kv, + cache_mem_len, + head_size, + group_size, + block_size, + use_seq_offset, + dtype, + quant_mode, + quant_bit, + has_value = True): + # Preprocess arguments + assert cache_mem_len >= max_context_len, \ + "cache_mem_len should greater then or equal to max_context_len." + assert head_size % group_size == 0, \ + "head_size should be a multiply of groupwise." + total_heads = head_num_q + head_num_kv * 2 + max_seq_offset = cache_mem_len - max_context_len + max_block_num = int(math.ceil(max_context_len / block_size)) + total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size + block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size, + max_block_num) + # Generates key and cache from context + context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1, + dtype=torch.int32, device="mlu") + if use_seq_offset: + context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset, + dtype=torch.int32, device="mlu") + else: + context_paddings = torch.zeros_like(context_lens) + cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1) + total_seqlen = cu_context_lens[-1] + context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu") + context_seq_offset[1:] = cu_context_lens[:-1] + context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu") + key = context[..., head_num_q:head_num_q + head_num_kv, :] + value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :] + # Generates key_cache and value_cache + cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size), + low=-128, high=127, dtype=torch.int32, device="mlu") + cache = cache.to(torch.int8) + key_cache, value_cache = cache[[0, 1]] + + # Generates key_cache_scale and value_cache_scale + if quant_mode == 0: # quant_mode == 0 is per channel + cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu") + else: # quant_mode != 1 (== 1 for extend) is per head + cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size), + dtype=torch.float, device="mlu") + key_cache_scale, value_cache_scale = cache_scale[[0, 1]] + + # Prepare arguments + if has_value == False: + value = None + value_cache = None + value_cache_scale = None + args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale] + args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None, + block_tables] + args += [quant_mode, quant_bit] + return args + +class TestDequantFromPagedCache(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + key = create_tensor_from_dic(dic['key']) + value = create_tensor_from_dic(dic['value']) + key_cache = create_tensor_from_dic(dic['key_cache']) + value_cache = create_tensor_from_dic(dic['value_cache']) + key_cache_quant_scale = create_tensor_from_dic(dic['key_cache_quant_scale']) + value_cache_quant_scale = create_tensor_from_dic(dic['value_cache_quant_scale']) + context_lengths = dic['context_lengths']['data'] + max_context_len = dic['max_context_len']['data'] + context_seq_offset = dic['context_seq_offset']['data'] + block_tables = dic['block_tables']['data'] + quant_mode = dic['quant_mode']['data'] + quant_bit = dic['quant_bit']['data'] + + self.launch(key, value, key_cache, value_cache, key_cache_quant_scale, + value_cache_quant_scale, context_lengths, max_context_len, + context_seq_offset, block_tables, quant_mode, quant_bit) + + def launch(self, *args): + key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \ + context_lengths, max_context_len, context_seq_offset, block_tables, \ + quant_mode, quant_bit = args + if value is None or value_cache is None or value_cache_scale is None: + has_value = False + else: + has_value = True + ops.dequant_from_paged_cache(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, context_lengths, max_context_len, + context_seq_offset, block_tables, quant_mode, quant_bit) + key_clone = key.clone() + if has_value: + value_clone = value.clone() + self.op_impl_base(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, context_lengths, max_context_len, + context_seq_offset, block_tables, quant_mode, quant_bit) + self.assertTensorsEqual(key_clone.cpu().float(), key.cpu().float(), 0.001, use_MSE=True) + if has_value: + self.assertTensorsEqual(value_clone.cpu().float(), value.cpu().float(), 0.001, + use_MSE=True) + + def op_impl_base(self, *args): + def dequant_from_cache(quant_data: torch.Tensor, + scale_data: torch.Tensor, + quant_mode: int): + quant_data_fp32 = quant_data.clone().to(torch.float) + scale_data_fp32 = scale_data.clone().to(torch.float) + if quant_mode == 0: # per channel [head_num, 1, head_size] + scale_data_fp32 = scale_data[..., None, :] + else: # per head/token [head_num, context_len, 1] + scale_data_fp32 = scale_data[..., None] + dequant_data_fp32 = quant_data_fp32 * scale_data_fp32 + + return dequant_data_fp32 + + key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \ + context_lengths, max_context_len, context_seq_offset, block_tables, \ + quant_mode, quant_bit = args + batch_size = context_lengths.size(0) + if context_seq_offset is None: + cu_seq_offset = torch.cumsum(context_lengths, dim=-1) + context_seq_offset = torch.zeros_like(cu_seq_offset) + context_seq_offset[1:] = cu_seq_offset[:-1] + + total_seqlen = 0 + block_size = key_cache.size(2) + for i in range(batch_size): + context_len = context_lengths[i].item() + seq_begin = context_seq_offset[i].item() + seq_end = seq_begin + context_len + total_seqlen += context_len + full_block_num = context_len // block_size + rem_token_num = context_len % block_size + key_i = key[seq_begin:seq_end].transpose(1, 0) + # [head_num, seq_num, head_size] + key_cache_i = torch.concat( + [key_cache[block_tables[i, j], ...] for j in range(full_block_num)] + + ([key_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] \ + if rem_token_num > 0 else []), dim=-2 + ) + + # dequant key from cache + if quant_mode == 0: + # [head_num, head_size] + key_cache_scale_i = key_cache_scale + else: + # [head_num, seq_num] + key_cache_scale_i = torch.concat( + [key_cache_scale[block_tables[i, j],...] for j in range(full_block_num)] + + ([key_cache_scale[block_tables[i, full_block_num], :, :rem_token_num]] \ + if rem_token_num > 0 else []), dim=-1 + ) + + dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale_i, quant_mode) + key_i[...] = dequant_key_i.to(key_i.dtype) + # dequant value from cache + if not (value_cache is None or value is None or value_cache_scale is None): + value_i = value[seq_begin:seq_end].transpose(1, 0) + # [head_num, seq_num, head_size] + value_cache_i = torch.concat( + [value_cache[block_tables[i, j], ...] for j in range(full_block_num)] + + ([value_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] \ + if rem_token_num > 0 else []), dim=-2 + ) + + # dequant value from cache + if quant_mode == 0: + # [head_num, head_size] + value_cache_scale_i = value_cache_scale + else: + # [head_num, seq_num] + value_cache_scale_i = torch.concat( + [value_cache_scale[block_tables[i, j],...] for j in range(full_block_num)] + + ([value_cache_scale[block_tables[i, full_block_num], :, :rem_token_num]] \ + if rem_token_num > 0 else []), dim=-1 + ) + dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale_i, quant_mode) + value_i[...] = dequant_value_i.to(value_i.dtype) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ or "MLU3" in torch.mlu.get_device_name(), + "Skipping test_prevent due to ASan issues or in MLU300 series.") + def test_prevent(self): + batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \ + head_size, block_size = 8, 1024, 16, 32, 2048, 128, 16 + dtype, quant_mode, quant_bit = torch.float16, 0, 8 + default_args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, 1, dtype, quant_mode, quant_bit) + args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, 1, torch.float32, quant_mode, quant_bit) + self.assertException("Tensor key type should be half or bfloat16.", + torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args) + args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, 1, dtype, 2, quant_bit) + self.assertException("quantization mode support 0 and 1.", + torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args) + args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, 1, dtype, quant_mode, 4) + self.assertException("quantization bit width only supports 8.", + torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args) + default_args[-5] = 10240 + self.assertException("max_context_len should smaller than or equal to block_size * max_block_num.", + torch.ops.torch_mlu_ops.dequant_from_paged_cache, *default_args) + default_args[-5] = max_context_len + default_args[1] = default_args[1].transpose(-1, -2) + self.assertException("Tensor value last dim must be contiguous.", + torch.ops.torch_mlu_ops.dequant_from_paged_cache, *default_args) + + @unittest.skipIf("MLU3" in torch.mlu.get_device_name(), + "Skipping test_dequant_from_paged_cache in MLU300 series") + def test_dequant_from_paged_cache(self): + test_cases = 100 + head_size_times = 16 + mlu_name = torch.mlu.get_device_name() + batch_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32) + max_context_len_list = torch.randint(low=2, high=1024, size=[test_cases], + dtype=torch.int32) + head_num_q_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32) + head_num_kv_list = torch.randint(low=1, high=8, size=[test_cases], dtype=torch.int32) + head_size_list = torch.randint(low=1, high=16, size=[test_cases], + dtype=torch.int32) * head_size_times + block_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32) + cache_mem_len_list = torch.randint(low=512, high=1024, size=[test_cases], + dtype=torch.int32) * 2 + quant_mode_list = np.random.choice([0, 1], test_cases) + quant_bit_list = np.random.choice([8], test_cases) + use_offset_list = np.random.choice([False, True], test_cases) + has_value_list = np.random.choice([False, True], test_cases) + dtype_list = [torch.half, torch.bfloat16] + dtype_list = np.random.choice(dtype_list, test_cases) + for i in range(test_cases): + batch_size = batch_size_list[i].item() + head_num_q = head_num_q_list[i].item() + max_context_len = max_context_len_list[i].item() + head_num_kv = head_num_kv_list[i].item() + head_size = head_size_list[i].item() + block_size = block_size_list[i].item() + cache_mem_len = cache_mem_len_list[i].item() + quant_mode = quant_mode_list[i] + quant_bit = quant_bit_list[i] + use_seq_offset = use_offset_list[i] + has_value = has_value_list[i] + dtype = dtype_list[i] + print("batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, " + "quant_bit={}, dtype={} testing...".format(batch_size, head_num_kv, head_size, + max_context_len, quant_mode, quant_bit, dtype)) + + torch.manual_seed(2766) + args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, use_seq_offset, dtype, quant_mode, + quant_bit, has_value) + self.launch(*args) + + @unittest.skipIf("MLU3" in torch.mlu.get_device_name(), + "Skipping test_dequant_from_paged_cache in MLU300 series") + def test_inductor(self): + batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \ + head_size, block_size = 8, 1024, 16, 32, 2048, 128, 16 + dtype, quant_mode, quant_bit = torch.float16, 0, 8 + args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, 1, dtype, quant_mode, quant_bit) + self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_paged_cache, args) + args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, + head_size, head_size, block_size, 0, dtype, quant_mode, quant_bit) + self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_paged_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestDequantFromPagedCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_faketensor.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_faketensor.py new file mode 100644 index 0000000..1e453cc --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_faketensor.py @@ -0,0 +1,299 @@ +import torch +import torch_mlu +from itertools import product +import torch_mlu_ops as ops +import random +import torch.testing._internal.optests as optests +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.testing._internal.common_utils import ( + run_tests, + TestCase, +) + +dtype_dict = { + torch.half: "half", + torch.bfloat16: "bfloat16", + torch.float: "float", +} + +fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) + +class FakeTensorTest(TestCase): + def test_matmul(self): + with fake_tensor_mode: + mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True + trans_a_list = [True, False] + trans_b_list = [True, False] + dtype_list = [torch.half, torch.float] + args = product(trans_a_list, trans_b_list, dtype_list) + for trans_a, trans_b, dtype in args: + print(f"matmul... trans_a: {trans_a}, trans_b: {trans_b}, dtype: {dtype}") + shape_a, shape_b = (mat_m, mat_k), (mat_k, mat_n) + if trans_a: + shape_a = (mat_k, mat_m) + if trans_b: + shape_b = (mat_n, mat_k) + a = torch.randn(shape_a, dtype=dtype, device='mlu') + b = torch.randn(shape_b, dtype=dtype, device='mlu') + bias = torch.randn((mat_n), dtype=dtype, device='mlu') + c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') + output = torch.ops.torch_mlu_ops.matmul(a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b) + self.assertEqual(output.shape, (mat_m, mat_n)) + + a8 = torch.randint(-128, 127, shape_a).to(torch.int8).mlu() + b8 = torch.randint(-128, 127, shape_b).to(torch.int8).mlu() + a_scale = 280.8 + b_scale = 190.8 + str_dtype = dtype_dict[dtype] + output = torch.ops.torch_mlu_ops.matmul(a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b) + self.assertEqual(output.shape, (mat_m, mat_n)) + + def test_weight_only_quant_matmul(self): + with fake_tensor_mode: + M, K, N, group_num = 2, 256, 32, 4 + quant_bit_size, act_mode, use_hp_active, act_coef, alpha, beta = 8, 'none', True, 1., 0.8, 0.3 + group_quant_list = [True, False] + trans_a_list = [True, False] + trans_b_list = [True, False] + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + args = product(group_quant_list, dtype_list, trans_a_list, trans_b_list) + for group_quant, dtype, trans_a, trans_b in args: + print(f"weight_only_quant_matmul... group_quant: {group_quant}, dtype: {dtype}, trans_a: {trans_a}, trans_b: {trans_b}") + a = torch.randn((M, K), dtype=dtype).mlu() + b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu() + c = torch.randn(M, N, device="mlu", dtype=dtype) + bias = torch.randn(N, device="mlu", dtype=dtype) + group_wise_scale = torch.randn((N, group_num), device="mlu", dtype=dtype) + b_quant_layout = "quantize_group_wise" if group_quant else "quantize_per_channel" + b_scale = group_wise_scale if group_quant else None + gemm_output_scale = None if group_quant else torch.randn(N, device="mlu", dtype=torch.float) + a_scale, a_zero, b_zero, c_zero = None, None, None, None + c_scale, gemm_output_zero = None, None + quant_algo, a_quant_layout = "weight_only", "quantize_none" + str_dtype = dtype_dict[dtype] + output = torch.ops.torch_mlu_ops.quant_matmul(a, a_scale, a_zero, + b, b_scale, b_zero, + bias, c, c_scale, c_zero, + gemm_output_scale, gemm_output_zero, + str_dtype, None, quant_algo, + a_quant_layout, b_quant_layout, + quant_bit_size, act_mode, use_hp_active, act_coef, + alpha, beta, trans_a, trans_b,) + self.assertEqual(output.shape, (M, N)) + + def test_smooth_quant_matmul(self): + with fake_tensor_mode: + act_mode_list = ["none", "silu", "gelu"] + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + trans_a_list = [True, False] + trans_b_list = [True, False] + M, K, N = 2, 16, 32 + quant_bit_size, use_hp_active, act_coef, alpha, beta = 8, True, 1., 0.8, 0.3 + arg = product(act_mode_list, dtype_list, trans_a_list, trans_b_list) + for act_mode, dtype, trans_a, trans_b in arg: + print(f"smooth_quant_matmul... act_mode: {act_mode}, dtype: {dtype}, trans_a: {trans_a}, trans_b: {trans_b}") + a = torch.randint(-128, 127, (M, K), dtype=torch.int8).mlu() + b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu() + c = None + bias = torch.randn(N, device="mlu", dtype=dtype) + a_scale = torch.randn(M, device="mlu", dtype=torch.float) + b_scale = torch.randn(N, device="mlu", dtype=torch.float) + a_zero, b_zero, c_zero = None, None, None + c_scale, gemm_output_scale, gemm_output_zero = None, None, None + quant_algo, a_quant_layout, b_quant_layout = "smooth_quant", "quantize_per_token", "quantize_per_channel" + str_dtype = dtype_dict[dtype] + output = torch.ops.torch_mlu_ops.quant_matmul(a, a_scale, a_zero, + b, b_scale, b_zero, + bias, c, c_scale, c_zero, + gemm_output_scale, gemm_output_zero, + str_dtype, None, quant_algo, + a_quant_layout, b_quant_layout, + quant_bit_size, act_mode, use_hp_active, act_coef, + alpha, beta, trans_a, trans_b) + self.assertEqual(output.shape, (M, N)) + + def test_group_gemm(self): + with fake_tensor_mode: + batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5 + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + idx_list = [False, True] + has_bias_list = [True, False] + args = product( dtype_list, idx_list, has_bias_list) + for data_type, idx_mode, has_bias in args: + print(f"group_gemm... has_bias: {has_bias}, idx_mode: {idx_mode}, dtype: {data_type}") + bs = batch * seq + token_topk = bs * topk + expert_id = torch.randint(experts_num, (token_topk,), device="mlu") + sorted_expert_id, indices = expert_id.sort() + gather_idx = indices // topk + gather_idx = gather_idx.to(torch.int32) + token_count = torch.randint(0, token_topk, (experts_num,)).to(torch.int32) + a = torch.randn(bs, k, device="mlu", dtype=data_type) + if not idx_mode: + a = a[gather_idx] + b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type) + c = torch.randn(token_topk, n, device="mlu", dtype=data_type) + alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32) + beta = torch.randn(experts_num, device="mlu", dtype=torch.float32) + a_scale = None + b_scale = None + bias = None + if has_bias: + bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) + gather_idx_ = gather_idx if idx_mode else None + output = torch.ops.torch_mlu_ops.group_gemm(a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, bias, None, None, None, bs) + self.assertEqual(output.shape, (token_topk, n)) + + def test_smoothquant_group_gemm(self): + with fake_tensor_mode: + batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5 + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + idx_list = [False, True] + has_bias_list = [True, False] + args = product( dtype_list, idx_list, has_bias_list) + for data_type, idx_mode, has_bias in args: + print(f"smoothquant_group_gemm... has_bias: {has_bias}, idx_mode: {idx_mode}, dtype: {data_type}") + bs = batch * seq + token_topk = bs * topk + expert_id = torch.randint(experts_num, (token_topk,), device="mlu") + sorted_expert_id, indices = expert_id.sort() + gather_idx = indices // topk + gather_idx = gather_idx.to(torch.int32) + token_count = torch.randint(0, token_topk, (experts_num,)).to(torch.int32) + a8 = torch.randint(-128, 127, (bs, k)).to(torch.int8).mlu() + b8 = torch.randint(-128, 127, (experts_num, n, k)).to(torch.int8).mlu() + a_scale = torch.randn(token_topk, dtype=torch.float32).mlu() + b_scale = torch.randn(experts_num, n, dtype=torch.float32).mlu() + c = torch.randn(token_topk, n, device="mlu", dtype=data_type) + if not idx_mode: + a8 = a8[gather_idx] + alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32) + beta = torch.randn(experts_num, device="mlu", dtype=torch.float32) + bias = None + if has_bias: + bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) + gather_idx_ = gather_idx if idx_mode else None + str_dtype = dtype_dict[data_type] + output = torch.ops.torch_mlu_ops.group_gemm(a8, b8, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, bias, str_dtype, None, None, bs) + self.assertEqual(output.shape, (token_topk, n)) + + def test_moe_expand_input(self): + with fake_tensor_mode: + token_num, hidden_size, expert_num, topk, start_expert_id, expert_size = 2048, 4096, 32, 8, 3, 20 + dtype_list = [torch.half, torch.float, torch.int8, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half, torch.float, torch.int8] + for dtype in dtype_list: + print(f"moe_expand_input... token_num: {token_num}, expert_num: {expert_num}, topk: {topk}, dtype: {dtype}") + input = torch.randn(token_num, hidden_size, device='mlu').to(dtype) + gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,), dtype=torch.int32, device='mlu') + cusum_token_count = torch.zeros(expert_num + 1, dtype=torch.int32).mlu() + output=torch.ops.torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count, start_expert_id, expert_size) + self.assertEqual(output.shape, (token_num*topk, hidden_size)) + + def test_moe_gen_idx(self): + with fake_tensor_mode: + token_num, expert_num, topk = 2048, 32, 8 + print(f"moe_gen_idx... token_num: {token_num}, expert_num: {expert_num}, topk: {topk}") + expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu') + expand_idx, combine_idx, token_count, cusum_token_count =torch.ops.torch_mlu_ops.moe_gen_idx(expert_id, expert_num) + self.assertEqual(expand_idx.shape, (token_num * topk,)) + self.assertEqual(combine_idx.shape, (token_num * topk,)) + self.assertEqual(token_count.shape, (expert_num,)) + self.assertEqual(cusum_token_count.shape, (expert_num + 1,)) + + def test_moe_combine_result(self): + with fake_tensor_mode: + has_bias_list = [True, False] + num_tokens, hidden_size, num_expert, topk, start_expert_id = 1, 2048, 8, 2, 0 + expert_size_list = [5, 8] + dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half] + args = product(has_bias_list, expert_size_list, dtype_list) + for has_bias, expert_size, dtype in args: + print(f"moe_combine_result... has_bias: {has_bias}, expert_size: {expert_size}, dtype: {dtype}") + input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device='mlu') + reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device='mlu') + gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device='mlu') + bias = None + residual = None + cusum_token_count = None + if has_bias: + bias = torch.randn((num_expert, hidden_size), dtype=dtype, device='mlu') + residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device='mlu') + if has_bias or expert_size < num_expert: + cusum_token_count =torch.zeros(num_expert + 1, dtype=torch.int32).mlu() + output=torch.ops.torch_mlu_ops.moe_combine_result(input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias) + self.assertEqual(output.shape, (num_tokens, hidden_size)) + + def test_moe(self): + with fake_tensor_mode: + act_mode = 'gelu' + case_list = set() + while (len(case_list) < 100): + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(1024, 3072, 512) + inner_size = random.randrange(1024, 3072, 512) + expert_num = random.randint(1, 40) + topk = random.randint(1, expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + quant_mode = random.choice(["no_quant", "w4", "w8", "w4w8"]) + quant_wise = random.choice([128, 256, 512]) + data_type = random.choice([torch.bfloat16, torch.float16]) + if not torch_mlu.mlu.is_bf16_supported(): + data_type = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_mode, quant_wise, act_mode) + if case in case_list: + continue + case_list.add(case) + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_mode: {quant_mode}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True) + # if expert_size == -1: + # expert_size = expert_num + if quant_mode == "no_quant": + w1 = torch.randn((expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) + w2 = torch.randn((expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) + w1_quant_flag = None + w2_quant_flag = None + elif quant_mode == "w4": + w1_quant_group = hidden_size // quant_wise + w2_quant_group = inner_size // quant_wise + w1_quant_flag = None + w2_quant_flag = None + w1 = torch.randint(-128, 127, (expert_num, inner_size*(1+gated), hidden_size // 2), device="mlu", dtype=torch.int32).to(torch.int8) + w2 = torch.randint(-128, 127, (expert_num, hidden_size, inner_size // 2), device="mlu", dtype=torch.int32).to(torch.int8) + w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + elif quant_mode == "w8": + w1_quant_flag = None + w2_quant_flag = None + w1 = torch.randint(-128, 127, (expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=torch.int32).to(torch.int8) + w2 = torch.randint(-128, 127, (expert_num, hidden_size, inner_size), device="mlu", dtype=torch.int32).to(torch.int8) + w1_scale = torch.empty((expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + w2_scale = torch.empty((expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + elif quant_mode == "w4w8": + w1_quant_group = hidden_size // quant_wise + w2_quant_group = inner_size // quant_wise + w1_quant_flag = random.choices([4,8], k=expert_num * w1_quant_group) + w2_quant_flag = random.choices([4,8], k=expert_num * w2_quant_group) + w1_count = (sum(w1_quant_flag) // 4) * (quant_wise // 2) * inner_size*(1+gated) + w2_count = (sum(w2_quant_flag) // 4) * (quant_wise // 2) * hidden_size + w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8) + w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8) + w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + input_smooth = None if quant_mode == "no_quant" else torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05) + act_smooth = None if quant_mode == "no_quant" else torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05) + bias1, bias2 = None, None + output = torch.ops.torch_mlu_ops.fused_moe(hidden_states, router_logit, w1, w2, bias1, bias2, residual, + input_smooth, act_smooth, w1_scale, w2_scale, w1_quant_flag, + w2_quant_flag, topk, renormalize, gated, act_mode, 0, + 0, 0) + self.assertEqual(output.shape, hidden_states.shape) + + +if __name__ == "__main__": + if torch.__version__ >= '2.3.0': + run_tests() diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py new file mode 100755 index 0000000..46f596e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_ffn.py @@ -0,0 +1,65 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * + +class TestFFNOp(BtTestCase): + def op_impl_base(self, *args): + input, w1, bias1, w2, bias2, w3, bias3, act_mode = args + up = F.linear(input, w1, bias1) + act = act_mode_dict[act_mode](up.float()).to(input.dtype) + if w3 is not None: + gate = F.linear(input, w3, bias3) + act = act * gate + output = F.linear(act, w2, bias2) + return output + + def test_ffn(self): + input_size_list = [128, 256, 512] + hidden_size = 1024 + seq_len_list = [10, 16, 20] + bool_value_list = [True, False] + batch = 5 + dtype_list = [torch.half] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for input_size, seq_len, bool_value in zip(input_size_list, + seq_len_list, + bool_value_list): + print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format( + input_size, seq_len, bool_value, bool_value), flush=True) + use_gate = bool_value + for dtype in dtype_list: + w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") + b1 = torch.randn((hidden_size), dtype=dtype, device="mlu") + w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu") + b2 = torch.randn((input_size), dtype=dtype, device="mlu") + w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None + b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None + input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu") + args = (input, w1, b1, w2, b2, w3, b3, 'silu') + output = self.op_impl_base(*args) + tmo_output1 = ops.ffn(*args) + self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + # use matmul to implement ffn + f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1 + f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1 + pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0) + act_out = ops.active(pre_gemm_out, 'silu', use_gate) + tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0) + tmo_output2 = tmo_output2.view(batch, seq_len, input_size) + self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + + def test_inductor(self): + batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half + input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu") + up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") + down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu") + args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.) + self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args) + +if __name__ == '__main__': + exit(run_unittest(TestFFNOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_flash_attention.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_flash_attention.py new file mode 100755 index 0000000..18c8ec2 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_flash_attention.py @@ -0,0 +1,357 @@ +import math +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +def gen_args(seq_q, seq_k, head_num_q, head_num_k, head_size, has_alibi, has_mask, is_causal, use_block, return_lse, dtype): + batch = len(seq_q) + max_seq_q = max(seq_q) + max_seq_k = max(seq_k) + cu_seq_len_q = [0] + cu_seq_len_k = [0] + for i in range(batch): + cu_seq_len_q.append(seq_q[i] + cu_seq_len_q[-1]) + cu_seq_len_k.append(seq_k[i] + cu_seq_len_k[-1]) + + softmax_scale = 1 / math.sqrt(head_size) + alibi_slope = None if has_alibi == False else torch.zeros((head_num_q)).uniform_(0, 0.1).to(torch.float32).mlu() + attn_bias = None if has_mask is False else torch.randn((batch, head_num_q, max_seq_q, max_seq_k), dtype=dtype).mlu() + total_seq_q = sum(seq_q) + total_seq_k = sum(seq_k) + q = torch.randn(total_seq_q, head_num_q, head_size, dtype=dtype, device="mlu") + block_tables = None + if use_block: + block_size = 16 + num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) + max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size + cache_shape = (num_blocks, head_num_k, block_size, head_size) + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + k = torch.randn(size=cache_shape, dtype=dtype).mlu() + v = torch.randn(size=cache_shape, dtype=dtype).mlu() + else: + k = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu") + v = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu") + tmo_out = torch.empty_like(q) + out_lse = torch.randn(batch, head_num_q, max_seq_q, dtype = torch.float, device="mlu") if return_lse else None + return (q, k, v, tmo_out, out_lse, + torch.tensor(cu_seq_len_q, dtype=torch.int32, device="mlu"), + torch.tensor(cu_seq_len_k, dtype=torch.int32, device="mlu"), + alibi_slope, attn_bias, None, None, + block_tables, max_seq_q, max_seq_k, + softmax_scale, is_causal, + -1, -1, "float", return_lse) + +def gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype): + batch = len(seq_q) + max_seq_q = max(seq_q) + max_seq_k = max(seq_k) + cu_seq_len_q = [0] + cu_seq_len_k = [0] + for i in range(batch): + cu_seq_len_q.append(seq_q[i] + cu_seq_len_q[-1]) + cu_seq_len_k.append(seq_k[i] + cu_seq_len_k[-1]) + cu_seq_len_q = torch.tensor(cu_seq_len_q, dtype=torch.int32, device="mlu") + cu_seq_len_k = torch.tensor(cu_seq_len_k, dtype=torch.int32, device="mlu") + + alibi_slope = None if has_alibi == False else torch.zeros((head_num_q)).uniform_(0, 0.1).to(torch.float32).mlu() + attn_bias = None if has_mask is False else torch.randn((batch, head_num_q, max_seq_q, max_seq_k), dtype=dtype).mlu() + total_seq_q = sum(seq_q) + total_seq_k = sum(seq_k) + q = torch.randn(total_seq_q, head_num_q, head_size, dtype=dtype, device="mlu") + block_tables = None + if use_block: + block_size = 16 + num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) + max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu() + v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu() + else: + k = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu") + v = torch.randn(total_seq_k, head_num_k, head_size_v, dtype=dtype, device="mlu") + return q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables + +class TestFlashAttnOp(BtTestCase): + def op_impl_base(self, *args): + q, k, v, out, cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, max_seq_len_q, \ + max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, \ + compute_dtype, return_lse, block_tables, k_cache_quant_scale, v_cache_quant_scale = args + is_pack = cu_seq_lens_q is not None + has_block_table = block_tables is not None + if has_block_table: + assert is_pack == True + batch = len(cu_seq_lens_q) - 1 if is_pack else q.size(0) + head_num_q = q.size(-2) + head_num_kv = k.size(-2) if block_tables is None else k.size(-3) + head_size = q.size(-1) + head_size_v = v.size(-1) + assert head_num_q >= head_num_kv and head_num_q % head_num_kv == 0 + group = head_num_q // head_num_kv + device = q.device + repeat_dim = -3 if has_block_table else -2 + k_bd = torch.repeat_interleave(k, group, dim=repeat_dim) + v_bd = torch.repeat_interleave(v, group, dim=repeat_dim) + out_list = [] + inf = 1e6 + lse = torch.zeros(batch, head_num_q, max_seq_len_q, dtype=torch.float) + for i in range(batch): + q_i = q[cu_seq_lens_q[i]:cu_seq_lens_q[i+1], ...] if is_pack else q[i] + if not has_block_table: + k_i = k_bd[cu_seq_lens_kv[i]:cu_seq_lens_kv[i+1], ...] if is_pack else k[i] + v_i = v_bd[cu_seq_lens_kv[i]:cu_seq_lens_kv[i+1], ...] if is_pack else v[i] + else: + block_table = block_tables[i] + context_len = cu_seq_lens_kv[i+1] - cu_seq_lens_kv[i] + block_size = k.size(-2) # (num_block, head_num, block_size, head_size) + table_end = (context_len + block_size - 1) // block_size + block_ids = block_table[0 : table_end] + keys, values = k[block_ids], v[block_ids] + keys = torch.repeat_interleave(keys, group, dim=1) #[num_block, head_q, block_size, head_size] + keys = keys.transpose(1, 0).contiguous().view(head_num_q, -1, head_size) #[head_q, num_block * blcke_size, head_size] + k_i = keys.transpose(1,0) #[num_block * blcke_size, head_q, head_size] + k_i = k_i[0:context_len, ...] #[seq_k, head_q, head_size] + values = torch.repeat_interleave(values, group, dim=1) + values = values.transpose(1, 0).contiguous().view(head_num_q, -1, head_size_v) + v_i = values.transpose(1,0) + v_i = v_i[0:context_len, ...] + qk = torch.einsum('qhd,khd->hqk', q_i, k_i).to(torch.float) * softmax_scale + seq_q, seq_k = q_i.size(0), k_i.size(0) + if alibi_slope is not None: + slope = alibi_slope.reshape(1, head_num_q, 1, 1) + slope_bias = torch.zeros(1, head_num_q, seq_q, seq_k).to(device=device) + if is_causal: + relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).to(device=device) + slope_bias = relative_pos * slope + else: + row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1) + col_idx = torch.arange(seq_k, dtype=torch.long) + relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).to(device=device) + slope_bias = -slope * relative_pos.to(dtype=slope.dtype) + qk += (slope_bias.squeeze(0)) + if is_causal: + assert seq_q <= seq_k, "seq_q <= seq_k if causal=True" + zeros = torch.zeros(seq_q, seq_k-seq_q, dtype=torch.float, device="mlu") + tri = torch.full((seq_q, seq_q), -inf, dtype=torch.float, device="mlu").triu(diagonal=1) + mask = torch.cat([zeros, tri], dim=1) # (q, k-q) + (q, q) => (q, k) + qk += mask + if window_size_left != -1 or window_size_right != -1: + mask_w = torch.full((seq_q, seq_k), -inf, dtype=torch.float, device="mlu") + for qi in range(seq_q): + left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0 + right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k + mask_w[qi, left:right] = 0 + qk += mask_w + + if attn_bias is not None: + qk += attn_bias[i][:, :seq_q, :seq_k] + if return_lse: + lse[i][:, :seq_q] = torch.logsumexp(qk, dim=-1) + attn = torch.softmax(qk, dim=-1, dtype=torch.float).to(q.dtype) + qkv = torch.einsum('hqk,khd->qhd', attn, v_i) + out_list.append(qkv) + attn_out = torch.cat(out_list, dim=0) + if is_pack == False: + attn_out = attn_out.view(q.size(0), q.size(1), q.size(2), head_size_v) + if return_lse: + attn_out = (attn_out, lse) + return attn_out + + def test_flash_attention(self): + seq_len_list = [((38, 64, 128), (38, 64, 128)), ((30, 40, 50), (60, 90, 120))] + head_num_list = [(32, 32), (32, 4)] + use_block_list = [False, True] + head_size_list = [(64, 64), (128, 512)] + alibi_list = [False, True] + mask_list = [False, True] + causal_list = [False, True] + dtype_list = [torch.half, torch.float] + window_size_list = [(-1, -1), (10, -1), (8, 8)] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + use_block_list = [False] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + args = product(seq_len_list, head_num_list, head_size_list, alibi_list, mask_list, causal_list, use_block_list, window_size_list, dtype_list) + for ((seq_q, seq_k), (head_num_q, head_num_k), (head_size, head_size_v), has_alibi, has_mask, is_causal, use_block, + (window_size_left, window_size_right), dtype) in args: + batch = len(seq_q) + print("batch={}, seq_lens_q={}, seq_lens_k={}, head_num_q={}, head_num_k={}, head_size={}, head_size_v= {}, has_alibi={}, \ +has_mask={}, is_causal={}, use_block={}, window_size_left={}, window_size_right={}, dtype={}, testing...".format( + batch, seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, is_causal, use_block, + window_size_left, window_size_right, dtype), flush=True) + max_seq_q = max(seq_q) + max_seq_k = max(seq_k) + params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype) + q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params + softmax_scale = 1 / math.sqrt(head_size) + + torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, False, + block_tables if use_block else None, None, None) + tmo_output = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, False, + block_tables if use_block else None, None, None) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + + if use_block: # test block_tables is [batch, 1] + block_size = max_seq_k + num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch + max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1 + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu() + v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu() + torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, -1, -1, torch.float, False, + block_tables if use_block else None, None, None) + tmo_output = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, -1, -1, torch.float, False, + block_tables if use_block else None, None, None) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + + def test_flash_attention_lse(self): + seq_len_list =[((66, 77, 88), (77, 88, 99)), ((1024, 1024), (1024, 1024))] + head_num_q, head_num_k = 16, 8 + head_size_list = [(64, 64), (16, 128)] + is_causal = True + has_alibi = True + has_mask = True + window_size_list = [(-1, -1), (10, -1)] + dtype_list = [torch.half, torch.float] + use_block_list = [False, True] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + use_block_list = [False] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + args = product(seq_len_list, head_size_list, use_block_list, window_size_list, dtype_list) + for (seq_q, seq_k), (head_size, head_size_v), use_block, (window_size_left, window_size_right), dtype in args: + batch = len(seq_q) + print("batch={}, seq_lens_q={}, seq_lens_k={}, head_num_q={}, head_num_k={}, head_size={}, head_size_v={}, has_alibi={}, \ + has_mask={}, is_causal={}, use_block={}, window_size_left={}, window_size_right={}, dtype={}, testing...".format( + batch, seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, is_causal, use_block, + window_size_left, window_size_right, dtype), flush=True) + max_seq_q = max(seq_q) + max_seq_k = max(seq_k) + params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype) + q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params + softmax_scale = 1 / math.sqrt(head_size) + + torch_output, torch_output_lse = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, True, + block_tables if use_block else None, None, None) + tmo_output, tmo_output_lse = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, True, + block_tables if use_block else None, None, None) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_output_lse.cpu(), tmo_output_lse.cpu(), 0.0006, use_MSE=True, use_RAE=True) + + if use_block: #test block_table = [batch, 1] + block_size = max_seq_k + num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch + max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1 + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu() + v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu() + torch_output, torch_output_lse = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, True, + block_tables if use_block else None, None, None) + tmo_output, tmo_output_lse = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, True, + block_tables if use_block else None, None, None) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_output_lse.cpu(), tmo_output_lse.cpu(), 0.0006, use_MSE=True, use_RAE=True) + + def test_flash_attention_inplace(self): + seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v = (66, 77, 88), (77, 88, 99), 16, 8, 64, 64 + use_block_list = [False, True] + softmax_scale = 1 / math.sqrt(head_size) + is_causal = True + has_alibi = True + has_mask = True + batch = len(seq_q) + max_seq_q = max(seq_q) + max_seq_k = max(seq_k) + cu_seq_len_q = [0] + cu_seq_len_k = [0] + window_size_left, window_size_right = -1, -1 + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + use_block_list = [False] + total_seq_q = sum(seq_q) + for use_block in use_block_list: + print("test_flash_attention_inplace: use_block: {}, testing....".format(use_block)) + params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, torch.half) + q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params + torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, False, + block_tables if use_block else None, None, None) + tmo_output = torch.empty((total_seq_q, head_num_q, head_size_v), dtype=torch.half, device="mlu") + ops.flash_attention(q, k, v, tmo_output, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, False, + block_tables if use_block else None, None, None) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + if use_block: #test block_table = [batch, 1] + block_size = max_seq_k + num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch + max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1 + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=torch.float16).mlu() + v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=torch.float16).mlu() + torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, False, + block_tables if use_block else None, None, None) + ops.flash_attention(q, k, v, tmo_output, cu_seq_len_q, cu_seq_len_k, + alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale, + is_causal, window_size_left, window_size_right, torch.float, False, + block_tables if use_block else None, None, None) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + seq_q, seq_k, head_num_q, head_num_k, head_size, dtype = (66, 77, 88), (77, 88, 99), 16, 8, 64, torch.float16 + use_block_list = [False, True] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + use_block_list = [False] + alibi_list = [False, True] + mask_list = [False, True] + causal_list = [False, True] + return_lse_list = [False, True] + test_flags = product(use_block_list, alibi_list, mask_list, causal_list, return_lse_list) + for use_block, has_alibi, has_mask, is_causal, return_lse in test_flags: + print(f"==== use_block: {use_block}, has_alibi: {has_alibi}, has_mask: {has_mask}, is_causal: {is_causal} return_lse: {return_lse}====") + args = gen_args(seq_q, seq_k, head_num_q, head_num_k, head_size, has_alibi, has_mask, is_causal, use_block, return_lse, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.flash_attention, args) + +if __name__ == '__main__': + exit(run_unittest(TestFlashAttnOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_attn_proj.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_attn_proj.py new file mode 100755 index 0000000..deb8e03 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_attn_proj.py @@ -0,0 +1,94 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import torch.nn as nn + +class TestFusedNormAttnProjOp(BtTestCase): + def op_impl_base(self, *args): + input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias, norm_weight, \ + norm_bias, eps, out_layout, head_size, norm_out = args + input_size = input.size(-1) + layernorm = torch.nn.LayerNorm(input_size) + layernorm.eps = eps + layernorm.weight = nn.Parameter(norm_weight) + layernorm.bias = nn.Parameter(norm_bias) + layernorm_out = layernorm(input) + q_out = torch.matmul(layernorm_out, q_weight.permute(1, 0)) + q_bias + if k_weight is not None: + k_out = torch.matmul(layernorm_out, k_weight.permute(1, 0)) + k_bias + v_out = torch.matmul(layernorm_out, v_weight.permute(1, 0)) + v_bias + + if out_layout == 'nhtc': + batch, seq, _ = input.shape + hidden_size_q = q_weight.size(0) + q_head = hidden_size_q // head_size + q_out = q_out.reshape(batch, seq, q_head, head_size).transpose(1, 2) + if k_weight is not None: + hidden_size_kv = k_weight.size(0) + kv_head = hidden_size_kv // head_size + k_out = k_out.reshape(batch, seq, kv_head, head_size).transpose(1, 2) + v_out = v_out.reshape(batch, seq, kv_head, head_size).transpose(1, 2) + outs = (q_out,) if k_weight is None else (q_out, k_out, v_out,) + if norm_out is True: + outs += (layernorm_out,) + return outs + + def test_attn_proj(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + N, T, input_size, hidden_size, head_size, eps, alpha, beta = 4, 16, 512, 768, 64, 1e-5, 0.5, 0.3 + print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format( + N, T, input_size, hidden_size), flush=True) + input = torch.randn(N, T, input_size, dtype=dtype, device="mlu") + weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") + bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu") + norm_weight = torch.randn(input_size, dtype=dtype, device="mlu") + norm_bias = torch.randn(input_size, dtype=dtype, device="mlu") + residual = torch.randn(N, T, hidden_size, dtype=dtype, device="mlu") + weights = torch.chunk(weight, 3) + biass = torch.chunk(bias, 3) + # test pre_attn_proj + print("test pre_attn_proj...") + out_torch = self.op_impl_base(input, + weights[0], biass[0], + weights[1], biass[1], + weights[2], biass[2], + norm_weight, norm_bias, + eps, 'nthc', head_size, True) + out_tmo = ops.fused_norm_attention_project(input, + weights[0], biass[0], + weights[1], biass[1], + weights[2], biass[2], + norm_weight, norm_bias, + eps, 'nthc', head_size, True) + for o1, o2 in list(zip(out_torch, out_tmo)): + self.assertTensorsEqual(o1.cpu().float(), o2.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + N, T, input_size, hidden_size, head_size, eps, alpha, beta, dtype = 4, 16, 512, 768, 64, 1e-5, 0.5, 0.3, torch.half + input = torch.randn(N, T, input_size, dtype=dtype, device="mlu") + weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") + bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu") + norm_weight = torch.randn(input_size, dtype=dtype, device="mlu") + norm_bias = torch.randn(input_size, dtype=dtype, device="mlu") + weights = torch.chunk(weight, 3) + biass = torch.chunk(bias, 3) + args = (input, weights[0], biass[0], weights[1], biass[1], weights[2], biass[2], + norm_weight, norm_bias, None, "nhtc", head_size, eps, alpha, + beta, True) + self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args) + + args = (input, weights[0], biass[0], weights[1], biass[1], weights[2], biass[2], + norm_weight, norm_bias, None, "nthc", head_size, eps, alpha, + beta, True) + self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args) + +if __name__ == '__main__': + exit(run_unittest(TestFusedNormAttnProjOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_ffn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_ffn.py new file mode 100755 index 0000000..2292cba --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_ffn.py @@ -0,0 +1,89 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * + +class TestFusedNormResidualFFNoP(BtTestCase): + def op_impl_base(self, *args): + input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, gate_up_proj_weight, \ + gate_up_proj_bias, layernorm_weight, layernorm_bias, eps, act_mode, residual_is, \ + alpha, beta = args + hidden_size = input.size(-1) + layernorm = torch.nn.LayerNorm(hidden_size) + layernorm.weight = torch.nn.Parameter(layernorm_weight) + layernorm.bias = torch.nn.Parameter(layernorm_bias) + layernorm.eps = eps + act = act_mode_dict[act_mode] + + residual = input + norm_out = input + if layernorm_weight is not None: + norm_out = layernorm(input) + if residual_is == "normed_input": + residual = norm_out + up_fc_out = torch.matmul(norm_out, up_fc_weight.permute(1, 0)) + up_fc_bias + act_out = act(up_fc_out.float()).to(up_fc_out.dtype) + if gate_up_proj_weight is not None: + gate_up_proj_out = torch.matmul(norm_out, gate_up_proj_weight.permute(1, 0)) + gate_up_proj_bias + down_proj_out = torch.matmul(act_out * gate_up_proj_out, down_proj_weight.permute(1, 0)) + down_proj_bias + else: + down_proj_out = torch.matmul(act_out, down_proj_weight.permute(1, 0)) + down_proj_bias + if residual_is != 'none': + out = beta * residual + alpha * down_proj_out + else: + out = down_proj_out + return out + + def test_ffn(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + batch, seq_len, hidden_size, inner_size, alpha, beta = 4, 16, 512, 512, 0.5, 0.3 + eps, act_mode, residual_is = 1e-5, "silu", "input" + print(f"batch: {batch}, seq_len: {seq_len}, hidden_size: {hidden_size}, inner_size: {inner_size}, alpha: {alpha}, beta: {beta}, \ +eps: {eps}, act_mode: {act_mode}, residual_is: {residual_is}, dtype: {dtype} testing...", flush=True) + input = torch.randn((batch, seq_len, hidden_size), dtype=dtype, device="mlu") + layernorm_weight = torch.randn(hidden_size, dtype=dtype, device="mlu") + layernorm_bias = torch.normal(0, 0.1, (hidden_size,), dtype=dtype, device="mlu") + up_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu") + up_fc_bias = torch.normal(0, 0.1, (inner_size,), dtype=dtype, device="mlu") + gated_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu") + gated_fc_bias = torch.normal(0, 0.1, (inner_size,), dtype=dtype, device="mlu") + down_fc_weight = torch.randn((hidden_size, inner_size), dtype=dtype, device="mlu") + down_fc_bias = torch.normal(0, 0.1, (hidden_size,), dtype=dtype, device="mlu") + torch_output = self.op_impl_base(input, + up_fc_weight, up_fc_bias, + down_fc_weight, down_fc_bias, + gated_fc_weight, gated_fc_bias, + layernorm_weight, layernorm_bias, + eps, act_mode, residual_is, + alpha, beta) + tmo_output = ops.fused_norm_residual_ffn(input, + up_fc_weight, up_fc_bias, + down_fc_weight, down_fc_bias, + gated_fc_weight, gated_fc_bias, + layernorm_weight, layernorm_bias, + eps, act_mode, residual_is, + alpha, beta) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + + def test_inductor(self): + batch, seq_len, hidden_size, inner_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half + input = torch.randn((batch, seq_len, hidden_size), dtype=dtype, device="mlu") + layernorm_weight = torch.randn((hidden_size,), dtype=dtype, device="mlu") + layernorm_bias = torch.randn((hidden_size,), dtype=dtype, device="mlu") + up_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu") + gate_up_proj_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu") + down_proj_weight = torch.randn((hidden_size, inner_size), dtype=dtype, device="mlu") + up_fc_bias = torch.randn((inner_size,), dtype=dtype, device="mlu") + gate_up_proj_bias = torch.randn((inner_size,), dtype=dtype, device="mlu") + down_proj_bias = torch.randn((hidden_size,), dtype=dtype, device="mlu") + args = (input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, + gate_up_proj_weight, gate_up_proj_bias, layernorm_weight, layernorm_bias, act_mode, "normed_input", 1e-5, 1., 0.) + self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args) + +if __name__ == '__main__': + exit(run_unittest(TestFusedNormResidualFFNoP)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_layernorm.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_layernorm.py new file mode 100755 index 0000000..282fa42 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_layernorm.py @@ -0,0 +1,213 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from typing import Union, List, Tuple +from torch.nn.parameter import Parameter +from torch import Size +import os + +dtype_list = [torch.half, torch.float] +if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) +eps = 0.001 + +class TestFuseLayerNormOp(BtTestCase): + def op_impl_base(self, *args): + x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant = args + layernorm = torch.nn.LayerNorm(x.size(-1)) + layernorm.eps = eps + layernorm.weight = Parameter(gamma) + layernorm.bias = Parameter(beta) + x = x + bias if bias is not None else x + pro_input = x + residual if residual is not None else x + output = layernorm(pro_input) + if quant_scale is not None: + output = (output * quant_scale).round().clamp(-128, 127).to(torch.int8) + if out is None: + if store_output_before_norm: + return (output, pro_input) + else: + return output + else: + out.copy_(output) + return out + + def test_layernorm(self): + C = 128 + input_shape = (8, 8, 6, C) + print("test layernorm...") + for dtype in dtype_list: + torch.manual_seed(1) + input = torch.randn(input_shape, device="mlu", dtype=dtype) + residual = torch.randn(input_shape, device="mlu", dtype=dtype) + gamma = torch.randn(C, device="mlu", dtype=dtype) + beta = torch.randn(C, device="mlu", dtype=dtype) + tmo_out_0, tmo_out_1 = ops.fused_layer_norm(input, residual, gamma, + beta, None, eps, True, None, None, False) + torch_out_0, torch_out_1 = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False) + self.assertTensorsEqual(tmo_out_0.cpu().float(), torch_out_0.cpu().float(), 0.005, use_MSE=True) + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out_1.cpu().float(), 0.005, use_MSE=True) + # test output_inplace + tmo_output = torch.empty(input_shape, device="mlu", dtype=dtype) + torch_output = torch.empty_like(tmo_output) + ops.fused_layer_norm(input, residual, gamma, beta, None, eps, False, None, tmo_output, False) + self.op_impl_base(input, residual, gamma, beta, None, eps, False, None, torch_output, False) + self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.003, use_MSE=True) + #test input_stride and output-continguous + inputs_shape = (8, 8, 10, C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input = inputs[:, :, 0:6, :] + tmo_out = ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, None, False) + torch_out = self.op_impl_base(input, None, gamma, beta, None, eps, False, None, None, False) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True) + # has res + tmo_out, tmo_res = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, None, None, False) + torch_out, torch_res = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True) + self.assertTensorsEqual(tmo_res.cpu().float(), torch_res.cpu().float(), 0.003, use_MSE=True) + #res has stride + res_shape = (8, 8, 16, C) + res = torch.randn(res_shape, device="mlu", dtype=dtype) + residual = res[..., 0:6, :] + tmo_out, tmo_res = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, None, None, False) + torch_out, torch_res = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True) + self.assertTensorsEqual(tmo_res.cpu().float(), torch_res.cpu().float(), 0.003, use_MSE=True) + #output has different stride + outputs = torch.randn(8, 8, 12, C, dtype=dtype, device='mlu') + output = outputs[..., 0:6, :] + ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, output, False) + torch_out = self.op_impl_base(input, None, gamma, beta, None, eps, False, None, None, False) + self.assertTensorsEqual(output.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True) + + def test_squant_layernorm(self): + C = 128 + input_shape = (8, 8, 6, C) + print("test squant_layernorm...") + for dtype in dtype_list: + torch.manual_seed(1) + input = torch.randn(input_shape, device="mlu", dtype=dtype) + residual = torch.randn(input_shape, device="mlu", dtype=dtype) + quant_scale = torch.randn(C, device="mlu", dtype=torch.float) * 30 + gamma = torch.randn(C, device="mlu", dtype=dtype) + beta = torch.randn(C, device="mlu", dtype=dtype) + tmo_out_0, tmo_out_1 = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, quant_scale, None, False) + torch_out_0, torch_out_1 = self.op_impl_base(input, residual, gamma, beta, None, eps, True, quant_scale, None, False) + self.assertTensorsEqual(tmo_out_0.cpu().float(), torch_out_0.cpu().float(), 0.006, use_MSE=True) + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out_1.cpu().float(), 0.006, use_MSE=True) + + def test_layernorm_stride(self): + C = 128 + print("test layernorm stride input...") + for dtype in dtype_list: + gamma = torch.randn(C, device="mlu", dtype=dtype) + beta = torch.randn(C, device="mlu", dtype=dtype) + inputs_shape = (8, 8, 10, C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input = inputs[:, :, 0:6, :] + tmo_output = torch.empty(inputs.shape, device="mlu", dtype=dtype) + tmo_output = tmo_output.as_strided(input.shape, input.stride()) + torch_out_0 = torch.empty_like(tmo_output) + self.op_impl_base(input, None, gamma, beta, None, eps, False, None, torch_out_0, False) + ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, tmo_output, False) + self.assertTensorsEqual(tmo_output.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True) + # test assumption get wrong stride when dim = 1 + input = inputs[:, :, 8:9, :] + input = input.as_strided(input.shape, (10240, 1280, 0, 1)) + torch_out_0 = torch.empty_like(input) + self.op_impl_base(input, None, gamma, beta, None, eps, False, None, torch_out_0, False) + ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, input, False) + self.assertTensorsEqual(input.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True) + inputs_shape = (8, 8, 10, 2*C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input1 = inputs[:, :, 0:2, 0:C] + torch_out_0 = torch.empty_like(input1) + self.op_impl_base(input1, None, gamma, beta, None, eps, False, None, torch_out_0, False) + ops.fused_layer_norm(input1, None, gamma, beta, None, eps, False, None, input1, False) + self.assertTensorsEqual(input1.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True) + input2 = inputs[:, :, :, 0:C] + torch_out_0 = torch.empty_like(input2) + self.op_impl_base(input2, None, gamma, beta, None, eps, False, None, torch_out_0, False) + ops.fused_layer_norm(input2, None, gamma, beta, None, eps, False, None, input2, False) + self.assertTensorsEqual(input2.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True) + #test 3-dim input + inputs_shape = (8, 8, 2*C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input1 = inputs[:, 0:2, 0:C] + torch_out_0 = torch.empty_like(input1) + self.op_impl_base(input1, None, gamma, beta, None, eps, False, None, torch_out_0, False) + ops.fused_layer_norm(input1, None, gamma, beta, None, eps, False, None, input1, False) + self.assertTensorsEqual(input1.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True) + input2 = inputs[0:2, 0:2, 0:C] + torch_out_0 = torch.empty_like(input2) + self.op_impl_base(input2, None, gamma, beta, None, eps, False, None, torch_out_0, False) + ops.fused_layer_norm(input2, None, gamma, beta, None, eps, False, None, input2, False) + self.assertTensorsEqual(input2.cpu().float(), torch_out_0.cpu().float(), 0.0032, use_MSE=True) + + # 防呆测试 + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + print("test_prevent....") + func1 = ops.fused_layer_norm + batch, seq_len, hidden_size = 5, 12, 512 + input = torch.randn(hidden_size, dtype=torch.half, device='mlu') + residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + quant_scale = torch.randn(hidden_size, dtype=torch.half, device='mlu') + gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu') + beta = torch.randn(hidden_size, dtype=torch.half, device='mlu') + self.assertException("input.dim() >= 2.", + func1, input, None, gamma, beta, None, eps, False, None, None) + inputs = torch.randn(batch, seq_len, 2*hidden_size, dtype=torch.half, device='mlu') + input = inputs[..., ::2] + self.assertException("input last dim must be contiguous.", + func1, input, None, gamma, beta, None, eps, False, None, None) + + input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + self.assertException("layernorm mode need gamma and beta.", + func1, input, None, gamma, None, None, eps, False, None, None) + gamma = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu') + beta = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu') + self.assertException("layernorm mode, gamma and beta size must be hidden_size.", + func1, input, None, gamma, beta, None, eps, False, None, None) + + gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu') + beta = torch.randn(hidden_size, dtype=torch.half, device='mlu') + inputs = torch.randn(batch, seq_len, 12, hidden_size, dtype=torch.half, device='mlu') + input = inputs[..., 6:9, :] + self.assertException("quant_out is not support when input has stride.", + func1, input, None, gamma, beta, None, eps, False, quant_scale, None) + inputs = torch.randn(batch, 12, seq_len, hidden_size, dtype=torch.half, device='mlu') + input = inputs[:, 6:9, ...] + self.assertException("check the strides of input.", + func1, input, None, gamma, beta, None, eps, False, None, input) + input = inputs[..., 6:9, :] + outputs = torch.randn(batch, 2*seq_len, 3, hidden_size, dtype=torch.half, device='mlu') + output = outputs[:, :seq_len, ...] + self.assertException("check the strides of output.", + func1, input, None, gamma, beta, None, eps, False, None, output) + + def test_inductor(self): + print("test_inductor....") + batch, seq_len, hidden_size = 5, 12, 512 + input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + quant_scale = torch.randn(hidden_size, dtype=torch.float, device='mlu') + gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu') + beta = torch.randn(hidden_size, dtype=torch.half, device='mlu') + bias = torch.randn(hidden_size, dtype=torch.half, device='mlu') + residual_out = torch.empty_like(input) + + output = torch.zeros(input.shape, dtype=torch.int8, device='mlu') + args = (input, output, residual, gamma, beta, bias, quant_scale, residual_out, None, "layernorm", eps, True, False) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args) + args = (input, output, residual, gamma, beta, bias, quant_scale, None, None, "layernorm", eps, False, False) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args) + output = torch.empty_like(input) + args = (input, output, residual, gamma, beta, bias, None, None, None, "layernorm", eps, False, False) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args) + + +if __name__ == '__main__': + exit(run_unittest(TestFuseLayerNormOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rmsnorm.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rmsnorm.py new file mode 100644 index 0000000..7dd7595 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rmsnorm.py @@ -0,0 +1,202 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from typing import Union, List, Tuple +from torch.nn.parameter import Parameter +from torch import Size +import os + +dtype_list = [torch.half, torch.float] +if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) +eps = 0.001 + +class TestFuseRmsNormOp(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + x = create_tensor_from_dic(dic['x']) + residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual']) + gamma = None if dic['gamma']['data'] is None else create_tensor_from_dic(dic['gamma']) + beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta']) + bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias']) + eps = dic['eps']['data'] + store_output_before_norm = dic['store_output_before_norm']['data'] + quant_scale = None if dic['quant_scale']['data'] is None else create_tensor_from_dic(dic['quant_scale']) + out = None if dic['out']['data'] is None else create_tensor_from_dic(dic['out']) + dynamic_quant = dic['dynamic_quant']['data'] + self.launch(x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant) + + def launch(self, *args): + args = list(args) + base_out = None if args[-2] is None else torch.empty_like(args[-2]) + base_input = args[0].clone() if args[-2] is not None and args[0] is args[-2] else args[0] + tmo_out = ops.fused_rms_norm(*args) + args[0] = base_input + args[-2] = base_out + torch_out = self.op_impl_base(*args) + if tmo_out.__class__ in (list, tuple): + for o1, o2 in zip(tmo_out, torch_out): + self.assertTensorsEqual(o1.cpu().float(), o2.cpu().float(), 0.007, use_MSE=True) + else: + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.007, use_MSE=True) + + def op_impl_base(self, *args): + x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant = args + x = x + bias if bias is not None else x + pro_input = x + residual if residual is not None else x + store_input = pro_input + pro_input = pro_input.to(torch.float32) + variance = pro_input.pow(2).mean(-1, keepdim=True) + pro_input = pro_input * torch.rsqrt(variance + eps) + if gamma.dtype in [torch.float16, torch.bfloat16]: + pro_input = pro_input.to(gamma.dtype) + output = gamma * pro_input + if quant_scale is not None: + output = (output * quant_scale).round().clamp(-128, 127).to(torch.int8) + if out is None: + if store_output_before_norm: + return (output, store_input) + else: + return output + else: + out.copy_(output) + return out + + def test_rmsnorm(self): + C = 128 + input_shape = (8, 8, 6, C) + print("test rmsnorm...") + for dtype in dtype_list: + torch.manual_seed(1) + input = torch.randn(input_shape, device="mlu", dtype=dtype) + residual = torch.randn(input_shape, device="mlu", dtype=dtype) + gamma = torch.randn(C, device="mlu", dtype=dtype) + self.launch(input, residual, gamma, None, None, eps, True, None, None, False) + # test inplace output + output = torch.empty(input_shape, device="mlu", dtype=dtype) + self.launch(input, residual, gamma, None, None, eps, False, None, output, False) + inputs_shape = (8, 8, 10, C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input = inputs[:, :, 0:6, :] + self.launch(input, None, gamma, None, None, eps, False, None, None, False) + self.launch(input, residual, gamma, None, None, eps, True, None, None, False) + #res has stride + res_shape = (8, 8, 16, C) + res = torch.randn(res_shape, device="mlu", dtype=dtype) + residual = res[..., 0:6, :] + self.launch(input, residual, gamma, None, None, eps, True, None, None, False) + #output has different stride + outputs = torch.randn(8, 8, 12, C, dtype=dtype, device='mlu') + output = outputs[..., 0:6, :] + self.launch(input, residual, gamma, None, None, eps, False, None, output, False) + + def test_squant_rmsnorm(self): + C = 128 + input_shape = (8, 8, 6, C) + print("test squant_rmsnorm...") + for dtype in dtype_list: + torch.manual_seed(1) + input = torch.randn(input_shape, device="mlu", dtype=dtype) + residual = torch.randn(input_shape, device="mlu", dtype=dtype) + quant_scale = torch.randn(C, device="mlu", dtype=torch.float) * 30 + gamma = torch.randn(C, device="mlu", dtype=dtype) + self.launch(input, residual, gamma, None, None, eps, True, quant_scale, None, False) + # test one output + self.launch(input, None, gamma, None, None, eps, False, quant_scale, None, False) + + def test_rmsnorm_stride(self): + C = 128 + print("test rmsnorm stride input...") + for dtype in dtype_list: + gamma = torch.randn(C, device="mlu", dtype=dtype) + inputs_shape = (8, 8, 10, C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input = inputs[:, :, 0:6, :] + output = torch.empty(inputs.shape, device="mlu", dtype=dtype) + output = output.as_strided(input.shape, input.stride()) + self.launch(input, None, gamma, None, None, eps, False, None, output, False) + # test assumption get wrong stride when dim = 1 + input = inputs[:, :, 0 :1, :] + input = input.as_strided(input.shape, (10240, 1280, 0, 1)) + self.launch(input, None, gamma, None, None, eps, False, None, input, False) + inputs_shape = (8, 8, 10, 2*C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input1 = inputs[:, :, 0:2, 0:C] + self.launch(input1, None, gamma, None, None, eps, False, None, input1, False) + input2 = inputs[:, :, :, 0:C] + self.launch(input2, None, gamma, None, None, eps, False, None, input2, False) + #tset 3-dim input + inputs_shape = (8, 8, 2*C) + inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype) + input1 = inputs[:, 0:2, 0:C] + self.launch(input1, None, gamma, None, None, eps, False, None, input1, False) + input2 = inputs[0:2, 0:2, 0:C] + self.launch(input2, None, gamma, None, None, eps, False, None, input2, False) + + + # 防呆测试 + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func1 = ops.fused_rms_norm + batch, seq_len, hidden_size = 5, 12, 512 + input = torch.randn(hidden_size, dtype=torch.half, device='mlu') + quant_scale = torch.randn(hidden_size, dtype=torch.half, device='mlu') + gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu') + self.assertException("input.dim() >= 2.", + func1, input, None, gamma, None, None, eps, False, None, None) + + inputs = torch.randn(batch, seq_len, 2*hidden_size, dtype=torch.half, device='mlu') + input = inputs[..., ::2] + self.assertException("input last dim must be contiguous.", + func1, input, None, gamma, None, None, eps, False, None, None) + + input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + self.assertException("rmsnorm mode need gamma.", + func1, input, None, None, None, None, eps, False, None, None) + gamma = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu') + self.assertException("rmsnorm mode, gamma size must be hidden_size.", + func1, input, None, gamma, None, None, eps, False, None, None) + + gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu') + inputs = torch.randn(batch, seq_len, 12, hidden_size, dtype=torch.half, device='mlu') + input = inputs[..., 6:9, :] + self.assertException("quant_out is not support when input has stride.", + func1, input, None, gamma, None, None, eps, False, quant_scale, None) + inputs = torch.randn(batch, 12, seq_len, hidden_size, dtype=torch.half, device='mlu') + input = inputs[:, 6:9, ...] + self.assertException("check the strides of input.", + func1, input, None, gamma, None, None, eps, False, None, input) + input = inputs[..., 6:9, :] + outputs = torch.randn(batch, 2 * seq_len, 12, hidden_size, dtype=torch.half, device='mlu') + output = outputs[:, 0:seq_len, 6:9, :] + self.assertException("check the strides of output.", + func1, input, None, gamma, None, None, eps, False, None, output) + + + def test_inductor(self): + batch, seq_len, hidden_size = 5, 12, 512 + input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu') + quant_scale = torch.randn(hidden_size, dtype=torch.float, device='mlu') + gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu') + beta = torch.randn(hidden_size, dtype=torch.half, device='mlu') + bias = torch.randn(hidden_size, dtype=torch.half, device='mlu') + residual_out = torch.empty_like(input) + + output = torch.zeros(input.shape, dtype=torch.int8, device='mlu') + args = (input, output, residual, gamma, beta, bias, quant_scale, residual_out, None, "rmsnorm", eps, True, False) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args) + args = (input, output, residual, gamma, beta, bias, quant_scale, None, None, "rmsnorm", eps, False, False) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args) + + output = torch.empty_like(input) + args = (input, output, residual, gamma, beta, bias, None, None, None, "rmsnorm", eps, False, False) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args) + +if __name__ == '__main__': + run_unittest(TestFuseRmsNormOp) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rope.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rope.py new file mode 100644 index 0000000..a03e5b7 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_fused_rope.py @@ -0,0 +1,447 @@ +import torch +import unittest +from torch.nn import functional as F +import torch_mlu_ops as ops +from common_utils import * +import random + +def genSlotMapping(batch, block_size): + output = [] + for i in range(batch): + idx = random.randint(i * block_size, (i + 1) * block_size - 1) + output.append(idx) + + return output + +def rotate(x: torch.Tensor): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + +def quant(input: torch.Tensor): + input_fp32 = input.to(torch.float32) + max_value, _ = torch.max(input_fp32.abs(), dim=-1, keepdim=True) + scale = max_value / 7 + scaled_input = torch.round(input_fp32 / scale) + return scaled_input.to(torch.int8), scale[..., 0], input_fp32 / scale + +class TestFusedRopeOp(BtTestCase): + def op_impl_base(self, *args): + qkv, k_cache_hp, v_cache_hp, sin_cache, cos_cache, position_id, gamma, beta, \ + k_cache_lp, v_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, cache_bs_id_lp, \ + cache_seq_offsets_lp, k_scale_hp, v_scale_hp, k_scale_lp, v_scale_lp, slot_mapping_hp, \ + slot_mapping_lp, eps = args + mixed_cache = k_cache_lp is not None and v_cache_lp is not None + qkv = qkv.to(torch.float32) + sin_cache = sin_cache.to(torch.float32) + cos_cache = cos_cache.to(torch.float32) + gamma = gamma.to(torch.float32) + beta = beta.to(torch.float32) + if k_scale_hp is not None: + k_scale_hp = 1 / k_scale_hp + if v_scale_hp is not None: + v_scale_hp = 1 / v_scale_hp + if not mixed_cache and k_scale_hp is None: + k_cache_hp = k_cache_hp.to(torch.float32) + v_cache_hp = v_cache_hp.to(torch.float32) + discrete_batch_hp = cache_bs_id_hp is not None + discrete_batch_lp = cache_bs_id_lp is not None + paged_cache_hp = slot_mapping_hp is not None + paged_cache_lp = slot_mapping_lp is not None + head_size = qkv.shape[-1] + rope_dim = head_size + batch_size = qkv.shape[0] + head_qkv = qkv.shape[-2] + kv_heads = k_cache_hp.shape[1] + q_heads = head_qkv - 2 * kv_heads + head_qk = q_heads + kv_heads + group_num = 1 + group_size = head_size + if mixed_cache: + group_num = k_scale_lp.shape[-1] + group_size = int(head_size / group_num) + k_cache_lp_shape = list(k_cache_lp.shape) + k_cache_lp_shape[-1] *= 2 + k_cache_lp_int8 = UnpackInt4(k_cache_lp).reshape(k_cache_lp_shape) + v_cache_lp_shape = list(v_cache_lp.shape) + v_cache_lp_shape.append(2) + v_cache_lp_int8 = UnpackInt4(v_cache_lp).reshape(v_cache_lp_shape) + + qk = qkv[:, :, 0:q_heads + kv_heads].clone() + + for i in range(batch_size): + qk_i = qk[i] + sin_cache_i = sin_cache[position_id[i]:position_id[i] + 1] + cos_cache_i = cos_cache[position_id[i]:position_id[i] + 1] + sin_cache_i = sin_cache_i[:1] + cos_cache_i = cos_cache_i[:1] + rot = rotate(qk_i) + + qk_i[:] = rot * sin_cache_i.unsqueeze(1) + qk_i * cos_cache_i.unsqueeze(1) + + q = qk[:, :, 0:q_heads] + k = qk[:, :, q_heads:head_qk].contiguous().reshape(batch_size, kv_heads, head_size) + + qkv_q = qkv[:, :, 0:q_heads] + qkv_q[...] = q + + shape_k = k.shape + k = torch.reshape(k, (-1, shape_k[-1])) + k_norm = F.layer_norm(k, (head_size,), gamma, beta, eps) + k_norm = torch.reshape(k_norm, shape_k) + + k_out = k_norm.reshape(batch_size, 1, kv_heads, head_size) + v_out = qkv[:, :, head_qk:head_qkv].contiguous() + k_out_hp = k_out.clone() + v_out_hp = v_out.clone() + k_out_lp = None + v_out_lp = None + if k_scale_hp is not None and v_scale_hp is not None: + k_scale_hp = k_scale_hp.reshape(kv_heads, head_size) + v_scale_hp = v_scale_hp.reshape(kv_heads, head_size) + k_out_hp = (k_out * k_scale_hp).round().clamp(-128, 127).to(torch.int8) + v_out_hp = (v_out * v_scale_hp).round().clamp(-128, 127).to(torch.int8) + + if paged_cache_hp: + block_size = k_cache_hp.shape[2] + for i in range(batch_size): + if slot_mapping_hp[i] >= 0: + block_id = torch.div(slot_mapping_hp[i], block_size, rounding_mode='floor') + block_offset = slot_mapping_hp[i] % block_size + k_cache_hp[block_id, :, block_offset, :] = k_out_hp[i] + v_cache_hp[block_id, :, block_offset, :] = v_out_hp[i] + else: + for i in range(batch_size): + key_i = k_out_hp[i].transpose(1, 0) + value_i = v_out_hp[i].transpose(1, 0) + + cache_bs_id_hp_i = cache_bs_id_hp[i] if discrete_batch_hp else i + cache_seqlen_offset_hp_i = cache_seq_offsets_hp[i] + if cache_seqlen_offset_hp_i < 0 or cache_bs_id_hp_i < 0: + continue + + key_cache_hp_i = \ + k_cache_hp[cache_bs_id_hp_i, :, cache_seqlen_offset_hp_i:cache_seqlen_offset_hp_i + 1] + key_cache_hp_i[...] = key_i[...] + + value_cache_hp_i = \ + v_cache_hp[cache_bs_id_hp_i, :, cache_seqlen_offset_hp_i:cache_seqlen_offset_hp_i + 1] + value_cache_hp_i[...] = value_i[...] + + if mixed_cache: + for i in range(batch_size): + key_i = k_out[i].reshape(kv_heads, group_num, group_size) + value_i = v_out[i].reshape(kv_heads, group_num, group_size) + key_i_lp, key_scale_i_lp, _ = quant(key_i) + value_i_lp, value_scale_i_lp, scaled_input = quant(value_i) + + if paged_cache_lp: + block_size = k_cache_lp_int8.shape[2] + if slot_mapping_lp[i] >= 0: + block_id = torch.div(slot_mapping_lp[i], block_size, rounding_mode='floor') + block_offset_k = slot_mapping_lp[i] % block_size + block_offset_v = torch.div(block_offset_k, 2, rounding_mode='floor') + odd_even_offset = block_offset_k % 2 + k_cache_lp_int8[block_id, :, block_offset_k, :] = key_i_lp.reshape(kv_heads, head_size) + v_cache_lp_int8[block_id, :, block_offset_v, :, odd_even_offset] = value_i_lp.reshape(kv_heads, head_size) + k_scale_lp[block_id, :, block_offset_k, :] = key_scale_i_lp[...] + v_scale_lp[block_id, :, block_offset_k, :] = value_scale_i_lp[...] + else: + cache_bs_id_lp_i = cache_bs_id_lp[i] if discrete_batch_lp else i + cache_seqlen_offset_lp_i = cache_seq_offsets_lp[i] + v_seq_offset = torch.div(cache_seqlen_offset_lp_i, 2, rounding_mode='floor') + odd_even_offset = cache_seqlen_offset_lp_i % 2 + if cache_seqlen_offset_lp_i < 0 or cache_bs_id_lp_i < 0: + continue + key_cache_lp_i = \ + k_cache_lp_int8[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i] + key_cache_lp_i[...] = key_i_lp.reshape(kv_heads, head_size) + key_scale_lp_i = \ + k_scale_lp[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i] + key_scale_lp_i[...] = key_scale_i_lp[...] + + value_cache_lp_i = \ + v_cache_lp_int8[cache_bs_id_lp_i, :, v_seq_offset, :, odd_even_offset] + value_cache_lp_i[...] = value_i_lp.reshape(kv_heads, head_size) + value_scale_lp_i = \ + v_scale_lp[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i] + value_scale_lp_i[...] = value_scale_i_lp[...] + out = (qkv, k_cache_hp, v_cache_hp) + if mixed_cache: + k_cache_lp = PairlyPackInt8(k_cache_lp_int8.view(-1, head_size)).reshape(k_cache_lp.shape) + v_cache_lp_int8 = v_cache_lp_int8.transpose(2, 3) + s0,s1,s2,s3,s4 = v_cache_lp_int8.shape + v_cache_lp = PairlyPackInt8(v_cache_lp_int8.reshape(-1, s3 * s4)).reshape(s0, s1, s2, s3 * s4 // 2).transpose(2,3) + out += (k_cache_lp, v_cache_lp) + if k_scale_lp is not None: + out += (k_scale_lp, v_scale_lp) + return out + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "fused_rope not support MLU3XX device") + def test_fused_rope(self): + random_cases = 100 + random.seed(355) + eps = 1e-6 + for _ in range(random_cases): + if random_cases == 1: + bs = 512 + seq_len = 1 + q_heads = 8 + kv_heads = 1 + head_size = 128 + rope_dim = 128 + dtype = torch.float16 + need_quant_kv = False + mixed_cache = False + max_decode_len_hp = 128 + max_decode_len_lp = 128 + discrete_batch_hp = True + discrete_batch_lp = False + paged_cache_hp = False + paged_cache_lp = True + block_size_hp = 16 + block_size_lp = 128 + num_blocks_hp = int((bs * max_decode_len_hp + block_size_hp - 1) / block_size_hp) + num_blocks_lp = int((bs * max_decode_len_lp + block_size_lp - 1) / block_size_lp) + group_num = 1 + else: + bs = random.randint(1, 512) + seq_len = 1 + q_heads = random.randint(1, 32) + kv_heads = random.randint(1, 32) + head_size_list = [32, 64, 96, 128] + head_size = random.choice(head_size_list) + rope_dim = head_size + dtype_list = [torch.half, torch.bfloat16] + bool_list = [True, False] + block_size_list = [16, 32, 64, 128] + dtype = random.choice(dtype_list) + need_quant_kv = random.choice(bool_list) + mixed_cache = random.choice(bool_list) + max_decode_len_hp = random.randint(128, 1024) + max_decode_len_lp = random.randint(128, 1024) + discrete_batch_hp = random.choice(bool_list) + discrete_batch_lp = random.choice(bool_list) + paged_cache_hp = random.choice(bool_list) + paged_cache_lp = random.choice(bool_list) + block_size_hp = random.choice(block_size_list) + num_blocks_hp = int((bs * max_decode_len_hp + block_size_hp - 1) / block_size_hp) + block_size_lp = random.choice(block_size_list) + num_blocks_lp = int((bs * max_decode_len_lp + block_size_lp - 1) / block_size_lp) + group_num_list = [1, 2, 4, 8] + group_num = random.choice(group_num_list) + + if mixed_cache: + need_quant_kv = True + group_size = int(head_size / group_num) + if max_decode_len_lp % 2 != 0: + max_decode_len_lp = max_decode_len_lp - 1 + print("bs: {}, seq_len: {}, q_heads: {}, kv_heads: {}, head_size: {}, rope_dim: {}, " + "dtype: {}, mixed_cache: {}, quant_kv: {}, paged_cache_hp: {}, paged_cache_lp: {}, " + "discrete_batch_hp: {}, discrete_batch_lp: {}." + .format(bs, seq_len, q_heads, kv_heads, head_size, rope_dim, dtype, mixed_cache, + need_quant_kv, paged_cache_hp, paged_cache_lp, discrete_batch_hp, discrete_batch_lp)) + + if mixed_cache: + print("max_decode_len_hp: {}, max_decode_len_lp: {}, num_blocks_hp: {}, num_blocks_lp: {}, " + "block_size_hp: {}, block_size_lp: {}, group_num: {}, testing..." + .format(max_decode_len_hp, max_decode_len_lp, num_blocks_hp, num_blocks_lp, + block_size_hp, block_size_lp, group_num)) + + max_bs_hp = bs + 1 if discrete_batch_hp else bs + max_bs_lp = bs + 1 if discrete_batch_lp else bs + cache_size = 1 if need_quant_kv else 2 + cache_bytes_hp = num_blocks_hp * kv_heads * block_size_hp * head_size if paged_cache_hp else \ + max_bs_hp * kv_heads * max_decode_len_hp * head_size + cache_bytes_hp = cache_bytes_hp * cache_size + cache_bytes_lp = 0 + if mixed_cache: + cache_bytes_lp = num_blocks_lp * kv_heads * block_size_lp * head_size if paged_cache_lp else \ + max_bs_lp * kv_heads * max_decode_len_lp * head_size + + if cache_bytes_hp > 2**31 or cache_bytes_lp > 2**31: + print("cache bytes can not be larger than int32max. ") + continue + + input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size) + input = torch.randn(size=input_shape, dtype=dtype).mlu() + + if paged_cache_hp: + cache_shape_hp = (num_blocks_hp, kv_heads, block_size_hp, head_size) + else: + cache_shape_hp = (max_bs_hp, kv_heads, max_decode_len_hp, head_size) + if need_quant_kv: + k_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu') + k_cache_hp = (k_cache_hp - 0.5) * 256 + k_cache_hp = k_cache_hp.to(torch.int8) + v_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu') + v_cache_hp = (k_cache_hp - 0.5) * 256 + v_cache_hp = k_cache_hp.to(torch.int8) + k_scale_ops_hp = 1 / (torch.randn(size=(kv_heads, head_size), dtype=torch.float).abs().mlu() + 0.01) + v_scale_ops_hp = 1 / (torch.randn(size=(kv_heads, head_size), dtype=torch.float).abs().mlu() + 0.01) + else: + k_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu') + v_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu') + k_scale_ops_hp = None + v_scale_ops_hp = None + + k_cache_lp = None + v_cache_lp = None + k_scale_lp = None + v_scale_lp = None + if mixed_cache: + if paged_cache_lp: + k_scale_lp = torch.randn(size=(num_blocks_lp, kv_heads, block_size_lp, group_num), dtype=torch.float).mlu() + v_scale_lp = torch.randn(size=(num_blocks_lp, kv_heads, block_size_lp, group_num), dtype=torch.float).mlu() + cache_raw = torch.randn((num_blocks_lp * kv_heads * block_size_lp, head_size), dtype=dtype, device='mlu') + max_value = torch.amax(torch.abs(cache_raw)) + cache_raw = cache_raw * (7 / max_value) + cache_raw = cache_raw.to(torch.int8) + k_cache_lp = PairlyPackInt8(cache_raw).reshape(num_blocks_lp, kv_heads, block_size_lp, int(head_size / 2)) + cache_raw = torch.randn((int(num_blocks_lp * kv_heads * block_size_lp / 2), head_size * 2), dtype=dtype, device='mlu') + max_value = torch.amax(torch.abs(cache_raw)) + cache_raw = cache_raw * (7 / max_value) + cache_raw = cache_raw.to(torch.int8) + v_cache_lp = PairlyPackInt8(cache_raw).reshape(num_blocks_lp, kv_heads, int(block_size_lp / 2), head_size) + k_cache_lp_ref_shape = (num_blocks_lp, kv_heads, block_size_lp, head_size) + v_cache_lp_ref_shape = (num_blocks_lp, kv_heads, int(block_size_lp / 2), head_size, 2) + else: + k_scale_lp = torch.randn(size=(max_bs_lp, kv_heads, max_decode_len_lp, group_num), dtype=torch.float).mlu() + v_scale_lp = torch.randn(size=(max_bs_lp, kv_heads, max_decode_len_lp, group_num), dtype=torch.float).mlu() + cache_raw = torch.randn((max_bs_lp * kv_heads * max_decode_len_lp, head_size), dtype=dtype, device='mlu') + max_value = torch.amax(torch.abs(cache_raw)) + cache_raw = cache_raw * (7 / max_value) + cache_raw = cache_raw.to(torch.int8) + k_cache_lp = PairlyPackInt8(cache_raw).reshape(max_bs_lp, kv_heads, max_decode_len_lp, int(head_size / 2)) + cache_raw = torch.randn((int(max_bs_lp * kv_heads * max_decode_len_lp / 2), head_size * 2), dtype=dtype, device='mlu') + max_value = torch.amax(torch.abs(cache_raw)) + cache_raw = cache_raw * (7 / max_value) + cache_raw = cache_raw.to(torch.int8) + v_cache_lp = PairlyPackInt8(cache_raw).reshape(max_bs_lp,kv_heads, int(max_decode_len_lp / 2), head_size) + k_cache_lp_ref_shape = (max_bs_lp, kv_heads, max_decode_len_lp, head_size) + v_cache_lp_ref_shape = (max_bs_lp, kv_heads, int(max_decode_len_lp / 2), head_size, 2) + del cache_raw + + cache_bs_id_hp = None + cache_bs_id_lp = None + cache_seq_offsets_hp = None + cache_seq_offsets_lp = None + slot_mapping_hp = None + slot_mapping_lp = None + if not paged_cache_hp: + if discrete_batch_hp: + cache_bs_id_hp = random.sample([*range(0, max_bs_hp)], bs) + cache_bs_id_hp = torch.IntTensor(cache_bs_id_hp).mlu() + cache_seq_offsets_hp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_hp - 2, + dtype=torch.int32, device='mlu') + else: + slot_mapping_hp = random.sample([*range(-1, block_size_hp * num_blocks_hp)], bs) + slot_mapping_hp = torch.IntTensor(slot_mapping_hp).mlu() + + input_ref = input.clone() + k_cache_hp_ref = k_cache_hp.clone() + v_cache_hp_ref = v_cache_hp.clone() + k_cache_lp_ref = None + v_cache_lp_ref = None + k_scale_lp_ref = None + v_scale_lp_ref = None + if mixed_cache: + if not paged_cache_lp: + if discrete_batch_lp: + cache_bs_id_lp = random.sample([*range(0, max_bs_lp)], bs) + cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu() + cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2, + dtype=torch.int32, device='mlu') + else: + slot_mapping_lp = genSlotMapping(bs, block_size_lp) + slot_mapping_lp = torch.IntTensor(slot_mapping_lp).mlu() + k_cache_lp_ref = k_cache_lp.clone() + v_cache_lp_ref = v_cache_lp.clone() + k_scale_lp_ref = k_scale_lp.clone() + v_scale_lp_ref = v_scale_lp.clone() + + cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() + sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() + gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu() + beta = torch.randn(size=(head_size, ), dtype=dtype).mlu() + position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu') + + notify_start = torch.mlu.Event(enable_timing=True) + notify_end = torch.mlu.Event(enable_timing=True) + notify_start.record() + base_output = self.op_impl_base(input_ref, k_cache_hp_ref, v_cache_hp_ref, sin_table, \ + cos_table, position_id, gamma, beta, \ + k_cache_lp_ref, v_cache_lp_ref, cache_bs_id_hp, cache_seq_offsets_hp, \ + cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops_hp, v_scale_ops_hp, \ + k_scale_lp_ref, v_scale_lp_ref, slot_mapping_hp, slot_mapping_lp, eps) + notify_end.record() + notify_end.synchronize() + time = notify_start.hardware_time(notify_end) + print("baseline hw_time is: ", time, "us") + del input_ref, k_cache_hp_ref, v_cache_hp_ref, k_cache_lp_ref, v_cache_lp_ref, k_scale_lp_ref, v_scale_lp_ref + + notify_start.record() + loop = 1 + for _ in range(loop): + ops.fused_rope(input, k_cache_hp, v_cache_hp, sin_table, cos_table, position_id, \ + gamma, beta, k_cache_lp, v_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, \ + cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops_hp, v_scale_ops_hp, \ + k_scale_lp, v_scale_lp, slot_mapping_hp, slot_mapping_lp, eps) + + notify_end.record() + notify_end.synchronize() + time = notify_start.hardware_time(notify_end) / loop + print("hw time is: ", time, "us") + + print("check input diff \n") + self.assertTensorsEqual(input.cpu().float(), base_output[0].cpu().float(), 0.003, use_MSE=True) + print("pass \n") + print("check key cache hp diff \n") + self.assertTensorsEqual(k_cache_hp.cpu().float(), base_output[1].cpu().float(), 0.003, use_MSE=True) + print("pass \n") + print("check value cache hp diff \n") + self.assertTensorsEqual(v_cache_hp.cpu().float(), base_output[2].cpu().float(), 0.003, use_MSE=True) + print("pass \n") + if mixed_cache: + k_cache_lp_int8 = UnpackInt4(k_cache_lp).reshape(k_cache_lp_ref_shape) + v_cache_lp_int8 = UnpackInt4(v_cache_lp).reshape(v_cache_lp_ref_shape) + k_cache_lp_ref_int8 = UnpackInt4(base_output[3]).reshape(k_cache_lp_ref_shape) + v_cache_lp_ref_int8 = UnpackInt4(base_output[4]).reshape(v_cache_lp_ref_shape) + + print("check key cache lp diff \n") + self.assertTensorsEqual(k_cache_lp_int8.cpu().float(), k_cache_lp_ref_int8.cpu().float(), 1) + print("pass \n") + print("check key scale lp diff \n") + self.assertTensorsEqual(k_scale_lp.cpu().float(), base_output[5].cpu().float(), 0.003, use_MSE=True) + print("pass \n") + print("check value cache lp diff \n") + self.assertTensorsEqual(v_cache_lp_int8.cpu().float(), v_cache_lp_ref_int8.cpu().float(), 1) + print("pass \n") + print("check value scale lp diff \n") + self.assertTensorsEqual(v_scale_lp.cpu().float(), base_output[6].cpu().float(), 0.003, use_MSE=True) + print("pass \n") + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "fused_rope not support MLU3XX device") + def test_inductor(self): + bs, seq_len, q_heads, kv_heads, head_size, rope_dim, max_decode_len, dtype= 40, 1, 8, 1, 128, 128, 2048, torch.bfloat16 + input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size) + input = torch.randn(size=input_shape, dtype=dtype).mlu() + max_bs = bs + 1 + cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() + sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu() + gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu() + beta = torch.randn(size=(head_size, ), dtype=dtype).mlu() + + cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu') + k_cache = cache[0] + v_cache = cache[1] + cache_bs_id = random.sample([*range(0, max_bs)], bs) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2, + dtype=torch.int32, device='mlu') + position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu') + args = (input, k_cache, v_cache, None, None, sin_table, cos_table, position_id, gamma, beta, None, + None, None, None, cache_bs_id, cache_seq_offsets, None, None, None, None, 1e-5) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_rope, args) + +if __name__ == "__main__": + exit(run_unittest(TestFusedRopeOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_group_gemm.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_group_gemm.py new file mode 100755 index 0000000..ed8eaad --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_group_gemm.py @@ -0,0 +1,122 @@ +import torch +from torch_mlu import mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +from torch.nn import functional as F + +def gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias=False): + bs = batch * seq + token_topk = bs * topk + expert_id = torch.randint(experts_num, (token_topk,), device="mlu") + sorted_expert_id, indices = expert_id.sort() + gather_idx = indices // topk + gather_idx = gather_idx.to(torch.int32) + token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32) + + a = torch.randn(bs, k, device="mlu", dtype=data_type) + if not idx_mode: + a = a[gather_idx] + b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type) + c = torch.randn(token_topk, n, device="mlu", dtype=data_type) + alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32) + beta = torch.randn(experts_num, device="mlu", dtype=torch.float32) + a_scale = None + b_scale = None + bias = None + if has_bias: + bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) + + gather_idx_ = gather_idx if idx_mode else None + return a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, data_type, bias + + +class TestGroupGemmOp(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + a = create_tensor_from_dic(dic['a']) + b = create_tensor_from_dic(dic['b']) + m_list = dic['m_list']['data'] + expand_idx = dic['expand_idx']['data'] + c = None if dic['c']['data'] is None else create_tensor_from_dic(dic['c']) + alpha = None if dic['alpha']['data'] is None else create_tensor_from_dic(dic['alpha']) + beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta']) + max_m = dic['max_m']['data'] + bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias']) + self.launch(a, b, m_list, expand_idx, c, alpha, beta, max_m, bias) + + def launch(self, *args): + total_m = args[2].sum().item() + torch_out = self.op_impl_base(*args) + tmo_out = ops.group_gemm(*args) + self.assertTensorsEqual(tmo_out.cpu().float()[0:total_m], torch_out.cpu().float()[0:total_m], 0.006, use_MSE=True) + + def op_impl_base(self, *args): + a, b, m_list, expand_idx, c, alpha, beta, max_m, bias = args + a = a.reshape(-1, a.size(-1)) + if expand_idx is not None: + a = a[expand_idx] + total_m = m_list.sum().item() + a_list = a[:total_m].split(tuple(m_list)) + + c_list = [] + if c is not None: + c = c.reshape(-1, c.size(-1)) + c_list = c[:total_m].split(tuple(m_list)) + + output_list = [] + for i in range(b.size(0)): # alpha*(a*b) + bias + beta*c + if (a_list[i].size(0) > 0): + gemm_out = torch.matmul(a_list[i], b[i].permute(1, 0)) + if alpha is not None: + gemm_out *= alpha[i] + if bias is not None: + gemm_out += bias[i] + if beta is not None and c_list != []: + gemm_out += c_list[i] * beta[i] + output_list.append(gemm_out) + real_res = torch.cat(output_list, dim=0) + output = torch.empty(a.shape[0], b.shape[1], device=real_res.device).to(real_res.dtype) + output[:total_m] = real_res + return output + + def test_group_gemm(self): + bs_list = [1, 3] + seq_list = [5, 8] + k_list = [512, 1024] + n_list = [512, 768, 2048] + expert_list = [8, 32] + topk_list = [2, 5] + dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half] + idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True] + has_bias_list = [True, False] + + args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list) + for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias in args: + print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, topk: {topk}, \ +dtype: {data_type}, idx_mode: {idx_mode}, has_bias: {has_bias} testing...", flush=True) + param = gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias) + torch_out = self.op_impl_base(*param[:7], batch * seq, param[10]) + tmo_out = ops.group_gemm(*param[:7], batch * seq, param[10]) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + + def test_inductor(self): + batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5 + dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half] + idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True] + has_bias_list = [True, False] + args = product( dtype_list, idx_list, has_bias_list) + for data_type, idx_mode, has_bias in args: + args = gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias) + args = list(args) + args[-2] = args[-1] # bias + args[-1] = None #dtype + args.extend([None, None, batch * seq]) #b_offset, max_m + self.base_opcheck(torch.ops.torch_mlu_ops.group_gemm, args) + +if __name__ == '__main__': + exit(run_unittest(TestGroupGemmOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_matmul.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_matmul.py new file mode 100755 index 0000000..5ecb0d6 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_matmul.py @@ -0,0 +1,193 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +import numpy as np + +class TestMatMulOp(BtTestCase): + def op_impl_base(self, *args): + a, b, bias, c, act_mode, alpha, beta, fast_act, approximate, d, \ + a_scale, b_scale, trans_a, trans_b = args + if a_scale is not None: + a = a / a_scale + if b_scale is not None: + b = b / b_scale + if trans_a: + a = a.transpose(0, 1) + if trans_b: + b = b.transpose(0, 1) + mul_out = alpha * torch.matmul(a, b) + if bias is not None: + mul_out += bias + if c is not None: + mul_out += beta * c + if act_mode in act_mode_dict.keys(): + active = act_mode_dict[act_mode] + mul_out = active(mul_out.float()).to(a.dtype) + return mul_out + + def test_matmul(self): + mat_m_list = [32] + mat_n_list = [256] + mat_k_list = [128] + has_res_list = [ False, True] + has_bias_list = [True, False] + trans_a_list = [False, True] + trans_b_list = [False, True] + act_mode_list = ['none', 'relu', 'gelu', 'silu'] + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + alpha = 0.625 + beta = 1.0 + args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list) + for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args: + torch.manual_seed(1) + print("m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format( + mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True) + if has_res : + beta = 1.0 + else : + beta = 0. + + shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n) + if trans_a: + shape_a = (mat_k, 4, mat_m) + if trans_b: + shape_b = (mat_n, 3, mat_k) + input0 = torch.randn(shape_a, dtype=dtype, device='mlu') + weight0 = torch.randn(shape_b, dtype=dtype, device='mlu') + input = input0[:, 1, :] + weight = weight0[:, 0, :] + bias = torch.randn((mat_n), dtype=dtype, device='mlu') + residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') + + output = self.op_impl_base(input, + weight, + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b) + tmo_output = ops.matmul(input, + weight, + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b) + tmo_output_contiguous = ops.matmul(input.contiguous(), weight.contiguous(), + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b) + + if act_mode == 'gelu': + tmo_output_high = ops.matmul(input.contiguous(), weight.contiguous(), + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, True, None, 1.0, 1.0, trans_a, trans_b) + self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + + self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + + # @unittest.skip("not test") + def test_matmul_int8(self): + mat_m_list = [32] + mat_n_list = [256] + mat_k_list = [128] + has_res_list = [True, False] + has_bias_list = [True, False] + trans_a_list = [True, False] + trans_b_list = [True, False] + act_mode_list = ['none', 'relu', 'silu', 'gelu'] + dtype_list = [torch.half, torch.float] + alpha = 0.625 + beta = 1.0 + args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list) + + for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args: + print("int8 test: m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format( + mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True) + torch.manual_seed(1) + if has_res : + beta = 1.0 + else : + beta = 0. + shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n) + if trans_a: + shape_a = (mat_k, 4, mat_m) + if trans_b: + shape_b = (mat_n, 3, mat_k) + input0 = torch.randn(shape_a, dtype=dtype, device='mlu') + weight0 = torch.randn(shape_b, dtype=dtype, device='mlu') + input = input0[:, 1, :] + weight = weight0[:, 0, :] + input8, a_scale = QuantByTensor(input, 8) + weight8, b_scale = QuantByTensor(weight, 8) + bias = torch.randn((mat_n), dtype=dtype, device='mlu') + residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') + + output = self.op_impl_base(input8, weight8, + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, False, None, a_scale, b_scale, trans_a, trans_b) + tmo_output = ops.matmul(input8, weight8, + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b) + tmo_output_contiguous = ops.matmul(input8.contiguous(), weight8.contiguous(), + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b) + + self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + if act_mode == 'gelu': + tmo_output_high = ops.matmul(input8.contiguous(), weight8.contiguous(), + alpha * bias if has_bias else None, + residual if has_res else None, + act_mode, alpha, beta, False, True, dtype, a_scale, b_scale, trans_a, trans_b) + self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True + trans_a_list = [True, False] + trans_b_list = [True, False] + dtype_list = [torch.half, torch.float] + args = product(trans_a_list, trans_b_list, dtype_list) + for trans_a, trans_b, dtype in args: + print("trans_a: {}, trans_b: {}, dtype: {}".format(trans_a, trans_b, dtype)) + shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n) + if trans_a: + shape_a = (mat_k, 4, mat_m) + if trans_b: + shape_b = (mat_n, 3, mat_k) + input0 = torch.randn(shape_a, dtype=dtype, device='mlu') + weight0 = torch.randn(shape_b, dtype=dtype, device='mlu') + a = input0[:, 1, :] + b = weight0[:, 0, :] + bias = torch.randn((mat_n), dtype=dtype, device='mlu') + c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') + args = (a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b) + self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args) + + a8, a_scale = QuantByTensor(a, 8) + b8, b_scale = QuantByTensor(b, 8) + str_dtype = "half" + if dtype == torch.float: + str_dtype = "float" + elif dtype == torch.bfloat16: + str_dtype = "bfloat16" + args = (a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b) + self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args) + + + +if __name__ == '__main__': + exit(run_unittest(TestMatMulOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe.py new file mode 100755 index 0000000..c9f4c59 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe.py @@ -0,0 +1,946 @@ +import torch +import unittest +import torch_mlu_ops as ops +from common_utils import * +from typing import Union, List, Tuple +from torch.nn.parameter import Parameter +from torch.nn import functional as F +from torch_mlu import mlu + +class TestFusedMOEOp(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + hidden_states = create_tensor_from_dic(dic['hidden_states']) + gating_output = create_tensor_from_dic(dic['gating_output']) + w1 = create_tensor_from_dic(dic['w1']) + w2 = create_tensor_from_dic(dic['w2']) + if dic['w1']['dtype'] is not torch.int8: + w1 *= 0.1 + w2 *= 0.1 + bias1 = None if dic['bias1']['data'] is None else create_tensor_from_dic(dic['bias1']) + bias2 = None if dic['bias2']['data'] is None else create_tensor_from_dic(dic['bias2']) + residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual']) + input_smooth = None if dic['input_smooth']['data'] is None else create_tensor_from_dic(dic['input_smooth'], is_uniform=True, low=0.01, high=0.05) + act_smooth = None if dic['act_smooth']['data'] is None else create_tensor_from_dic(dic['act_smooth'], is_uniform=True, low=0.01, high=0.05) + w1_scale = None if dic['w1_scale']['data'] is None else create_tensor_from_dic(dic['w1_scale'], is_uniform=True, low=-0.05, high=0.05) + w2_scale = None if dic['w2_scale']['data'] is None else create_tensor_from_dic(dic['w2_scale'], is_uniform=True, low=-0.05, high=0.05) + topk = dic['topk']['data'] + renormalize = dic['renormalize']['data'] + gated = dic['gated']['data'] + act_mode = dic['act_mode']['data'] + start_expert_id = dic['start_expert_id']['data'] + block_n = dic['block_n']['data'] + cncl_comm = dic['cncl_comm']['data'] + w1_quant_flag = dic['w1_quant_flag']['data'] + w2_quant_flag = dic['w2_quant_flag']['data'] + self.launch(hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, + act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, + start_expert_id, block_n, 0, w1_quant_flag, w2_quant_flag) + + def launch(self, *args): + base = self.op_impl_base(*args) + tmo_res = tmo.fused_moe(*args) + self.assertTensorsEqual(tmo_res.cpu().float(), base.cpu().float(), 0.03, use_MSE=True) + + def op_impl_base(self, *args): + hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, \ + act_smooth, w1_scale, w2_scale, topk, renormalize, \ + gated, act_mode, start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag = args + if w2.dim() == 4: + w2 = w2.transpose(1, 2).reshape(-1, w2.size(1), w2.size(-1)) + expert_num = gating_output.size(-1) + expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1) + def router(hidden_states, router_logit): + router_logit = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1) + topk_logit, expert_id = torch.topk(router_logit, k=topk, dim=1) + if renormalize: + topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1) + sorted_expert_id, indices = expert_id.int().flatten().sort() + nk = indices.size(0) + token_cout = torch.bincount(sorted_expert_id, minlength=expert_num).cpu() + expand_idx = indices.int() // topk + combine_idx = torch.zeros((nk,), dtype=torch.int, device="mlu") + combine_idx.scatter_(0, indices, torch.arange(nk, dtype=torch.int, device="mlu")) + input_expand = hidden_states[expand_idx] + input_list = input_expand.split(tuple(token_cout)) + + input_scale_list = [] + input_split_result = [] + if input_smooth is not None: + # do smooth quant on input + idx = 0 + for i in range(expert_num): + if i >= start_expert_id and i < start_expert_id + expert_size: + if (input_list[i].size(0) > 0): + temp = input_list[i] * input_smooth[idx] + input_split_result.append(temp) + idx += 1 + else: + input_split_result.append(input_list[i]) + quant_input, input_scale = QuantByRow(torch.cat(input_split_result, dim=0), 8) + input_list = quant_input.split(token_cout.tolist()) + input_scale_list = input_scale.split(token_cout.tolist()) + return input_list, input_scale_list, topk_logit.flatten().view(nk , 1), combine_idx + + dtype = hidden_states.dtype + if w1_quant_flag is None: + inner_size = w1.size(1) // 2 if gated else w1.size(1) + else: + inner_size = w1_scale.size(2) // 2 if gated else w1_scale.size(2) + hidden_size = w2.size(1) if w2_quant_flag is None else w2_scale.size(2) + input = hidden_states.view(-1, hidden_states.size(-1)) + gating_output = gating_output.view(-1, gating_output.size(-1)) + input_list, input_scale_list, reduce_weight, combine_idx = router(input, gating_output) + output_list = [] + idx = 0 + need_quant = len(input_scale_list) != 0 + if need_quant and w1_scale.dim() == 3: + w1_scale = w1_scale.transpose(0, 1).contiguous() + w2_scale = w2_scale.transpose(0, 1).contiguous() + if w1_quant_flag is not None: + w1_quant_group = w1_scale.size(1) + quant_wise = hidden_size // w1_quant_group + w1_quant_flag = torch.tensor(w1_quant_flag).view(-1, w1_quant_group) + w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated) + w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0) + if w2_quant_flag is not None: + w2_quant_group = w2_scale.size(1) + quant_wise = inner_size // w2_quant_group + w2_quant_flag = torch.tensor(w2_quant_flag).view(-1, w2_quant_group) + w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size + w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0) + for i in range(expert_num): # start_expert_id : start_expert_id + expert_size + if i >= start_expert_id and i < start_expert_id + expert_size: + if (input_list[i].size(0) > 0): + if need_quant: + if w1_quant_flag is None: + gemm1_out = smooth_quant_matmul(input_list[i], input_scale_list[i], + w1[idx], w1_scale[idx], dtype) + else: + gemm1_out = smooth_quant_matmul_w4w8_mixed(input_list[i], input_scale_list[i], + w1[w1_offset_cu[idx]:w1_offset_cu[idx+1]], w1_scale[idx], dtype, + quant_flag = w1_quant_flag[idx]) + else: + gemm1_out = torch.matmul(input_list[i], w1[idx].permute(1, 0)) + act_in = gemm1_out[:, :inner_size].float() + gate = gemm1_out[:, inner_size:] + act = act_mode_dict[act_mode] + gemm1_out = act(act_in).to(dtype=dtype) if gated == False else \ + act(act_in).to(dtype=dtype) * gate + if need_quant: + quant_gemm1_out, gemm1_out_scale = QuantByRow(gemm1_out * act_smooth[idx], 8) + if w2_quant_flag is None: + gemm2_out = smooth_quant_matmul(quant_gemm1_out, gemm1_out_scale, w2[idx], w2_scale[idx], dtype) + else: + gemm2_out = smooth_quant_matmul_w4w8_mixed(quant_gemm1_out, gemm1_out_scale, + w2[w2_offset_cu[idx]:w2_offset_cu[idx+1]], w2_scale[idx], dtype, + quant_flag = w2_quant_flag[idx]) + else: + gemm2_out = torch.matmul(gemm1_out, w2[idx].permute(1, 0)) + output_list.append(gemm2_out) + idx += 1 + else: + output_list.append(torch.zeros_like(input_list[i])) + output = torch.cat(output_list, dim=0)[combine_idx].float() * reduce_weight + output = output.reshape(-1, topk, hidden_size).sum(dim=1).to(dtype=dtype) + if residual is not None: + output = output + residual.view(input.shape) + return output.view(hidden_states.shape) + + def test_fused_moe(self): + print("test_fused_moe") + batch, seq, hidden_size, inner_size = 3, 5, 512, 768 + dtype_list = [torch.half] + if mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', dtype + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) + bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) + + start_expert_id_list = [0, 1, 3, 4, 5, 6] + expert_size_list = [8, 4, 3, 2, 3, 2] + for i in range(len(start_expert_id_list)): + start_expert_id = start_expert_id_list[i] + expert_size = expert_size_list[i] + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +start_expert_id: {start_expert_id}, expert_size: {expert_size}, topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + torch_out = self.op_impl_base(hidden_states, + router_logit, + weight1[start_expert_id:start_expert_id+expert_size], + weight2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id, + 0, 0, None, None) + # (N, T, C) + tmo_out_1 = ops.fused_moe(hidden_states, + router_logit, + weight1[start_expert_id:start_expert_id+expert_size], + weight2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id) + # (N*T, C) + tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), + router_logit.view(-1, expert_num), + weight1[start_expert_id:start_expert_id+expert_size], + weight2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual.view(-1, hidden_size) if residual is not None else None, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id).reshape(batch, seq, hidden_size) + tmo_out_3 = fused_moe(hidden_states, + router_logit, + weight1[start_expert_id:start_expert_id+expert_size], + weight2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id) + + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_3.cpu().float(), 0.006, use_MSE=True) + + def test_pertoken_quant_fused_moe_tp(self): + print("test_pertoken_quant_fused_moe_tp") + batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 + expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16 + if not mlu.is_bf16_supported(): + data_type = torch.float16 + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + self.__run_sq_case(batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode) + + def test_moe_tp2_mixed_ep4_no_quant(self): + print("test_moe_tp2_mixed_ep4_no_quant") + tp_num = 2 + ep_num = 4 + expert_num = 32 + expert_size = expert_num // ep_num + batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 + topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16 + if not mlu.is_bf16_supported(): + data_type = torch.float16 + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + residual = None + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) + bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) + w1 = weight1.reshape(ep_num * expert_size, tp_num, (inner_size * (1 + gated)) // tp_num, hidden_size) + w2 = weight2.reshape(ep_num * expert_size, hidden_size, tp_num, inner_size // tp_num) + + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + tmo_out1 = torch.zeros_like(hidden_states) + tmo_out2 = torch.zeros_like(hidden_states) + torch_out = torch.zeros_like(hidden_states) + for tp_idx in range(tp_num): + w1_curr_tp = w1[:, tp_idx, ...] + w2_curr_tp = w2[:, :, tp_idx, :] + for ep_idx in range(ep_num): + start_expert_id = ep_idx * expert_size + w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous() + w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous() + tmo_out1 += ops.fused_moe(hidden_states, + router_logit, + w1_curr_tp_and_ep, + w2_curr_tp_and_ep, + bias1, + bias2, + residual, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id) + tmo_out2 += fused_moe(hidden_states, + router_logit, + w1_curr_tp_and_ep, + w2_curr_tp_and_ep, + bias1, + bias2, + residual, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id) + torch_out += self.op_impl_base(hidden_states, + router_logit, + w1_curr_tp_and_ep, + w2_curr_tp_and_ep, + bias1, + bias2, + residual, + None, + None, + None, + None, + topk, + renormalize, + gated, + act_mode, + start_expert_id, + 0, 0, None, None) + tmo_out2 = tmo_out2.reshape(batch, seq, hidden_size) + + self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + self.assertTensorsEqual(tmo_out1.cpu().float(), tmo_out2.cpu().float(), 0.006, use_MSE=True) + + + def test_moe_tp2_mixed_ep4_quant(self): + print("test_moe_tp2_mixed_ep4_quant") + tp_num = 2 + ep_num = 4 + expert_num = 32 + expert_size = expert_num // ep_num + batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 + topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16 + if not mlu.is_bf16_supported(): + data_type = torch.float16 + scale_s = 0.1 # avoid the occurrence of inf + eps = 0.01 # Avoid the occurrence of nan + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + residual = None + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s + bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) + + input_smooth = torch.randn(expert_num, hidden_size, device='mlu', dtype=torch.float32).abs() + eps + act_smooth = torch.randn(expert_num, (1+gated)*inner_size, device='mlu', dtype=torch.float32).abs() + eps + weight1_shape, weight2_shape = weight1.shape, weight2.shape + weight1 = weight1 / input_smooth.unsqueeze(1) + weight2 = weight2 / act_smooth.unsqueeze(1) + quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1_shape[-1]), 8) + quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2_shape[-1]), 8) + quant_w1, quant_w2 = quant_w1.view(weight1.shape), quant_w2.view(weight2.shape) + w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) + torch_out = torch.zeros_like(hidden_states) + for _ in range(tp_num): + w1_scale_tp = w1_scale.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num) + act_smooth_tp = act_smooth.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num) + + w1 = quant_w1.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num, hidden_size) + w2 = quant_w2.reshape(ep_num, expert_size, hidden_size, tp_num, inner_size*(1+gated)//tp_num) + input_smooth = input_smooth.reshape(ep_num, expert_size, hidden_size) + act_smooth = act_smooth.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num) + w1_scale = w1_scale.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num) + w2_scale = w2_scale.reshape(ep_num, expert_size, hidden_size) + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + tmo_out = torch.zeros_like(hidden_states) + tmo_out1 = torch.zeros_like(hidden_states) + torch_out = torch.zeros_like(hidden_states) + for tp_idx in range(tp_num): + w1_curr_tp = w1[:, :, tp_idx, ...] + w2_curr_tp = w2[:, :, :, tp_idx, :] + act_smooth_tp = act_smooth[:, :, tp_idx, :] + w1_scale_tp = w1_scale[:, :, tp_idx, :] + for ep_idx in range(ep_num): + start_expert_id = ep_idx * expert_size + w1_curr_tp_and_ep = w1_curr_tp[ep_idx].contiguous() + w2_curr_tp_and_ep = w2_curr_tp[ep_idx].contiguous() + input_smooth_curr_ep = input_smooth[ep_idx].contiguous() + act_smooth_curr_ep = act_smooth_tp[ep_idx].contiguous() + w1_scale_curr_ep = w1_scale_tp[ep_idx].contiguous() + w2_scale_curr_ep = w2_scale[ep_idx].contiguous() + tmo_out += ops.fused_moe(hidden_states, # [batch, seq, hidden_size] + router_logit, # [batch, seq, expert_num] + w1_curr_tp_and_ep, # [expert_size, inner_size*(1+gated)//tp_num, hidden_size] + w2_curr_tp_and_ep, # [expert_size, hidden_size, inner_size*(1+gated)//tp_num] + bias1, + bias2, + residual, + input_smooth_curr_ep, # [expert_size, hidden_size] + act_smooth_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num] + w1_scale_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num] + w2_scale_curr_ep, # [expert_size, hidden_size] + topk, + renormalize, + gated, + act_mode, + start_expert_id) + tmo_out1 += fused_moe(hidden_states, + router_logit, + w1_curr_tp_and_ep, + w2_curr_tp_and_ep, + bias1, + bias2, + residual, + input_smooth_curr_ep, + act_smooth_curr_ep, + w1_scale_curr_ep, + w2_scale_curr_ep, + topk, + renormalize, + gated, + act_mode, + start_expert_id) + torch_out += self.op_impl_base(hidden_states, + router_logit, + w1_curr_tp_and_ep, + w2_curr_tp_and_ep, + bias1, + bias2, + residual, + input_smooth_curr_ep, + act_smooth_curr_ep, + w1_scale_curr_ep, + w2_scale_curr_ep, + topk, + renormalize, + gated, + act_mode, + start_expert_id, + 0, 0, None, None) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True) + self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True) + + def test_smq_fused_moe_random_tp(self): + print("test_smq_fused_moe_random_tp") + import random + random.seed(0) + act_mode = 'gelu' + case_list = set() + for i in range(10): + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(512, 2048, 2) + inner_size = random.randrange(512, 2048, 2) + expert_num = random.randint(1, 40) + topk = random.randint(1,expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + data_type = random.choice([torch.bfloat16, torch.float16]) + if not mlu.is_bf16_supported(): + data_type = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode) + if case in case_list: + continue + case_list.add(case) + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + self.__run_sq_case(*case) + + def __run_sq_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode='gelu', start_expert_id=0, expert_size=-1): + if expert_size == -1: + expert_size = expert_num + scale_s = 0.01 # avoid the occurrence of inf + eps = 0.1 # Avoid the occurrence of nan + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s + residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps + act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps + bias1, bias2 = None, None + weight1_shape, weight2_shape = weight1.shape, weight2.shape + weight1 = weight1 / input_smooth.unsqueeze(1) + weight2 = weight2 / act_smooth.unsqueeze(1) + quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) + quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) + quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) + w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) + torch_out = self.op_impl_base(hidden_states, + router_logit, + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + w1_scale[start_expert_id:start_expert_id+expert_size], + w2_scale[start_expert_id:start_expert_id+expert_size], + topk, + renormalize, + gated, + act_mode, + start_expert_id, + 0, 0, None, None) + # (N, T, C) + tmo_out_1 = ops.fused_moe(hidden_states, + router_logit, + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + w1_scale[start_expert_id:start_expert_id+expert_size], + w2_scale[start_expert_id:start_expert_id+expert_size], + topk, + renormalize, + gated, + act_mode, + start_expert_id) + # # (N*T, C) + tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), + router_logit.view(-1, expert_num), + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual.view(-1, hidden_size) if residual is not None else None, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + w1_scale[start_expert_id:start_expert_id+expert_size], + w2_scale[start_expert_id:start_expert_id+expert_size], + topk, + renormalize, + gated, + act_mode, + start_expert_id).view(batch, seq, hidden_size) + tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size), + router_logit.view(-1, expert_num), + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual.view(-1, hidden_size) if residual is not None else None, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + w1_scale[start_expert_id:start_expert_id+expert_size], + w2_scale[start_expert_id:start_expert_id+expert_size], + topk, + renormalize, + gated, + act_mode, + start_expert_id).view(batch, seq, hidden_size) + + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) + self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True) + self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) + + def test_fused_moe_with_4D_w2(self): + print("test_fused_moe_with_4D_w2") + batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 + dtype_list = [torch.half] + if mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, True, True, 'silu', dtype + hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type) + router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) + residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) + weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) + bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) + + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + torch_out = self.op_impl_base(hidden_states, router_logit, + weight1, weight2, + bias1, bias2, residual, None, None, None, None, + topk, renormalize, gated, act_mode, 0, 0, 0, None, None) + # (N, T, C) + tmo_out_1 = ops.fused_moe(hidden_states, router_logit, + weight1, weight2, + bias1, bias2, residual, None, None, None, None, + topk, renormalize, gated, act_mode, 0) + # (N*T, C) + tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), + router_logit.view(-1, expert_num), + weight1, weight2.view(8, 4, hidden_size, inner_size).transpose(1,2).contiguous(), + bias1, bias2, + residual.view(-1, hidden_size) if residual is not None else None, + None, None, None, None, + topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size) + + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + + def test_moe_tp2_mixed_ep4_with_4D_w2(self): + print("test_moe_tp2_mixed_ep4_with_4D_w2") + tp_num, ep_num = 2, 4 + batch, seq, hidden_size, inner_size = 3, 5, 8192, 2048 + expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16 + assert inner_size % tp_num == 0 + assert expert_num % ep_num == 0 + expert_size = expert_num // ep_num + assert 4096 % inner_size == 0 + block_e = 4096 // inner_size + assert expert_size % block_e == 0 + if not mlu.is_bf16_supported(): + data_type = torch.float16 + hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type) + router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) + weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) + weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) + residual = None + bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type) + bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type) + w1 = weight1.reshape(expert_num, tp_num, -1, hidden_size) + w2 = weight2.reshape(expert_num, hidden_size, tp_num, -1) + + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True) + tmo_out = torch.zeros_like(hidden_states) + torch_out = torch.zeros_like(hidden_states) + for tp_idx in range(tp_num): + new_inner_size = w2.shape[-1] + w1_curr_tp = w1[:, tp_idx, ...] + w2_curr_tp = w2[:, :, tp_idx, :] + for ep_idx in range(ep_num): + start_expert_id = ep_idx * expert_size + w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous() + w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous() + torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep, + w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(), + bias1, bias2, residual, None, None, None, None, + topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None) + tmo_out += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep, + w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(), + bias1, bias2, residual, None, None, None, None, + topk, renormalize, gated, act_mode, start_expert_id) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + + def test_sq_fused_moe_with_4D_w2(self): + print("test_sq_fused_moe_with_4D_w2") + batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 + expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16 + if not mlu.is_bf16_supported(): + data_type = torch.float16 + assert 4096 % inner_size == 0 + block_e = 4096 // inner_size + assert expert_num % block_e == 0 + scale_s = 0.01 # avoid the occurrence of inf + eps = 0.1 # Avoid the occurrence of nan + hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type) + router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32) + weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type) + weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type) + residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + bias1, bias2 = None, None + input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps + act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps + weight1_shape, weight2_shape = weight1.shape, weight2.shape + weight1 = weight1 / input_smooth.unsqueeze(1) + weight2 = weight2 / act_smooth.unsqueeze(1) + quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) + quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) + quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) + w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) + torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1, quant_w2, + bias1, bias2, residual, input_smooth, act_smooth, + w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, 0, 0, None, None) + # (N, T, C) + tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1, quant_w2, + bias1, bias2, residual, input_smooth, act_smooth, + w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0) + # # (N*T, C) + tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size), + router_logit.view(-1, expert_num), + quant_w1, + quant_w2.view(-1,block_e,hidden_size, inner_size).transpose(1,2).contiguous(), + bias1, bias2, + residual.view(-1, hidden_size) if residual is not None else None, + input_smooth, act_smooth, + w1_scale, w2_scale, topk, renormalize, + gated, + act_mode, + 0).view(batch, seq, hidden_size) + + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True) + + def test_sq_fused_moe_random_tp_quant_grouped(self): + print("test_sq_fused_moe_random_tp_quant_grouped") + import random + random.seed(0) + act_mode = 'gelu' + case_list = set() + while (len(case_list) < 100): + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(512, 2048, 128) + inner_size = random.randrange(512, 2048, 512) + expert_num = random.randint(1, 40) + topk = random.randint(1, expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + quant_bit = random.choice([4, 8]) + data_type = random.choice([torch.bfloat16, torch.float16]) + if not mlu.is_bf16_supported(): + data_type = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode) + if case in case_list: + continue + case_list.add(case) + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True) + self.__run_quant_grouped_case(*case) + + def __run_quant_grouped_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode='gelu', start_expert_id=0, expert_size=-1, quant_group_size=128): + def get_quant_group(n, quant_group_size): + quant_group = n // quant_group_size + return quant_group if quant_group >= 1 else 1 + + if expert_size == -1: + expert_size = expert_num + scale_s = 0.01 # avoid the occurrence of inf + eps = 0.1 # Avoid the occurrence of nan + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s + residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps + act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps + bias1, bias2 = None, None + weight1_shape, weight2_shape = weight1.shape, weight2.shape + weight1 = weight1 / input_smooth.unsqueeze(1) + weight2 = weight2 / act_smooth.unsqueeze(1) + w1_quant_group = get_quant_group(hidden_size, quant_group_size) + w2_quant_group = get_quant_group(inner_size, quant_group_size) + quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), quant_bit, w1_quant_group) + quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), quant_bit, w2_quant_group) + quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) + if quant_bit == 4: + quant_w1, quant_w2 = PairlyPackInt8(quant_w1), PairlyPackInt8(quant_w2) + w1_scale = w1_scale.view(expert_num, -1, w1_quant_group).permute(2, 0, 1).contiguous() + w2_scale = w2_scale.view(expert_num, -1, w2_quant_group).permute(2, 0, 1).contiguous() + # split scale and transpose + def extract_scale(scale, start_expert, expert_size): + return scale[:, start_expert:start_expert+expert_size, :].contiguous() + + torch_out = self.op_impl_base(hidden_states, + router_logit, + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + extract_scale(w1_scale, start_expert_id, expert_size), + extract_scale(w2_scale, start_expert_id, expert_size), + topk, + renormalize, + gated, + act_mode, + start_expert_id, + 0, 0, None, None) + # (N, T, C) + tmo_out_1 = ops.fused_moe(hidden_states, + router_logit, + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + extract_scale(w1_scale, start_expert_id, expert_size), + extract_scale(w2_scale, start_expert_id, expert_size), + topk, + renormalize, + gated, + act_mode, + start_expert_id) + tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size), + router_logit.view(-1, expert_num), + quant_w1[start_expert_id:start_expert_id+expert_size], + quant_w2[start_expert_id:start_expert_id+expert_size], + bias1, + bias2, + residual.view(-1, hidden_size) if residual is not None else None, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + extract_scale(w1_scale, start_expert_id, expert_size), + extract_scale(w2_scale, start_expert_id, expert_size), + topk, + renormalize, + gated, + act_mode, + start_expert_id).view(batch, seq, hidden_size) + + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) + self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True) + + def __run_w4w8_mixed_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode='gelu', start_expert_id=0, expert_size=-1): + if expert_size == -1: + expert_size = expert_num + w1_quant_group = hidden_size // quant_wise + w2_quant_group = inner_size // quant_wise + w1_quant_flag = torch.randint(1, 3, (expert_num, w1_quant_group), dtype=torch.int32) * 4 + w2_quant_flag = torch.randint(1, 3, (expert_num, w2_quant_group), dtype=torch.int32) * 4 + w1_count = (w1_quant_flag.sum().item() // 4) * (quant_wise // 2) * inner_size*(1+gated) + w2_count = (w2_quant_flag.sum().item() // 4) * (quant_wise // 2) * hidden_size + w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8) + w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8) + hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) + router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32) + input_smooth = torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05) + act_smooth = torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05) + bias1, bias2 = None, None + w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05) + w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated) + w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0) + w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size + w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0) + + # split scale and transpose + def extract_scale(scale, start_expert, expert_size): + return scale[:, start_expert:start_expert+expert_size, :].contiguous() + + params = [hidden_states, router_logit, + w1[w1_offset_cu[start_expert_id]:w1_offset_cu[start_expert_id+expert_size]], + w2[w2_offset_cu[start_expert_id]:w2_offset_cu[start_expert_id+expert_size]], + bias1, bias2, residual, + input_smooth[start_expert_id:start_expert_id+expert_size], + act_smooth[start_expert_id:start_expert_id+expert_size], + extract_scale(w1_scale, start_expert_id, expert_size), + extract_scale(w2_scale, start_expert_id, expert_size), + topk, renormalize, gated, act_mode, start_expert_id, 0, 0, + w1_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist(), + w2_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist()] + torch_out = self.op_impl_base(*params) + # (N, T, C) + tmo_out_1 = ops.fused_moe(*params) + tmo_out_2 = fused_moe(*params) + self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True) + self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True) + + def test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed(self): + print("test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed") + import random + random.seed(0) + act_mode = 'gelu' + case_list = set() + while (len(case_list) < 100): + batch = random.randint(1, 10) + seq = random.randint(1, 10) + hidden_size = random.randrange(1024, 3072, 512) + inner_size = random.randrange(1024, 3072, 512) + expert_num = random.randint(1, 40) + topk = random.randint(1, expert_num) + gated = random.choice([True, False]) + renormalize = random.choice([True, False]) + quant_wise = random.choice([128, 256, 512]) + data_type = random.choice([torch.bfloat16, torch.float16]) + if not mlu.is_bf16_supported(): + data_type = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode) + if case in case_list: + continue + case_list.add(case) + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True) + self.__run_w4w8_mixed_case(*case, 0, -1) + + def test_sq_fused_moe_single_quant_group(self): + print("test_sq_fused_moe_single_quant_group") + import random + random.seed(0) + act_mode = 'gelu' + batch = 9 + seq = 10 + hidden_size = 1664 + inner_size = 512 + expert_num = 15 + topk = 2 + gated = False + renormalize = False + quant_bit = 4 + data_type = torch.float16 + case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode) + print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \ +topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True) + self.__run_quant_grouped_case(*case) + + def test_inductor(self): + batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024 + expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', torch.float16 + start_expert_id, expert_size = 0, 8 + hidden_states = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type) + router_logit = torch.randn(batch * seq, expert_num, device="mlu", dtype=torch.float32) + residual = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type) + weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) + weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) + args = (hidden_states, router_logit, + weight1[start_expert_id:start_expert_id+expert_size], + weight2[start_expert_id:start_expert_id+expert_size], + None, None, residual, None, None, None, None, None, None, + topk, renormalize, gated, act_mode, 0, 0, 0) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args) + + eps = 1e-5 + input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps + act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps + weight1_shape, weight2_shape = weight1.shape, weight2.shape + weight1 = weight1 / input_smooth.unsqueeze(1) + weight2 = weight2 / act_smooth.unsqueeze(1) + quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8) + quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8) + quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape) + w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1) + args = (hidden_states, router_logit, + quant_w1, quant_w2, None, None, residual, + input_smooth, act_smooth, w1_scale, w2_scale, + None, None, topk, renormalize, gated, act_mode, 0, 0, 0) + self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args) + +if __name__ == '__main__': + exit(run_unittest(TestFusedMOEOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_add_bias_activation.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_add_bias_activation.py new file mode 100755 index 0000000..07ff771 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_add_bias_activation.py @@ -0,0 +1,242 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from common_utils import * +import random +from itertools import product +import time +import os + +def gen_data(num_expert, + total_tokens, + inner_size, + output_stride, + dtype, + is_gated, + has_bias, + is_ep): + ci = inner_size * (1 + is_gated) + input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu') + cusum_token_count, token_count = generate_token_count(num_expert, total_tokens) + output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu') + output.as_strided(output.size(), (output_stride, 1)) + start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0 + expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert + bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None + return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size + +class TestMoeActiveKernel(BtTestCase): + def op_impl_base(self, *args): + input, act_mode, is_gated, output, bias, cusum_token_count, start_expert_id, expert_size = args + act_fun = torch.nn.functional.gelu if act_mode == 'gelu' else torch.nn.functional.silu + total_token_num = input.size(0) + inner_size = input.size(1) // 2 if is_gated else input.size(1) + input_ = input.clone() + if bias is not None: + deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id] + input_ = input_[:deal_token_num, :] + token_count = cusum_token_count[1:] - cusum_token_count[:-1] + token_count_ = token_count[start_expert_id:start_expert_id+expert_size].tolist() + input_list = list(input_.split(token_count_)) + for i in range(expert_size): + input_list[i] += bias[i] + input_ = torch.cat(input_list, dim=0) + if cusum_token_count.size(0) - 1 != expert_size: + pad = torch.zeros(total_token_num-deal_token_num, input_.size(-1)).to(input_.dtype).mlu() + input_ = torch.cat((input_, pad), dim=0) + acted = act_fun(input_[:, :inner_size]) + acted = acted * input_[:, inner_size:] if is_gated else acted + if output is None: + return acted + else: + return output.copy_(acted) + + # 功能测试 + def test_functional(self): + num_expert_list = [5, 32] + total_tokens_list = [64, 1024] + inner_size_list = [1024, 8192] + dtype_list = [torch.half, torch.float32] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + is_gated_list = [True, False] + is_ep_list = [True, False] + # change True when test + has_bias_list = [False, False, True] + act_mode_list = ['silu', 'gelu'] + args = product(num_expert_list, total_tokens_list, inner_size_list, dtype_list, + is_gated_list, is_ep_list, has_bias_list, act_mode_list) + + for num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode in args: + print("===============================================================================") + print(f"num_expert: {num_expert}, total_tokens: {total_tokens}") + print(f"inner_size: {inner_size}, dtype: {dtype}, is_gated: {is_gated}") + print(f"is_ep: {is_ep}, has_bias: {has_bias}, act_mode: {act_mode}") + + input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None + base_output = torch.empty_like(output) + ops.moe_active(input, + act_mode, + is_gated, + output, + bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size) + deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id] + output = output[:deal_token_num, :] + self.op_impl_base(input, + act_mode, + is_gated, + base_output, + bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size) + base_output = base_output[:deal_token_num, :] + self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inplace(self): + num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \ + 32, 80, 1024, torch.half, True, False, False, 'gelu' + input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + input_bak = input.clone() + # test output inpalce with stride + output = input.as_strided((total_tokens, inner_size), (inner_size * 2, 1)) + base_output = torch.empty_like(output) + ops.moe_active(input, + act_mode, + is_gated, + output, + bias, + None, + start_expert_id, + expert_size) + self.op_impl_base(input_bak, act_mode, is_gated, base_output, bias, None, start_expert_id, expert_size) + self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + # 随机遍历测试 + def test_random(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + + for i in range(500): + num_expert = random.randint(1, 64) + total_tokens = random.randint(1, 32768) + inner_size = random.randint(1, 8192) + dtype = random.sample(dtype_list, 1)[0] + is_gated = random.sample([True, False], 1)[0] + is_ep = random.sample([True, False], 1)[0] + # change True when test + has_bias = random.sample([False, True], 1)[0] + act_mode = random.sample(['gelu', 'silu'], 1)[0] + print("===============================================================================") + print(f"[{i}]: num_expert: {num_expert}, total_tokens: {total_tokens}") + print(f" inner_size: {inner_size}, dtype: {dtype}, is_gated: {is_gated}") + print(f" is_ep: {is_ep}, has_bias: {has_bias}, act_mode: {act_mode}") + + input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None + base_output = torch.empty_like(output) + ops.moe_active(input, + act_mode, + is_gated, + output, + bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size) + deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id] + output = output[:deal_token_num, :] + self.op_impl_base(input, + act_mode, + is_gated, + base_output, + bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size) + base_output = base_output[:deal_token_num, :] + self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_single(self): + num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \ + 32, 64, 8192, torch.float, True, False, False, 'gelu' + + input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None + base_output = torch.empty_like(output) + ops.moe_active(input, + act_mode, + is_gated, + output, + bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size) + deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id] + output = output[:deal_token_num, :] + + self.op_impl_base(input, act_mode, is_gated, base_output, bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, + expert_size) + base_output = base_output[:deal_token_num, :] + self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func = ops.moe_active + num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \ + 32, 64, 8192, torch.float, True, False, False, 'gelu' + + input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None + act_mode = 'abc' + self.assertException("act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.", func, + input, act_mode, is_gated, output, bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, expert_size) + act_mode = 'gelu' + self.assertException("input.dim() >= 2", func, + input.reshape(-1), act_mode, is_gated, output, bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, expert_size) + is_gated = True + self.assertException("in_channel % 2 == 0 if is_gated is true", func, + torch.randn(total_tokens, inner_size * 2 - 1, dtype=dtype, device='mlu'), + act_mode, is_gated, output, bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + start_expert_id, expert_size) + + def test_inductor(self): + num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \ + 32, 64, 8192, torch.float, True, False, True, 'gelu' + + input, bias, _, cusum_token_count, output, start_expert_id, expert_size = \ + gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep) + bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None + args = (input, output, bias_real, + cusum_token_count.mlu() if has_bias or is_ep else None, + act_mode, is_gated, start_expert_id, expert_size) + self.base_opcheck(torch.ops.torch_mlu_ops.active, args) + +if __name__ == '__main__': + random.seed(0) + torch.manual_seed(0) + exit(run_unittest(TestMoeActiveKernel)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_cast_gating.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_cast_gating.py new file mode 100644 index 0000000..0608c33 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_cast_gating.py @@ -0,0 +1,77 @@ +import torch +import unittest +import torch_mlu_ops as ops +import random +from common_utils import * +from itertools import product +import copy +from typing import Optional + +class TestMoeCastGating(BtTestCase): + def op_impl_base(self, *args): + input, weight = args + input = input.to(torch.float) + output = torch.matmul(input, weight.permute(1, 0)) + return output + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device") + def test_moe_cast_gating_random(self): + for _ in range(1000): + total_seq = random.randint(1, 32768) + hidden_size = random.randint(1, 16384) + expert_num = random.randint(1, 128) + input_dtype_list = [torch.half] + if torch_mlu.mlu.is_bf16_supported(): + input_dtype_list.append(torch.bfloat16) + input_dtype = random.choice(input_dtype_list) + weight_dtype = torch.float + print("total_seqlen={}, hidden_size={}, expert_num={}, input_dtype={}, testing...".format( + total_seq, hidden_size, expert_num, input_dtype)) + input = torch.randn(total_seq, hidden_size, dtype=input_dtype, device="mlu") + weight = torch.randn(expert_num, hidden_size, dtype=weight_dtype, device="mlu") + + tmo_out = ops.moe_cast_gating(input, weight) + torch_out = self.op_impl_base(input, weight) + + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 1e-4, + use_MSE=True, use_RAE=True) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func = ops.moe_cast_gating + input = torch.randn(1024, 8192, dtype=torch.half, device="mlu") + input = input.as_strided(input.shape, (100, 1)) + weight = torch.randn(128, 8192, dtype=torch.float, device="mlu") + self.assertException("input must be contiguous", func, input, weight) + input = input.contiguous() + weight = weight.as_strided(weight.shape, (100, 1)) + self.assertException("weight must be contiguous", func, input, weight) + weight = weight.contiguous() + weight = weight.reshape(1, 128, 8192) + self.assertException("weight.dim() == 2", func, input, weight) + weight = torch.randn(128, 2048, dtype=torch.float, device="mlu") + self.assertException("input.size(-1) == weight.size(-1)", func, input, weight) + weight = torch.randn(128, 8192, dtype=torch.half, device="mlu") + self.assertException("weight type need be torch::kFloat32", func, input, weight) + weight = weight.to(torch.float) + input = input.to(torch.float) + self.assertException("input type need be torch::kFloat16 or torch::kBFloat16", func, input, weight) + input = torch.randn(1024, 16388, dtype=torch.half, device="mlu") + weight = torch.randn(128, 16388, dtype=torch.float, device="mlu") + self.assertException("hidden_size > 0 && hidden_size <= 16384", func, input, weight) + input = torch.randn(1024, 16384, dtype=torch.half, device="mlu") + weight = torch.randn(129, 16384, dtype=torch.float, device="mlu") + self.assertException("expert_num > 0 && expert_num <= 128", func, input, weight) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device") + def test_inductor(self): + m, hidden_size, expert_num, input_dtype = 1024, 4096, 32, torch.half + input = torch.randn(m, hidden_size, dtype=input_dtype, device="mlu") + weight = torch.randn(expert_num, hidden_size, dtype=torch.float, device="mlu") + args = (input, weight) + self.base_opcheck(torch.ops.torch_mlu_ops.moe_cast_gating, args) + +if __name__ == "__main__": + if "MLU3" not in torch.mlu.get_device_name(): + exit(run_unittest(TestMoeCastGating)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_combine_result.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_combine_result.py new file mode 100644 index 0000000..b5a6c19 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_combine_result.py @@ -0,0 +1,242 @@ +import torch +import unittest +import torch_mlu_ops as ops +import random +from common_utils import * +from itertools import product +import copy +from typing import Optional + +def generate_token_count(num_expert, + total_token_count): + token_count = torch.randint(low=1, high=1024, size=(num_expert, ), \ + dtype=torch.int32).to(dtype=torch.float32) + sum = torch.sum(token_count, dim=-1) * 1.0 + token_count *= total_token_count / sum.item() + token_count = token_count.to(dtype=torch.int32) + cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32) + end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count + cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) + cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count) + cusum_token_count[-1] = total_token_count + return cusum_token_count + + +def gen_case(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + has_bias, + has_residual, + dtype, + device): + input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device) + reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device) + gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device) + bias = None + residual = None + + cusum_token_count = None + + if has_bias: + bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device) + if has_residual: + residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device) + + if has_bias or expert_size < num_expert: + cusum_token_count = generate_token_count(num_expert, num_tokens * topk) + cusum_token_count = cusum_token_count.to(device=device) + return input, reduce_weight, gather_ids, residual, bias, cusum_token_count + +class TestCombineResult(BtTestCase): + def op_impl_base(self, *args): + input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, \ + expert_size, bias = args + input = input.to(dtype=torch.float32).cpu() + reduce_weight = reduce_weight.cpu() + gather_ids = gather_ids.cpu() + bias = None + + dtype = input.dtype + num_expert = expert_size + hidden_size = input.shape[1] + + if bias is not None: + bias = bias.to(dtype=torch.float32).cpu() + num_expert = bias.shape[0] + if cusum_token_count is not None: + num_expert = cusum_token_count.shape[0] - 1 + cusum_token_count = cusum_token_count.cpu() + + if cusum_token_count is not None and expert_size < num_expert: + gathered_input = input[gather_ids - cusum_token_count[start_expert_id].item()] + else: + gathered_input = input[gather_ids] + + if bias is not None and cusum_token_count is not None: + for i in range(start_expert_id, start_expert_id + expert_size): + gathered_input[cusum_token_count[i] : cusum_token_count[i+1]] += bias[i] + gathered_input = gathered_input.reshape(*reduce_weight.shape, hidden_size) + + if cusum_token_count is not None: + filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \ + (gather_ids < cusum_token_count[start_expert_id + expert_size]) + filtered_ids = filtered_ids.to(dtype=torch.float32) + reduce_weight = reduce_weight * filtered_ids.reshape(reduce_weight.shape) + + gathered_input *= reduce_weight.reshape(*reduce_weight.shape, -1) + output = torch.sum(gathered_input, dim=1, keepdim=False) + if residual is not None: + residual = residual.to(dtype=torch.float32).cpu() + output += residual + return output.to(dtype=dtype) + + def test_random_case(self): + torch.manual_seed(444) + test_cases = 200 + num_tokens_list = torch.randint(low=1, high=2048, size=(test_cases, ), dtype=torch.int32) + topk_list = torch.randint(low=1, high=33, size=(test_cases, ), dtype=torch.int32) + hidden_size_list = torch.randint(low=256, high=8193, size=(test_cases, ), dtype=torch.int32) + num_expert_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32) + num_expert_list = torch.maximum(topk_list, num_expert_list) + expert_size_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32) + expert_size_list = torch.minimum(expert_size_list, num_expert_list) + start_expert_id_list = torch.randint(low=0, high=129, size=(test_cases, ), dtype=torch.int32) + start_expert_id_list = torch.minimum(start_expert_id_list, num_expert_list - expert_size_list) + start_expert_id_list = torch.maximum(start_expert_id_list, torch.zeros(test_cases, dtype=torch.int32)) + has_bias_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + has_residual_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + dtype_list = torch.randint(low=0, high=10, size=(test_cases, ), dtype=torch.int32) + dtype_map = [torch.half, torch.bfloat16, torch.float] + + device = 'mlu' + mlu_name = torch.mlu.get_device_name() + is_mlu370 = "MLU3" in mlu_name + max_num_tokens = 128 * 1024 + for i in range(test_cases): + num_tokens = num_tokens_list[i].item() + topk = topk_list[i].item() + hidden_size = hidden_size_list[i].item() + num_expert = num_expert_list[i].item() + expert_size = expert_size_list[i].item() + start_expert_id = start_expert_id_list[i].item() + has_bias = False # has_bias_list[i].item() + has_residual = has_residual_list[i].item() + + topk = min(topk, (int)(max_num_tokens / num_tokens)) + + if dtype_list[i].item() < 4: + dtype = dtype_map[0] + elif dtype_list[i].item() < 9: + dtype = dtype_map[1] + else: + dtype = dtype_map[2] + + if is_mlu370 and dtype is torch.bfloat16: + continue + + inputs = gen_case(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + has_bias, + has_residual, + dtype, + device) + input = inputs[0] + reduce_weight = inputs[1] + gather_ids = inputs[2] + residual = inputs[3] + bias = inputs[4] + cusum_token_count = inputs[5] + + print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, " + "start_expert_id={}, has_bias={}, has_residual={}, dtype={}, testing...".format( + num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \ + has_bias, has_residual, dtype)) + + golden_output = self.op_impl_base(input, reduce_weight, gather_ids, residual, + cusum_token_count, start_expert_id, expert_size, None) + + output = ops.moe_combine_result(input, reduce_weight, gather_ids, residual, + cusum_token_count, start_expert_id, expert_size) + + self.assertTensorsEqual(output.cpu().float(), golden_output.cpu().float(), 0.003, + "golden_output must equal output", True, True, True, True) + + + def test_perf_case(self): + num_tokens_list = [1, 72, 512] + hidden_size_list = [2048, 4096, 5120, 8192] + # [num_expert, topk, start_expert_id, expert_size] + expert_options_list = [[1, 1, 0, 1], [8, 2, 0, 8], [32, 5, 0, 32], [32, 5, 24, 8]] + has_residual_list = [True, False] + dtype_list = [torch.half, torch.bfloat16] + + device = 'mlu' + mlu_name = torch.mlu.get_device_name() + is_mlu370 = "MLU3" in mlu_name + args = product(num_tokens_list, hidden_size_list, expert_options_list,\ + has_residual_list, dtype_list) + for num_tokens, hidden_size, expert_options, has_residual, dtype in args: + num_expert = expert_options[0] + topk = expert_options[1] + start_expert_id = expert_options[2] + expert_size = expert_options[3] + has_bias = False + + if is_mlu370 and dtype is torch.bfloat16: + continue + + torch.manual_seed(444) + inputs = gen_case(num_tokens, + topk, + hidden_size, + num_expert, + expert_size, + has_bias, + has_residual, + dtype, + device) + input = inputs[0] + reduce_weight = inputs[1] + gather_ids = inputs[2] + residual = inputs[3] + bias = inputs[4] + cusum_token_count = inputs[5] + + golden_output = self.op_impl_base(input, reduce_weight, gather_ids, residual, + cusum_token_count, start_expert_id, expert_size, None) + + notify_start = torch.mlu.Event(enable_timing=True) + notify_end = torch.mlu.Event(enable_timing=True) + notify_start.record() + loop = 10 + for _ in range(loop): + output = ops.moe_combine_result(input, reduce_weight, gather_ids, residual, + cusum_token_count, start_expert_id, expert_size) + notify_end.record() + notify_end.synchronize() + time = notify_start.hardware_time(notify_end) / loop + + print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, " + "start_expert_id={}, has_bias={}, has_residual={}, dtype={}, time={:.1f}".format( + num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \ + has_bias, has_residual, dtype, time)) + + self.assertTensorsEqual(output.cpu().float(), golden_output.cpu().float(), 0.003, + "golden_output must equal output", True, True, True, True) + + def test_inductor(self): + num_tokens, hidden_size, has_bias, has_residual, dtype = 1, 2048, False, True, torch.float16 + num_expert, topk, start_expert_id, expert_size = 8, 2, 0, 8 + input, reduce_weight, gather_ids, residual, \ + bias, cusum_token_count = gen_case(num_tokens, topk, hidden_size, num_expert, expert_size, + has_bias, has_residual, dtype, 'mlu') + args = (input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias) + self.base_opcheck(torch.ops.torch_mlu_ops.moe_combine_result, args) + +if __name__ == '__main__': + exit(run_unittest(TestCombineResult)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_expand_input.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_expand_input.py new file mode 100644 index 0000000..b57388c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_expand_input.py @@ -0,0 +1,132 @@ +import torch +import torch_mlu_ops +import unittest +import torch_mlu_ops +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from common_utils import * +import random +from itertools import product +from typing import Tuple, Optional +import os + +class TestExpandInput(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + input = create_tensor_from_dic(dic['input']) + gather_idx = dic['gather_idx']['data'] + cusum_token_count = dic['cusum_token_count']['data'] + start_expert_id = dic['start_expert_id']['data'] + expert_size = dic['expert_size']['data'] + self.launch(input, gather_idx, cusum_token_count, start_expert_id, expert_size) + + def launch(self, *args): + cusum_token_count = args[2] + start_expert_id = args[3] + expert_size = args[4] + real_token_count = None if cusum_token_count is None else cusum_token_count[start_expert_id+expert_size-1] - cusum_token_count[start_expert_id] + base_out = self.op_impl_base(*args) + tmo_out = torch_mlu_ops.moe_expand_input(*args) + self.assertTensorsEqual(base_out[:real_token_count].cpu().float(), tmo_out[:real_token_count].cpu().float(), + 0.00, use_MSE=True, use_RAE=True) + + def op_impl_base(self, *args): + input, gather_idx, cusum_token_count, start_expert_id, expert_size = args + if cusum_token_count is None: + return input[gather_idx] + else: + idx = gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id + expert_size]] + return input[idx] + + def get_tensor(self, token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype): + input = torch.randn(token_num, hidden_size, device='mlu').to(dtype) + gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,), dtype=torch.int32, device='mlu') + cusum_token_count, _ = generate_token_count(expert_num, token_num * topk) + cusum_token_count = cusum_token_count.to('mlu') + use_all_experts = expert_num == expert_size + if use_all_experts: + cusum_token_count = None + real_token_count = token_num * topk + else: + real_token_count = cusum_token_count[start_expert_id+expert_size-1] - cusum_token_count[start_expert_id] + return input, gather_idx, cusum_token_count, real_token_count + + def test_kernel_random(self): + for i in range(100): + token_num = random.randint(1, 2048) + hidden_size = random.randint(1, 4096) + expert_num = random.randint(1, 32) + topk = random.randint(1, expert_num) + start_expert_id = random.randint(0, expert_num-1) + expert_size = random.randint(1, expert_num-start_expert_id) + dtype_list = [torch.half, torch.float, torch.int8] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + dtype = random.sample(dtype_list, 1)[0] + print("===============================================================================") + print(f"[{i}]: token_num: {token_num}, hidden_size: {hidden_size}, expert_num: {expert_num}") + print(f" topk: {topk}, start_expert_id: {start_expert_id}, expert_size: {expert_size}, dtype: {dtype}") + input, gather_idx, cusum_token_count, real_token_count = self.get_tensor(token_num, + hidden_size, expert_num, topk, start_expert_id, expert_size, dtype) + base_expand_hidden_states = self.op_impl_base(input, gather_idx, None, 0, 0) + expand_hidden_states = torch_mlu_ops.moe_expand_input(input, gather_idx) + self.assertTensorsEqual(base_expand_hidden_states.cpu().float(), expand_hidden_states.cpu().float(), + 0.00, use_MSE=True, use_RAE=True) + del base_expand_hidden_states, expand_hidden_states + base_expand_hidden_states = self.op_impl_base(input, gather_idx, cusum_token_count, start_expert_id, expert_size)[:real_token_count] + expand_hidden_states = torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count, + start_expert_id, expert_size)[:real_token_count] + self.assertTensorsEqual(base_expand_hidden_states.cpu().float(), expand_hidden_states.cpu().float(), + 0.00, use_MSE=True, use_RAE=True) + del base_expand_hidden_states, expand_hidden_states + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func = torch_mlu_ops.moe_expand_input + token_num, hidden_size, topk, expert_num, start_expert_id, expert_size, dtype = 5, 5, 6, 32, 0, 16, torch.half + input, gather_idx, cusum_token_count, _ = self.get_tensor(token_num, hidden_size, expert_num, \ + topk, start_expert_id, expert_size, dtype) + input = torch.randn(token_num, 1, hidden_size).to(dtype).to('mlu') + self.assertException("input dim must be equal to 2.", func, input, gather_idx) + input = torch.randn(token_num, hidden_size).to(dtype).to('mlu') + gather_idx = torch.randint(0, token_num, (token_num * topk, 1)).to(torch.int32).to('mlu') + self.assertException("gather_idx dim must be equal to 1.", func, input, gather_idx) + gather_idx = gather_idx.reshape(token_num * topk) + self.assertException("input tensor must on mlu.", func, input.cpu(), gather_idx) + self.assertException("gather_idx must on mlu.", func, input, gather_idx.cpu()) + self.assertException("data type of gather_idx must be int32.", func, input, gather_idx.to(torch.int64)) + input = torch.randn(token_num, hidden_size).to(dtype).to('mlu') + gather_idx = torch.randint(0, token_num, (8,)).to(torch.int32).to('mlu') + self.assertException("expand_token_num % token_num == 0.", func, input, gather_idx) + gather_idx = torch.randint(1, token_num, (token_num * topk,)).to(torch.int32).to('mlu') + self.assertException("cusum_token_count must on mlu.", func, input, gather_idx, cusum_token_count.cpu(), + start_expert_id, expert_size) + self.assertException("data type of cusum_token_count must be int32.", func, input, gather_idx, cusum_token_count.to(torch.int64), + start_expert_id, expert_size) + start_expert_id = -1 + self.assertException("start_expert_id >=0 && start_expert_id < expert_num.", + func, input, gather_idx, cusum_token_count, start_expert_id, expert_size) + start_expert_id = expert_num + self.assertException("start_expert_id >=0 && start_expert_id < expert_num.", + func, input, gather_idx, cusum_token_count, start_expert_id, expert_size) + start_expert_id = 16 + expert_size = 17 + self.assertException("start_expert_id + expert_size <= expert_num.", + func, input, gather_idx, cusum_token_count, start_expert_id, expert_size) + + def test_inductor(self): + dtype = torch.float16 + token_num, hidden_size, expert_num, topk, start_expert_id, expert_size = 64, 128, 16, 8, 3, 10 + input, gather_idx, cusum_token_count, _ = self.get_tensor(token_num, hidden_size, expert_num, \ + topk, start_expert_id, expert_size, dtype) + args = (input, gather_idx, cusum_token_count, start_expert_id, expert_size) + self.base_opcheck(torch.ops.torch_mlu_ops.moe_expand_input, args) + +if __name__ == "__main__": + random.seed(0) + torch.manual_seed(0) + exit(run_unittest(TestExpandInput)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_gen_idx.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_gen_idx.py new file mode 100644 index 0000000..e035a76 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_gen_idx.py @@ -0,0 +1,62 @@ +import torch +import torch_mlu_ops +import unittest +import torch_mlu_ops +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from common_utils import * +import random +from itertools import product +from typing import Tuple, Optional +import os + +def gen_args(index: int = 0): + token_num = random.randint(1, 32768) + expert_num = random.randint(1, 256) + topk = random.randint(1, expert_num) + expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu') + print("===============================================================================") + print(f"[{index}]: token_num: {token_num}, expert_num: {expert_num}, topk: {topk}") + return expert_id, expert_num + +class TestGenIdx(BtTestCase): + def op_impl_base(self, *args): + expert_id, expert_num = args + token_num, topk = expert_id.size(0), expert_id.size(1) + sorted_expert_id, indices = expert_id.int().flatten().sort() + expand_idx_out = indices.int() // topk + combine_idx_out = torch.zeros((token_num * topk,), dtype=torch.int, device="mlu") + combine_idx_out.scatter_(0, indices, torch.arange(token_num * topk, dtype=torch.int, device="mlu")) + token_count_out = torch.bincount(sorted_expert_id, minlength=expert_num) + cusum_token_count_out = torch.cat((torch.tensor([0]).to('mlu'), torch.cumsum(token_count_out, dim=0))).to(torch.int32) + return tuple([expand_idx_out, combine_idx_out, token_count_out, cusum_token_count_out]) + + def test_kernel_random(self): + for i in range(1500): + expert_id, expert_num = gen_args(i) + base_gather_expand_idx, base_gather_combine_idx, base_token_count, base_cusum_token_count = self.op_impl_base(expert_id, expert_num) + gather_expand_idx_out, gather_combine_idx_out, token_count_out, cusum_token_count_out = \ + torch_mlu_ops.moe_gen_idx(expert_id, expert_num) + self.assertTensorsEqual(base_gather_expand_idx.cpu(), gather_expand_idx_out.cpu(), 0.00, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_gather_combine_idx.cpu(), gather_combine_idx_out.cpu(), 0.00, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_token_count.cpu(), token_count_out.cpu(), 0.00, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_cusum_token_count.cpu(), cusum_token_count_out.cpu(), 0.00, use_MSE=True, use_RAE=True) + + def test_inductor(self): + args = gen_args() + self.base_opcheck(torch.ops.torch_mlu_ops.moe_gen_idx, args) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func = torch_mlu_ops.moe_gen_idx + token_num, expert_num, topk = 128, 32, 16 + expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk, 1)).to(torch.int32).to('mlu') + self.assertException("expert_id dim must be equal to 2.", func, expert_id, expert_num) + expert_id = expert_id.reshape(token_num, topk) + self.assertException("data type of expert_id must be int32.", func, expert_id.to(torch.int64), expert_num) + +if __name__ == "__main__": + random.seed(0) + torch.manual_seed(0) + exit(run_unittest(TestGenIdx)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_softmax_topk.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_softmax_topk.py new file mode 100644 index 0000000..d985a69 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_moe_softmax_topk.py @@ -0,0 +1,200 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from common_utils import * +import random +from itertools import product +import time +import numpy as np +import os + +class TestSoftmaxTopkOp(BtTestCase): + def op_impl_base(self, *args): + input, topk, normalize, num_expert_group, topk_group, origin_mask, normed_by = args + softmax = torch.softmax(input.float(), dim=-1) + if num_expert_group <= 1: + if origin_mask is not None: + softmax = softmax * origin_mask + reduce_weight, expert_id = torch.topk(softmax, k=topk, dim=-1) + if normalize: + if normed_by == "topk_logit": + reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True) + if normed_by == "softmax_logit": + reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True) + return reduce_weight, expert_id + else: + group_size = softmax.shape[-1] // num_expert_group + new_shape = softmax.shape[:-1] + (num_expert_group, group_size) + group_data = softmax.view(new_shape) + group_max_value = group_data.max(dim=-1).values + group_idx = torch.topk(group_max_value, k=topk_group, dim=-1)[1] + mask_shape = softmax.shape[:-1] + (num_expert_group,) + mask = torch.zeros((mask_shape), dtype = torch.bool, device = group_idx.device) + mask.scatter_(-1, group_idx, True) + mask = mask.unsqueeze(-1).expand(new_shape) + masked_data = group_data.masked_fill(~mask, 0.0) + masked_data = masked_data.reshape(softmax.shape) + reduce_weight, expert_id = torch.topk(masked_data, k=topk, dim=-1) + if normalize: + if normed_by == "topk_logit": + reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True) + if normed_by == "softmax_logit": + reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True) + return reduce_weight, expert_id + + # 接口测试 + def test_interface(self): + num_token, num_expert, topk, num_expert_group, topk_group, normalize = 1024, 160, 6, 10, 5, False + input = torch.randn(num_token, num_expert, dtype=torch.half, device='mlu') + mask = None + normed_by = "topk_logit" + base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by) + reduce_weight, expert_id = torch_mlu_ops.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group) + base_expert_id, _ = base_expert_id.sort() + expert_id, _ = expert_id.sort() + self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + + # 随机遍历测试 + def test_kernel_random(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for i in range(1000): + num_batch = random.randint(1, 64) + num_mask = random.randint(1, 1024) + num_expert = random.randint(1, 512) + factors = [i for i in range(1, num_expert + 1) if num_expert % i == 0] + num_expert_group = random.choice(factors) + topk_group = random.randint(1, num_expert_group) + topk = random.randint(1, num_expert / num_expert_group * topk_group) + normalize = random.sample([True, False], 1)[0] + dtype = random.sample(dtype_list, 1)[0] + group_invalid = random.sample([True, False], 1)[0] + if group_invalid: + num_expert_group = -1 + normed_by = random.choice(["topk_logit", "softmax_logit"]) + print("===============================================================================") + print(f"[{i}]: num_batch: {num_batch}, num_mask: {num_mask}, num_expert: {num_expert}, num_expert_group: {num_expert_group}") + print(f" topk_group: {topk_group}, topk: {topk}, normalize: {normalize}, dtype: {dtype}, normed_by: {normed_by}") + input = torch.randn(num_batch, num_mask, num_expert, dtype=dtype).mlu() + mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = dtype).mlu() + if num_expert_group > 1: + mask = None + base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by) + reduce_weight, expert_id = torch.ops.torch_mlu_ops.moe_softmax_topk(input, + topk, + num_expert_group, + topk_group, + normalize, + mask, + normed_by) + # softmax后的值可能因数值差异较小,造成topk的值存在顺序上的差异,例如: + # base_reduce_weight[N, 10:12] = [0.012, 0.011] + # reduce_weight[N, 10:12] = [0.011, 0.012] + # base_expert_id[N, 10:12] = [7, 8] + # expert_id[N, 10:12] = [8, 7] + # 这种情况产生的顺序上的差异不是错误的,但这样会造成结果对比错误,因此需要对结果先排序再对比结果 + base_expert_id, _ = base_expert_id.sort() + expert_id, _ = expert_id.sort() + self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + # 防呆测试 + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func = torch.ops.torch_mlu_ops.moe_softmax_topk + num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 2, 1024, 160, 0, 1, 1, True, "abc_logit" + input = torch.randn(num_batch, num_mask, num_expert, dtype = torch.half, device='mlu') + input_permute = input.permute(0, 2, 1) + mask = torch.randint(0, 2, (num_mask, num_expert), dtype = torch.half, device='mlu') + mask_permute = mask.permute(1, 0) + self.assertException("input must be contiguous.", + func, input_permute, topk, num_expert_group, topk_group, normalize, mask, normed_by) + self.assertException("mask must be contiguous.", + func, input, topk, num_expert_group, topk_group, normalize, mask_permute, normed_by) + self.assertException("normed_by must be 'topk_logit' or 'softmax_logit'", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + input = torch.randn(num_expert, dtype = torch.half, device = 'mlu') + normed_by = "softmax_logit" + self.assertException("input.dim() >= 2", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + input = torch.randn(num_batch, num_mask, num_expert, dtype = torch.half, device='mlu') + self.assertException("topk > 0 && topk <= num_expert", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + topk = 5 + self.assertException("the dim of mask should be the same as input", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + mask = torch.randint(0, 2, (1, num_mask, num_expert - 1), dtype = torch.half, device='mlu') + self.assertException("the last dim of mask should be the same as the last dim of input", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + mask = torch.randint(0, 2, (1, num_mask - 1, num_expert), dtype = torch.half, device='mlu') + self.assertException("the penultimate dim of mask should be the same as the penultimate dim of input", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + mask = torch.randint(0, 2, (2, num_mask, num_expert), dtype = torch.half, device='mlu') + self.assertException("the product of all but the lower two dimensions of mask is 1", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.float, device='mlu') + self.assertException("the dtype of mask should be the same as input", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.half, device='mlu') + num_expert_group = 8 + self.assertException("if num_expert_group > 1, mask should be None", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + mask = None + topk = 160 + self.assertException("topk <= (num_expert / num_expert_group) * topk_group", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + + num_expert_group = 11 + self.assertException("num_expert % num_expert_group == 0", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + + num_expert_group = 8 + topk_group = 9 + self.assertException("topk_group > 0 && topk_group <= num_expert_group", + func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + + # 单条测例 + def test_single(self): + num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 32, 16, 128, 34, 1, 1, True, "softmax_logit" + input = torch.randn(2, num_batch, num_mask, num_expert, dtype = torch.float32, device='mlu') + mask = torch.randint(0, 2, (1, 1, num_mask, num_expert), dtype = torch.float32, device='mlu') +# mask = None + base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by) + reduce_weight, expert_id = torch.ops.torch_mlu_ops.moe_softmax_topk(input, + topk, + num_expert_group, + topk_group, + normalize, + mask, + normed_by) + base_expert_id, _ = base_expert_id.sort() + expert_id, _ = expert_id.sort() + self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 32, 16, 128, 34, 4, 3, True, "softmax_logit" + input = torch.randn(num_batch, num_mask, num_expert, dtype=torch.half, device='mlu') + mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.half, device='mlu') + normed_by = "softmax_logit" + if num_expert_group > 1: + mask = None + args = (input, topk, num_expert_group, topk_group, normalize, mask, normed_by) + self.base_opcheck(torch.ops.torch_mlu_ops.moe_softmax_topk, args) + +if __name__ == '__main__': + random.seed(0) + torch.manual_seed(0) + exit(run_unittest(TestSoftmaxTopkOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_linear_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_linear_cache.py new file mode 100755 index 0000000..21f38f0 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_linear_cache.py @@ -0,0 +1,194 @@ +import torch +import unittest +import torch_mlu_ops as ops +import random +from common_utils import * + +def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype): + args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True, True) + cache_scale = args[10] + args = args[0:10] + args.insert(4, cache_scale[0]) + args.insert(5, cache_scale[1]) + args.insert(8, False) + return args + +class TestOfflineQuantToLinearCache(BtTestCase): + def op_impl_base(self, *args): + def quant2Int8(input_fp: torch.Tensor, + scale_fp: torch.tensor, + quant_mode: torch.int64): + head_num = input_fp.size(0) + seq = input_fp.size(1) + head_size = input_fp.size(2) + input_fp32 = input_fp.to(torch.float32) + if quant_mode == 0: # per_channel + scale = scale_fp.reshape((head_num, 1, head_size)) + else: + scale = scale_fp.reshape((head_num, seq, 1)) + scaled_context = input_fp32 / scale + rounded = torch.round(scaled_context) + clipped = torch.clip(rounded, -128, 127) + return clipped.to(torch.int8) + + key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, \ + context_lengths, max_context_len, quant_mode, packed, context_seq_offset, cache_bs_id, \ + cache_seqlen_offset = args + batch_size = context_lengths.size(0) - 1 if packed else context_lengths.size(0) + for i in range(batch_size): + if packed: + key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) + value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None + context_len_i = context_lengths[i+1] - context_lengths[i] + context_seq_offset_i = 0 + else: + key_i = key[i].transpose(1, 0) + value_i = value[i].transpose(1, 0) if value is not None else None + context_len_i = context_lengths[i] + context_seq_offset_i = context_seq_offset[i] if context_seq_offset is not None else 0 + cache_bs_id_i = cache_bs_id[i] if cache_bs_id is not None else i + cache_seq_begin = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0 + if cache_bs_id_i < 0 or cache_seq_begin < 0: + continue + cache_seq_end = cache_seq_begin + context_len_i + + # quant key to int8 + key_i = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i] + if quant_mode == 0: + key_cache_scale_i = key_cache_quant_scale + else: + key_cache_scale_i = key_cache_quant_scale[:, cache_seq_begin:cache_seq_end] + quant_key_i = quant2Int8(key_i, key_cache_scale_i, quant_mode) + key_cache_i = key_cache[cache_bs_id_i, :, cache_seq_begin:cache_seq_end] + key_cache_i[...] = quant_key_i + + # quant value to int8 + if value_cache is not None and value is not None and value_cache_quant_scale is not None: + value_i = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i] + if quant_mode == 0: + value_cache_scale_i = value_cache_quant_scale + else: + value_cache_scale_i = value_cache_quant_scale[:, cache_seq_begin:cache_seq_end] + quant_value_i = quant2Int8(value_i, value_cache_scale_i, quant_mode) + value_cache_i = value_cache[cache_bs_id_i, :, cache_seq_begin:cache_seq_end] + value_cache_i[...] = quant_value_i + return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale) + + def test_offline_quant_to_linear_cache(self): + test_cases = 100 + bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + num_heads_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32) + head_size_list *= 16 + cache_memory_len_list = torch.randint(low=2, high=1024, size=(test_cases, ), dtype=torch.int32) + packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + mode_list= torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32) + dtype_map = [torch.half, torch.bfloat16, torch.float32] + + for i in range(test_cases): + q_heads = 1 + batch_size = bs_list[i].item() + invalid_batch = batch_size // 10 + num_heads = num_heads_list[i].item() + head_size = head_size_list[i].item() + cache_memory_len = cache_memory_len_list[i].item() + packed = packed_list[i].item() + quant_mode = mode_list[i].item() + total_heads = q_heads + num_heads * 2 + dtype = dtype_map[dtype_list[i]] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name and dtype == torch.bfloat16: + dtype = torch.half + print("BFLOAT16 is not support on {}, use half instead".format(mlu_name)) + + print("batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, mode ={}, dtype={} testing...".format( + batch_size, num_heads, head_size, cache_memory_len, packed > + 0, quant_mode, dtype)) + + torch.manual_seed(1) + max_bs = batch_size + 1 + context_lens = torch.randint(size=(batch_size, ), low=1, + high=cache_memory_len // 2, + dtype=torch.int32, device='mlu') + max_context_len = context_lens.max().item() + max_seq_offset = max_context_len // 3 + 1 + context_seq_offsets = torch.randint(size=(batch_size, ), + low=0, high=max_seq_offset, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch_size, ), low=0, + high=(cache_memory_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1 + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + if packed > 0: + context = torch.randn((total_seqlen, total_heads, head_size), + dtype=torch.float, device='mlu') + else: + context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size), + dtype=torch.float, device='mlu') + context = context.to(dtype) + key = context[..., q_heads:q_heads + num_heads, :] + value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :] + + # prepare key_scale and value_scale + if quant_mode == 0 : # per_channel + cache_scale = torch.randn((2, num_heads, head_size), dtype=torch.float, device='mlu') + else: + cache_scale = torch.randn((2, num_heads, cache_memory_len), dtype=torch.float, device='mlu') + key_cache_scale = cache_scale[0] + value_cache_scale = cache_scale[1] + + cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu') + cache = (cache - 0.5) * 256 + cache = cache.to(torch.int8) + + + key_cache = cache[0] + value_cache = cache[1] + + ref_cache = cache.clone() + ref_key_cache = ref_cache[0] + ref_value_cache = ref_cache[1] + + cache_bs_id = random.sample([*range(0, max_bs)], batch_size) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch + + if packed > 0: + ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, cu_context_lens, max_context_len, + quant_mode, packed > 0, None, cache_bs_id, + cache_seq_offsets) + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, key_cache_scale, + value_cache_scale, cu_context_lens, max_context_len, + quant_mode, packed > 0, None, cache_bs_id, + cache_seq_offsets) + else: + ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, context_lens, max_context_len, + quant_mode, packed > 0, context_seq_offsets, cache_bs_id, + cache_seq_offsets) + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, key_cache_scale, + value_cache_scale, context_lens, max_context_len, + quant_mode, packed > 0, context_seq_offsets, cache_bs_id, + cache_seq_offsets) + # for debug + cache = cache.cpu().flatten() + ref_cache = ref_cache.cpu().flatten() + diff = cache - ref_cache + diff = diff.abs() + assert torch.max(diff) < 2, "ref_cache must equal cache or absolute values differ by 1 due to round_mode!" + + def test_inductor(self): + batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16 + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_linear_cache, args) + + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_linear_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestOfflineQuantToLinearCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_paged_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_paged_cache.py new file mode 100644 index 0000000..e2c323d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_offline_quant_to_paged_cache.py @@ -0,0 +1,209 @@ +import torch +import unittest +import torch_mlu_ops as ops +import random +from common_utils import * +import numpy as np + +def quant2int8(input_fp: torch.Tensor, + scale_fp: torch.Tensor): + input_fp32 = input_fp.to(torch.float32) + scaled_input = input_fp32 / scale_fp + rounded = torch.round(scaled_input) + clipped = torch.clip(rounded, -128, 127) + return clipped.to(torch.int8) + +def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype): + args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True, True) + cache_scale = args[10] + slot_mapping = args[11] + args = args[0:4] + args.insert(2, cache_scale[0]) + args.insert(3, cache_scale[1]) + args.insert(4, slot_mapping) + return args +class TestOfflineQuantToPagedCache(BtTestCase): + def op_impl_base(self, *args): + k, v, k_cache_scale, v_cache_scale, slot_mapping, k_cache, v_cache = args + tokens_num = k.shape[0] + block_size = k_cache.shape[2] + for i in range(tokens_num): + if slot_mapping[i] >= 0: + key_i = k[i] + block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + key_cache_i = k_cache[block_id, :, block_offset, :] + quant_key_i = quant2int8(key_i, k_cache_scale) + key_cache_i[...] = quant_key_i + if v is not None: + value_i = v[i] + value_cache_i = v_cache[block_id, :, block_offset, :] + quant_value_i = quant2int8(value_i, v_cache_scale) + value_cache_i[...] = quant_value_i + return (k_cache, v_cache) if v is not None else k_cache + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device") + def test_offline_quant_to_paged_cache(self): + test_cases = 10 + token_list = torch.randint(low=1, high=512, size=(test_cases, ), dtype=torch.int32) + head_num_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + head_size_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + block_size_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32) + dtype_map = [torch.bfloat16, torch.half, torch.float32] + only_quant_key_list = [True, False] + + for i in range(test_cases): + tokens_num = token_list[i].item() + head_num = head_num_list[i].item() + head_size = head_size_list[i].item() + block_size = block_size_list[i].item() + dtype = dtype_map[dtype_list[i]] + print("tokens_num={}, head_num={}, head_size={}, block_size={}, dtype={} testing...".format( + tokens_num, head_num, head_size, block_size, dtype)) + np.random.seed(1) + only_quant_key = random.choice(only_quant_key_list) + key_data = np.random.uniform(-1, 1, size=[tokens_num, head_num, head_size]) + key_cache_scale_data = np.random.uniform(-10, 10, size=[head_num, head_size]) + key = torch.tensor(key_data, dtype=dtype, device="mlu") + key_cache_scale = torch.tensor(key_cache_scale_data, dtype=torch.float32, device="mlu") + value_data = np.random.uniform(-0.25, 0.25, size=[tokens_num, head_num, head_size]) + value_cache_scale_data = np.random.uniform(-10, 10, size=[head_num, head_size]) + value = torch.tensor(value_data, dtype=dtype, device="mlu") + value_cache_scale = torch.tensor(value_cache_scale_data, dtype=torch.float32, device="mlu") + + min_blocks = (int)((tokens_num + block_size - 1) / block_size) + blocks_num = min(min_blocks + 10, 2 * min_blocks) + num_slots = blocks_num * block_size + slot_mapping = random.sample(range(num_slots), tokens_num) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu") + slot_mapping[-1] = -1 # test mask + + torch.manual_seed(0) + key_cache = torch.randint(-128, 127, (blocks_num, head_num, block_size, head_size), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (blocks_num, head_num, block_size, head_size), dtype=torch.int8).mlu() + + #python base result + key_cache_base = key_cache.clone() + value_cache_base = value_cache.clone() + if only_quant_key: + value, value_cache_scale, value_cache_base, value_cache = None, None, None, None + self.op_impl_base(key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache_base, value_cache_base) + #mlu result + ops.offline_quant_to_paged_cache(key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + #compute diff + baseline_key_cache = key_cache_base.cpu().flatten() + mlu_key_cache = key_cache.cpu().flatten() + key_cache_diff = (mlu_key_cache - baseline_key_cache).abs() + assert torch.max(key_cache_diff) < 2, "key_cache_diff exceed threshold" + if not only_quant_key: + baseline_value_cache = value_cache_base.cpu().flatten() + mlu_value_cache = value_cache.cpu().flatten() + value_cache_diff = (mlu_value_cache - baseline_value_cache).abs() + assert torch.max(value_cache_diff) < 2, "value_cache_diff exceed threshold" + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_large_tensor(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + print("offline_quant_to_paged_cache: test_large_tensor...") + head_num = 16 + head_size = 128 + token_nums = 200 + block_size = 16 + block_nums = ((2**32 - 1) // 1 // head_num // head_size // block_size) + num_slots = block_nums * block_size + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums] + key = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu") + value = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu") + key_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu") + value_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu") + key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu") + + #python base result + key_cache_base = key_cache.clone() + value_cache_base = value_cache.clone() + self.op_impl_base(key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache_base, value_cache_base) + #mlu result + ops.offline_quant_to_paged_cache(key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + #compute diff + baseline_key_cache = key_cache_base.cpu().flatten() + mlu_key_cache = key_cache.cpu().flatten() + key_cache_diff = (mlu_key_cache - baseline_key_cache).abs() + assert torch.max(key_cache_diff) < 2, "key_cache_diff exceed threshold" + + baseline_value_cache = value_cache_base.cpu().flatten() + mlu_value_cache = value_cache.cpu().flatten() + value_cache_diff = (mlu_value_cache - baseline_value_cache).abs() + assert torch.max(value_cache_diff) < 2, "value_cache_diff exceed threshold" + + block_nums = (2**32 // 1 // head_num // head_size // block_size) + num_slots = block_nums * block_size + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums] + key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu") + self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.offline_quant_to_paged_cache, + key, value, key_cache_scale, value_cache_scale, slot_mapping, key_cache, value_cache) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + print("offline_quant_to_paged_cache: test prevent...") + head_num = 16 + head_size = 128 + token_nums = 200 + block_size = 16 + block_nums = 20 + num_slots = block_nums * block_size + dtype = torch.half + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums] + key = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu") + value = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu") + key_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu") + value_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu") + key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache = None + self.assertException("v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().", + ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu") + value = value.reshape(token_nums, head_num, head_size, 1) + self.assertException("dim of v must be 3", + ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + value = value.squeeze() + value_cache = value_cache.reshape(block_nums, head_num, block_size, head_size, 1) + self.assertException("dim of v_cache must be 4", + ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + value_cache = value_cache.squeeze() + value_cache_scale = value_cache_scale.reshape(head_num, head_size, 1) + self.assertException("dim of v_cache_scale must be 2", + ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + value_cache_scale = value_cache_scale.squeeze() + value = value.as_strided(size=(token_nums, head_num, head_size), stride=(head_num, 1, head_size)) + self.assertException("v last dim must be contiguous.", + ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + value = value.as_strided(size=(token_nums, head_num, head_size), stride=(head_num, token_nums, 1)) + self.assertException("v second dim must be contiguous.", + ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale, + slot_mapping, key_cache, value_cache) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device") + def test_inductor(self): + batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16 + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_paged_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestOfflineQuantToPagedCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_per_token_smooth_quantize.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_per_token_smooth_quantize.py new file mode 100755 index 0000000..c88f534 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_per_token_smooth_quantize.py @@ -0,0 +1,72 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as tmo +from common_utils import * +import random + + +class TestPerTokenSmoothQuantizeOp(BtTestCase): + def op_impl_base(self, *args): + x, smooth, zero, m_list = args + x_shape = x.size() + scale_shape = x.size()[0:-1] + if m_list is None: + smoothed = x * smooth + else: + input_list = x.split(tuple(m_list)) + experts = len(input_list) + result = [] + for i in range(experts): + result.append(input_list[i] * smooth[i]) + smoothed = torch.concat(result, dim=0) + output, scale = QuantByRow(smoothed, 8) + return output.reshape(x_shape), scale.reshape(scale_shape) + + def test_random_case(self): + torch.manual_seed(0) + case_list = set() + while(len(case_list) < 200): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + dtype = random.choice(dtype_list) + has_group = random.choice([False, True]) + if has_group: + experts = random.randint(1, 40) + m_list = torch.randint(1, 100, (experts,), device="mlu", dtype=torch.int32) + ci = m_list.sum().item() + else: + experts = None + ci = random.randint(1, 4096) + co = random.randint(1, 4096) + case = (experts, ci, co, dtype) + if case in case_list: + continue + else: + case_list.add(case) + x = torch.randn(ci, co, device="mlu", dtype=dtype) + if has_group: + scale = torch.randn(experts, co, device="mlu", dtype=torch.float32) + else: + scale = torch.randn(co, device="mlu", dtype=torch.float32) + print("experts={}, ci={}, co={}, dtype={}, testing...".format(experts, ci, co, dtype), flush=True) + param = (x, scale, None, m_list if has_group else None) + tmo_output, tmo_scale = tmo.per_token_smooth_quantize(*param) + torch_output, torch_scale = self.op_impl_base(*param) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_scale.cpu().float(), tmo_scale.cpu().float(), 0.01, use_MSE=True, use_RAE=True) + + def test_inductor(self): + m_list = torch.randint(1, 100, (8,), device="mlu", dtype=torch.int32) + total_m = m_list.sum().item() + x = torch.randn(total_m, 1024, device="mlu", dtype=torch.half) + scale = torch.randn(8, 1024, device="mlu", dtype=torch.float32) + output = torch.empty(x.size(), dtype=torch.int8, device="mlu") + output_scale = torch.empty(x.size()[:-1], dtype=torch.float32, device="mlu") + args = (x, scale, output, output_scale, None, m_list, None, None, 'per_token', True) + self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args) + + +if __name__ == '__main__': + exit(run_unittest(TestPerTokenSmoothQuantizeOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_preload.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_preload.py new file mode 100644 index 0000000..e1a882e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_preload.py @@ -0,0 +1,21 @@ +import torch +import unittest +import torch_mlu_ops as ops +from common_utils import * + +class TestPreloadOp(BtTestCase): + def op_impl_base(self, *args): + wegiht, size = args + return super().op_impl_base(*args) + + def test_preload(self): + weight = torch.randn((1024, 8, 5, 1024)).half().mlu() + ops.preload(weight, weight.element_size() * weight.numel()) + torch.mlu.synchronize() + + def test_inductor(self): + weight = torch.randn((1024, 8, 5, 1024)).half().mlu() + self.base_opcheck(torch.ops.torch_mlu_ops.preload, (weight, weight.element_size() * weight.numel())) + +if __name__ == '__main__': + exit(run_unittest(TestPreloadOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_ffn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_ffn.py new file mode 100755 index 0000000..6c14134 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_ffn.py @@ -0,0 +1,137 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from torch.nn import functional as F +from torch.nn.parameter import Parameter + +active = torch.nn.functional.silu + +def torch_ffn(input, w1, scale1, bias1, w2, scale2, bias2, is_gated): + tmp_input = input.flatten(0, -2).to(torch.float) + inner_size = w1.size(0) // (1 + is_gated) + imm = weight_only_quant_matmul(tmp_input, w1, scale1, bias1) + acted = active(imm[:, :inner_size]) + acted = acted * imm[:, inner_size:] if is_gated else acted + out = weight_only_quant_matmul(acted, w2, scale2, bias2) + return out.reshape(input.shape) + +def torch_smooth_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, input_smooth, act_smooth, is_gated): + inner_size = w1.size(0) // (1 + is_gated) + tmp_input = input.flatten(0, -2).to(torch.float) + quant_input, input_scale = QuantByRow(tmp_input.flatten(0, -2) * input_smooth, 8) + imm = smooth_quant_matmul(quant_input, input_scale, w1, scale1, input.dtype, bias1) + acted = active(imm[:, :inner_size]) + acted = acted * imm[:, inner_size:] if is_gated else acted + quant_acted, acted_scale = QuantByRow(acted.flatten(0, -2) * act_smooth, 8) + out = smooth_quant_matmul(quant_acted, acted_scale, w2, scale2, input.dtype, bias2) + return out.reshape(input.shape) + +def tmo_weight_only_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit): + tmp_input = input.flatten(0, -2) + imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'silu', quant_bit, True) + out = ops.weight_only_quant_matmul(imm, w2, scale2, None, bias2, None, "none", quant_bit) + return out.reshape(input.shape) + +def tmo_weight_only_group_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit): + tmp_input = input.flatten(0, -2) + imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'none', quant_bit) + acted = active(imm) + out = ops.weight_only_quant_matmul(acted, w2, scale2, None, bias2, None, "none", quant_bit) + return out.reshape(input.shape) + +def tmo_weight_only_quant_gated_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit): + tmp_input = input.flatten(0, -2) + imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'none', quant_bit) + acted = ops.active(imm, 'silu', True) + out = ops.weight_only_quant_matmul(acted, w2, scale2, None, bias2, None, "none", quant_bit) + return out.reshape(input.shape) + +def tmo_pertoken_smooth_quant_gated_ffn(input, w1, scale1, bias13, w2, scale2, bias2, input_smooth, act_smooth, dtype): + tmp_input = input.flatten(0, -2) + quant_input, input_scale = ops.per_token_smooth_quantize(tmp_input, input_smooth, None) + imm = ops.smooth_quant_matmul(quant_input, input_scale, w1, scale1, dtype, bias13) + acted = ops.active(imm, 'silu', True) + quant_acted, acted_scale = ops.per_token_smooth_quantize(acted, act_smooth, None) + out = ops.smooth_quant_matmul(quant_acted, acted_scale, w2, scale2, dtype, bias2) + return out.reshape(input.shape) + +def init_tensors(batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=torch.half, group_num=1): + sigma = 0.1 + eps = 0.1 # Avoid the occurrence of nan + torch.manual_seed(1) + input_smooth = torch.randn(hidden_size, dtype=torch.float, device="mlu").abs() + eps + act_smooth = torch.randn(inner_size, dtype=torch.float, device="mlu").abs() + eps + w1 = torch.randn((1 + is_gated) * inner_size, hidden_size, dtype=dtype, device="mlu") * sigma + w1 = w1 / input_smooth if with_smooth else w1 + bias1 = torch.randn((1 + is_gated) * inner_size, dtype=dtype, device="mlu") * sigma + w2 = torch.randn(hidden_size, inner_size, dtype=dtype, device="mlu") * sigma + w2 = w2 / act_smooth if with_smooth else w2 + bias2 = torch.randn(hidden_size, dtype=dtype, device="mlu") * sigma + input = torch.randn(batch, seq, hidden_size, dtype=dtype, device="mlu") + quant_w1, scale1 = QuantByRow(w1, quant_bit, group_num) + quant_w2, scale2 = QuantByRow(w2, quant_bit, group_num) + return input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, input_smooth, act_smooth + +dtype_list = [torch.half] +if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) +class TestQuantFFN(BtTestCase): + def op_impl_base(self, *args): + return super().op_impl_base(*args) + + def test_weight_only_quant_ffn(self): + for dtype in dtype_list: + print("test_weight_only_quant_ffn...") + batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 768, 8, False, False + input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq, + hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype) + torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, False) + tmo_out = tmo_weight_only_quant_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, quant_bit) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + + def test_weight_only_quant_gated_ffn(self): + for dtype in dtype_list: + print("test_weight_only_quant_gated_ffn...") + batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 768, 4, True, False + input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq, + hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype) + quant_w1_int4 = PairlyPackInt8(quant_w1) + quant_w2_int4 = PairlyPackInt8(quant_w2) + torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, True) + tmo_out = tmo_weight_only_quant_gated_ffn(input, quant_w1_int4, scale1, bias1, quant_w2_int4, scale2, bias2, quant_bit) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.005, use_MSE=True, use_RAE=True) + + def test_weight_only_group_quant_ffn(self): + for dtype in dtype_list: + print("test_weight_only_group_quant_ffn...") + batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 512, 8, False, False + input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq, + hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype, group_num=8) + torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, is_gated) + tmo_out = tmo_weight_only_group_quant_ffn(input, quant_w1, scale1.to(dtype=input.dtype), bias1, + quant_w2, scale2.to(dtype=input.dtype), bias2, quant_bit) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.05, use_MSE=True, use_RAE=True) + + def test_pertoken_smooth_quant_ffn(self): + for dtype in dtype_list: + print("test_pertoken_smooth_quant_ffn...") + batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype = 3, 5, 512, 768, 8, True, True, dtype + input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, input_smooth, act_smooth = \ + init_tensors(batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype) + torch_out = torch_smooth_quant_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, + input_smooth, act_smooth, is_gated) + tmo_out = tmo_pertoken_smooth_quant_gated_ffn(input, quant_w1, scale1, bias1, quant_w2, + scale2, bias2, input_smooth, act_smooth, dtype) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.05, use_MSE=True, use_RAE=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestQuantFFN)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_linear_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_linear_cache.py new file mode 100755 index 0000000..dcf8e68 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_linear_cache.py @@ -0,0 +1,275 @@ +import torch +import unittest +import torch_mlu_ops as ops +import numpy as np +import random +import math +from common_utils import * +from typing import Optional + +def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, quant_bit): + args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True) + cache_scale = args[10] + args = args[0:10] + args.insert(4, cache_scale[0]) + args.insert(5, cache_scale[1]) + args.insert(12, quant_bit) + return args + +class TestQuantToLinearCache(BtTestCase): + def op_impl_base(self, *args): + key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, \ + context_lengths, max_context_len, packed, context_seq_offset, cache_bs_id, cache_seqlen_offset, \ + quant_bit = args + def quant(context: torch.Tensor, + quant_bit: int, + group_size: int): + # context:[head_num, seq, head_size] + head_num = context.shape[0] + head_size = context.shape[-1] + if group_size != head_size: + context = context.reshape(head_num, -1, group_size) + context_fp32 = context.to(torch.float32) + max_value, _ = torch.max(context_fp32.abs(), dim=-1, keepdim=True) + int_max = float(2 ** (quant_bit - 1) - 1) + scale = max_value / int_max + scaled_context = context_fp32 / scale + return scaled_context.reshape(head_num, -1, head_size), scale[..., 0] + + batch_size = context_lengths.shape[0] - 1 if packed else context_lengths.shape[0] + head_num = key.shape[-2] + head_size = key.shape[-1] + + if key_cache_quant_scale.dim() == 3: + group_size = head_size + else: + group_size = head_size // key_cache_quant_scale.size(-1) + for i in range(batch_size): + if packed: + key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) + value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None + context_len_i = context_lengths[i+1] - context_lengths[i] + context_seq_offset_i = 0 + else: + key_i = key[i].transpose(1, 0) + value_i = value[i].transpose(1, 0) if value is not None else None + context_len_i = context_lengths[i] + context_seq_offset_i = context_seq_offset[i] + cache_bs_id_i = cache_bs_id[i] + cache_seqlen_offset_i = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0 + if cache_bs_id_i < 0 or cache_seqlen_offset_i < 0: + continue + key_cache_i = \ + key_cache[cache_bs_id_i, :, \ + cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i] + key_cache_scale_i = \ + key_cache_quant_scale[cache_bs_id_i, :, \ + cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i] + + # key_i[head_num, context_len[i], head_size] + key_i = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i] + float_key_i, key_scale_i = quant(key_i, quant_bit, group_size) + key_cache_scale_i = key_cache_scale_i.reshape(key_cache_scale_i.shape[0], -1) + key_cache_scale_i[...] = key_scale_i + rounded = torch.round(float_key_i) + clipped = torch.clip(rounded, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1) + quant_key_i = clipped.to(torch.int8) + if quant_bit == 4: + quant_key_flat = quant_key_i.flatten() + d0 = quant_key_flat[0::2].to(torch.uint8) + d1 = quant_key_flat[1::2].to(torch.uint8) + dp = (d1 << 4) + (d0 & 0x0F) + quant_key_i = dp.to(torch.int8).reshape(head_num, -1, head_size // 2) + key_cache_i[...] = quant_key_i + + if value_cache is not None and value is not None: + value_cache_scale_i = \ + value_cache_quant_scale[cache_bs_id_i, :, \ + cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i] + value_i = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i] + float_value_i, value_scale_i = quant(value_i, quant_bit, group_size) + value_cache_scale_i = value_cache_scale_i.reshape(value_cache_scale_i.shape[0], -1) + value_cache_scale_i[...] = value_scale_i + rounded = torch.round(float_value_i) + clipped = torch.clip(rounded, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1) + quant_value_i = clipped.to(torch.int8) + if quant_bit == 8: + value_cache_i = \ + value_cache[cache_bs_id_i, :, \ + cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i] + else: + value_cache_i = \ + value_cache[cache_bs_id_i, :, \ + cache_seqlen_offset_i // 2:math.ceil((cache_seqlen_offset_i + context_len_i) / 2)] + if cache_seqlen_offset_i % 2 == 1: + front_vec = value_cache[cache_bs_id_i, :, cache_seqlen_offset_i // 2, :] + front_low_bits = front_vec & (0x0F) + front_low_bits_expand = front_low_bits.unsqueeze(1) # [head_num, 1, head_size] + quant_value_i = torch.cat((front_low_bits_expand, quant_value_i), dim=1) + if (cache_seqlen_offset_i + context_len_i) % 2 == 1: + back_vec = value_cache[cache_bs_id_i, :, math.ceil((cache_seqlen_offset_i + context_len_i) / 2) - 1, :] + back_high_bits = (back_vec >> 4) & (0x0F) + back_high_bits_expand = back_high_bits.unsqueeze(1) # [head_num, 1, head_size] + quant_value_i = torch.cat((quant_value_i, back_high_bits_expand), dim=1) + value_temp = quant_value_i.reshape(head_num, -1, 2, head_size) + quant_value_flat = value_temp.permute(0, 1, 3, 2).flatten() + v0 = quant_value_flat[0::2].to(torch.uint8) + v1 = quant_value_flat[1::2].to(torch.uint8) + vp = (v1 << 4) + (v0 & 0x0F) + quant_value_i = vp.to(torch.int8).reshape(head_num, -1, head_size) + value_cache_i[...] = quant_value_i + return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale) + + def int8_to_int4(self, + input): + input_flat = input.flatten() + size = input_flat.size(0) + output = torch.zeros(size * 2, dtype=torch.int8, device=input.device) + high = input_flat >> 4 + low = input_flat << 4 + low = low >> 4 + output[0::2] = low + output[1::2] = high + return output + + def test_quant_to_linear_cache(self): + test_cases = 100 + bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + num_heads_list = torch.randint(low=1, high=32, size=(test_cases, ), dtype=torch.int32) + head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32) + head_size_list *= 16 + cache_memory_len_list = torch.randint(low=8, high=512, size=(test_cases, ), dtype=torch.int32) + cache_memory_len_list = cache_memory_len_list * 2 + packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32) + dtype_map = [torch.half, torch.bfloat16, torch.float] + quant_bit_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + quant_bit_map = [4, 8] + + for i in range(test_cases): + batch_size = bs_list[i].item() + invalid_batch = batch_size // 10 + num_heads = num_heads_list[i].item() + head_size = head_size_list[i].item() + group_size_factors = [i for i in range(4, head_size + 1) if head_size % i == 0] # group_size should > 1 + group_size = random.choice(group_size_factors) + cache_memory_len = cache_memory_len_list[i].item() + packed = packed_list[i].item() + dtype = dtype_map[dtype_list[i]] + quant_bit = quant_bit_map[quant_bit_list[i]] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name and dtype == torch.bfloat16: + dtype = torch.half + print("BFLOAT16 is not support on {}, use half instead".format(mlu_name)) + + print("case num={}, batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, dtype={}, quant_bit={}, group_size={} testing...".format( + i, batch_size, num_heads, head_size, cache_memory_len, packed > 0, dtype, quant_bit, group_size)) + + max_bs = batch_size + 1 + context_lens = torch.randint(size=(batch_size, ), low=1, + high=cache_memory_len // 4, + dtype=torch.int32, device='mlu') + context_lens = context_lens * 2 + max_context_len = context_lens.max().item() + max_seq_offset = max_context_len // 3 + 1 + context_seq_offsets = torch.randint(size=(batch_size, ), + low=0, high=max_seq_offset, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch_size, ), low=0, + high=(cache_memory_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid + + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + if packed > 0: + key = torch.randn((total_seqlen, num_heads, head_size), + dtype=torch.float, device='mlu') + value = torch.randn((total_seqlen, num_heads, head_size), + dtype=torch.float, device='mlu') + else: + key = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size), + dtype=torch.float, device='mlu') + value = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size), + dtype=torch.float, device='mlu') + key = key.to(dtype) + value = value.to(dtype) + if quant_bit == 8 and group_size == head_size: + key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu() + if quant_bit == 8 and group_size != head_size: + key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu() + if quant_bit == 4 and group_size == head_size: + key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size // 2), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len // 2, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu() + if quant_bit == 4 and group_size != head_size: + key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size // 2), dtype=torch.int8).mlu() + value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len // 2, head_size), dtype=torch.int8).mlu() + key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu() + value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu() + ref_key_cache = key_cache.clone() + ref_value_cache = value_cache.clone() + ref_key_cache_scale = key_cache_scale.clone() + ref_value_cache_scale = value_cache_scale.clone() + + cache_bs_id = random.sample([*range(0, max_bs)], batch_size) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch + + if packed > 0: + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, ref_key_cache_scale, + ref_value_cache_scale, cu_context_lens, max_context_len, + packed > 0, None, cache_bs_id, cache_seq_offsets, + quant_bit) + ops.quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, cu_context_lens, max_context_len, + packed > 0, None, cache_bs_id, cache_seq_offsets, + quant_bit) + else: + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, ref_key_cache_scale, + ref_value_cache_scale, context_lens, max_context_len, + packed > 0, context_seq_offsets, cache_bs_id, + cache_seq_offsets, quant_bit) + ops.quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, context_lens, max_context_len, + packed > 0, context_seq_offsets, cache_bs_id, + cache_seq_offsets, quant_bit) + if quant_bit == 8: + self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0.003, + "key_cache must equal ref_key_cache", True, True, True, True) + self.assertTensorsEqual(value_cache.cpu().float(), ref_value_cache.cpu().float(), 0.003, + "value_cache must equal ref_value_cache", True, True, True, True) + else: + key_cache_int8 = self.int8_to_int4(key_cache) + ref_key_cache_int8 = self.int8_to_int4(ref_key_cache) + value_cache_int8 = self.int8_to_int4(value_cache) + ref_value_cache_int8 = self.int8_to_int4(ref_value_cache) + + key_cache_diff = (key_cache_int8.cpu() - ref_key_cache_int8.cpu()).abs() + assert torch.max(key_cache_diff) < 2, "ref_key_cache must equal key_cache or absolute values differ by 1 due to round_mode!" + value_cache_diff = (value_cache_int8.cpu() - ref_value_cache_int8.cpu()).abs() + assert torch.max(value_cache_diff) < 2, "ref_value_cache must equal value_cache or absolute values differ by 1 due to round_mode!" + + self.assertTensorsEqual(key_cache_scale.cpu().float(), ref_key_cache_scale.cpu().float(), 0.003, + "key_cache_scale must equal ref_key_cache_scale", True, True, True, True) + self.assertTensorsEqual(value_cache_scale.cpu().float(), ref_value_cache_scale.cpu().float(), 0.003, + "value_cache_scale must equal ref_value_cache_scale", True, True, True, True) + + def test_inductor(self): + batch_size, num_heads, head_size, cache_memory_len, dtype, quant_bit = 4, 8, 64, 128, torch.float16, 8 + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype, quant_bit) + self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_linear_cache, args) + + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype, quant_bit) + self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_linear_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestQuantToLinearCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_paged_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_paged_cache.py new file mode 100644 index 0000000..90b9429 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quant_to_paged_cache.py @@ -0,0 +1,213 @@ +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +import copy + +class TestQuantToPagedCache(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + k = create_tensor_from_dic(dic['k'], is_uniform=True, low=-1, high=1) + v = create_tensor_from_dic(dic['v'], is_uniform=True, low=-0.25, high=0.25) + k_cache = create_tensor_from_dic(dic['k_cache']) + v_cache = create_tensor_from_dic(dic['v_cache']) + k_cache_quant_scale = create_tensor_from_dic(dic['k_cache_quant_scale']) + v_cache_quant_scale = create_tensor_from_dic(dic['v_cache_quant_scale']) + slot_mapping = dic['slot_mapping']['data'] + self.launch(k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping) + + def launch(self, *args): + v = args[1] + args_bak = copy.deepcopy(args) + torch_out = self.op_impl_base(*args_bak) + tmo_out = ops.quant_to_paged_cache(*args) + self.assertTensorsEqual(torch_out[0].cpu().float(), tmo_out[0].cpu().float(), 9e-3, + use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_out[1].cpu().float(), tmo_out[1].cpu().float(), 9e-3, + use_MSE=True, use_RAE=True) + if v is not None: + self.assertTensorsEqual(torch_out[2].cpu().float(), tmo_out[2].cpu().float(), 3e-3, + use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_out[3].cpu().float(), tmo_out[3].cpu().float(), 3e-3, + use_MSE=True, use_RAE=True) + + def op_impl_base(self, *args): + k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping = args + # input_fp[head_num, head_size] + def quant(input_fp: torch.Tensor): + input_fp32 = input_fp.to(torch.float32) + max_value, _ = torch.max(input_fp32.abs(), dim=-1, keepdim=True) + scale = max_value / 127.0 + scaled_input = input_fp32 / scale + return scaled_input.to(torch.int8), scale[..., 0] + + tokens_num = k.shape[0] # [token_num, head_num, head_size] + block_size = k_cache.shape[2] + for i in range(tokens_num): + if slot_mapping[i] >= 0: + key_i = k[i] + block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + key_cache_i = k_cache[block_id, :, block_offset, :] + key_cache_scale_i = k_cache_quant_scale[block_id, :, block_offset] + quant_key_i, key_scale_i = quant(key_i) + key_cache_i[...] = quant_key_i + key_cache_scale_i[...] = key_scale_i + if v is not None: + value_i = v[i] + value_cache_i = v_cache[block_id, :, block_offset, :] + value_cache_scale_i = v_cache_quant_scale[block_id, :, block_offset] + quant_value_i, value_scale_i = quant(value_i) + value_cache_i[...] = quant_value_i + value_cache_scale_i[...] = value_scale_i + return (k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale) if v is not None else (k_cache, k_cache_quant_scale) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device") + def test_quant_to_paged_cache(self): + token_nums = random.randint(1, 2048) + head_num_kv = random.randint(1, 128) + head_size = random.randint(1, 1024) + block_size = random.randint(1, 50) + min_blocks = (int)((token_nums + block_size - 1) / block_size) + block_nums = min(min_blocks + 10, 2 * min_blocks) + num_slots = block_nums * block_size + slot_mapping = random.sample(range(num_slots), token_nums) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu() + slot_mapping[-1] = -1 # test mask + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + only_quant_key_list = [True, False] + + for _ in range(100): + print("test_quant_to_paged_cache...") + dtype = random.choice(dtype_list) + only_quant_key = random.choice(only_quant_key_list) + key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1) + key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu") + key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu") + if only_quant_key: + value, value_cache, value_cache_scale = None, None, None + else: + value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25) + value_cache = torch.zeros_like(key_cache) + value_cache_scale = torch.zeros_like(key_cache_scale) + self.launch(key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, slot_mapping) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_large_tensor(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + print('quant_to_paged_cache: test_large_tensor') + head_num_kv = 16 + head_size = 128 + token_nums = 20 + block_size = 16 + block_nums = ((2**32 - 1) // 1 // head_num_kv // head_size // block_size) + num_slots = block_nums * block_size + key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1) + value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25) + key_cache_torch = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache_torch = torch.zeros_like(key_cache_torch) + key_cache_scale_torch = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu") + value_cache_scale_torch = torch.zeros_like(key_cache_scale_torch) + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums] + key_cache_tmo = torch.zeros_like(key_cache_torch) + value_cache_tmo = torch.zeros_like(value_cache_torch) + key_cache_scale_tmo = torch.zeros_like(key_cache_scale_torch) + value_cache_scale_tmo = torch.zeros_like(value_cache_scale_torch) + self.op_impl_base(key, value, key_cache_torch, value_cache_torch, key_cache_scale_torch, value_cache_scale_torch, slot_mapping) + ops.quant_to_paged_cache(key, value, key_cache_tmo, value_cache_tmo, key_cache_scale_tmo, value_cache_scale_tmo, slot_mapping) + self.assertTensorsEqual(key_cache_torch.cpu(), key_cache_tmo.cpu(), 1) + self.assertTensorsEqual(value_cache_torch.cpu(), value_cache_tmo.cpu(), 1) + self.assertTensorsEqual(key_cache_scale_torch.cpu().float(), key_cache_scale_tmo.cpu().float(), 3e-3, + use_MSE=True, use_RAE=True) + self.assertTensorsEqual(value_cache_scale_torch.cpu().float(), value_cache_scale_tmo.cpu().float(), 3e-3, + use_MSE=True, use_RAE=True) + + block_nums = (2**32 // 1 // head_num_kv // head_size // block_size) + num_slots = block_nums * block_size + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums] + key_cache_torch = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache_torch = torch.zeros_like(key_cache_torch) + key_cache_scale_torch = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu") + value_cache_scale_torch = torch.zeros_like(key_cache_scale_torch) + key_cache_tmo = torch.zeros_like(key_cache_torch) + value_cache_tmo = torch.zeros_like(value_cache_torch) + key_cache_scale_tmo = torch.zeros_like(key_cache_scale_torch) + value_cache_scale_tmo = torch.zeros_like(value_cache_scale_torch) + self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.quant_to_paged_cache, + key, value, key_cache_tmo, value_cache_tmo, key_cache_scale_tmo, value_cache_scale_tmo, slot_mapping) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + token_nums = random.randint(1, 2048) + head_num_kv = random.randint(1, 128) + head_size = random.randint(1, 1024) + block_size = random.randint(1, 50) + min_blocks = (int)((token_nums + block_size - 1) / block_size) + block_nums = min(min_blocks + 10, 2 * min_blocks) + num_slots = block_nums * block_size + slot_mapping = random.sample(range(num_slots), token_nums) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu() + slot_mapping[-1] = -1 # test mask + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + dtype = random.choice(dtype_list) + + print("quant_to_paged_cache: test_prevent...") + key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1) + value = None + key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache = torch.zeros_like(key_cache) + key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu") + value_cache_scale = torch.zeros_like(key_cache_scale) + self.assertException("v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().", + ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, slot_mapping) + value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25) + value = value.as_strided(size=(token_nums, head_num_kv, head_size), stride=(1, token_nums, token_nums * head_num_kv)) + self.assertException("v last dim must be contiguous.", + ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, slot_mapping) + value = value.as_strided(size=(token_nums, head_num_kv, head_size), stride=(head_size, head_num_kv, 1)) + self.assertException("v second dim must be contiguous.", + ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale, + value_cache_scale, slot_mapping) + + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device") + def test_inductor(self): + token_nums = 20 + head_num_kv = 30 + head_size = 20 + block_size = 10 + min_blocks = (int)((token_nums + block_size - 1) / block_size) + block_nums = min(min_blocks + 10, 2 * min_blocks) + num_slots = block_nums * block_size + slot_mapping = random.sample(range(num_slots), token_nums) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu() + slot_mapping[-1] = -1 # test mask + + dtype = torch.half + key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1) + value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25) + key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu") + value_cache = torch.zeros_like(key_cache) + key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu") + value_cache_scale = torch.zeros_like(key_cache_scale) + args = (key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, slot_mapping) + self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_paged_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestQuantToPagedCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py new file mode 100755 index 0000000..aeb2995 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_quantize.py @@ -0,0 +1,46 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as tmo +from common_utils import * +import random + + +class TestQuantizeOp(BtTestCase): + def op_impl_base(self, *args): + x, smooth, zero = args + return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8) + + def test_random_case(self): + torch.manual_seed(0) + case_list = set() + while(len(case_list) < 100): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + dtype = random.choice(dtype_list) + ci = random.randint(1, 4096) + co = random.randint(1, 4096) + case = (ci, co) + if case in case_list: + continue + else: + case_list.add((ci, co)) + x = torch.randn(ci, co, device="mlu", dtype=dtype) + scale = torch.randn(co, device="mlu", dtype=torch.float32) + print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True) + param = (x, scale, None) + tmo_output = tmo.quantize(*param) + torch_output = self.op_impl_base(*param) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True) + + def test_inductor(self): + x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half) + scale = torch.randn(1024, device="mlu", dtype=torch.float32) + output = torch.empty(x.size(), dtype=torch.int8, device="mlu") + args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False) + self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args) + + +if __name__ == '__main__': + exit(run_unittest(TestQuantizeOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_linear_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_linear_cache.py new file mode 100755 index 0000000..ef64037 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_linear_cache.py @@ -0,0 +1,179 @@ +import torch +import unittest +import torch_mlu_ops as ops +import random +from common_utils import * +from typing import Optional + +def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype): + args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype) + return args[0:10] + +class TestReshapeLinearCache(BtTestCase): + def op_impl_base(self, *args): + key, value, key_cache, value_cache, context_lengths, max_context_len, packed, \ + context_seq_offset, cache_bs_id, cache_seqlen_offset = args + + batch_size = context_lengths.shape[0] - 1 if packed else context_lengths.shape[0] + for i in range(batch_size): + if packed: + key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) + value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None + context_len_i = context_lengths[i+1] - context_lengths[i] + context_seq_offset_i = 0 + else: + key_i = key[i].transpose(1, 0) + value_i = value[i].transpose(1, 0) if value is not None else None + context_len_i = context_lengths[i] + context_seq_offset_i = context_seq_offset[i] + cache_bs_id_i = cache_bs_id[i] + cache_seqlen_offset_i = cache_seqlen_offset[i] + if cache_seqlen_offset_i < 0 or cache_bs_id_i < 0: + continue + + key_cache_i = \ + key_cache[cache_bs_id_i, :, \ + cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i] + + key_cache_i[...] = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i] + if value_cache is not None and value is not None: + value_cache_i = \ + value_cache[cache_bs_id_i, :, \ + cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i] + value_cache_i[...] = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i] + return (key_cache, value_cache, ) if value_cache is not None else key_cache + + def test_reshape_linear_cache(self): + test_cases = 100 + bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + num_heads_list = torch.randint(low=1, high=32, size=(test_cases, ), dtype=torch.int32) + head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32) + head_size_list *= 16 + cache_memory_len_list = torch.randint(low=16, high=1024, size=(test_cases, ), dtype=torch.int32) + packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + dtype_list = torch.randint(low=0, high=4, size=(test_cases, ), dtype=torch.int32) + dtype_map = [torch.int8, torch.half, torch.bfloat16, torch.float] + + for i in range(test_cases): + q_heads = 1 + batch_size = bs_list[i].item() + invalid_batch = batch_size // 10 + num_heads = num_heads_list[i].item() + head_size = head_size_list[i].item() + cache_memory_len = cache_memory_len_list[i].item() + packed = packed_list[i].item() + total_heads = q_heads + num_heads * 2 + dtype = dtype_map[dtype_list[i]] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name and dtype == torch.bfloat16: + dtype = torch.half + print("BFLOAT16 is not support on {}, use half instead".format(mlu_name)) + + print("batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, dtype={} testing...".format( + batch_size, num_heads, head_size, cache_memory_len, packed > 0, dtype)) + + max_bs = batch_size + 1 + context_lens = torch.randint(size=(batch_size, ), low=1, + high=cache_memory_len // 2, + dtype=torch.int32, device='mlu') + # print(context_lens) + max_context_len = context_lens.max().item() + max_seq_offset = max_context_len // 3 + 1 + context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1, + high=(cache_memory_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1 + + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + total_seqlen = cu_context_lens[-1] + if packed > 0: + context = torch.randn((total_seqlen, total_heads, head_size), + dtype=torch.float, device='mlu') + else: + context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size), + dtype=torch.float, device='mlu') + cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu') + context = context.to(dtype) + cache = cache.to(dtype) + ref_cache = cache.clone() + key = context[..., q_heads:q_heads + num_heads, :] + value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :] + key_cache = cache[0] + value_cache = cache[1] + ref_key_cache = ref_cache[0] + ref_value_cache = ref_cache[1] + + cache_bs_id = None + cache_bs_id = random.sample([*range(0, max_bs)], batch_size) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch + + if packed > 0: + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, cu_context_lens, + max_context_len, packed > 0, None, + cache_bs_id, cache_seq_offsets) + ops.reshape_linear_cache(key, value, key_cache, value_cache, cu_context_lens, + max_context_len, packed > 0, None, + cache_bs_id, cache_seq_offsets) + else: + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, context_lens, + max_context_len, packed > 0, context_seq_offsets, + cache_bs_id, cache_seq_offsets) + ops.reshape_linear_cache(key, value, key_cache, value_cache, context_lens, + max_context_len, packed > 0, context_seq_offsets, + cache_bs_id, cache_seq_offsets) + self.assertTensorsEqual(cache.cpu().float(), ref_cache.cpu().float(), 0, "ref_cache must equal cache", + True, True, True, True) + + def test_reshape_linear_key_cache(self): + batch_size, num_heads, head_size, cache_memory_len = 2, 2, 16, 128 + print("[test_reshape_linear_key_cache] batch_size={}, num_heads={}, head_size={}, cache_memory_len={} testing...".format( + batch_size, num_heads, head_size, cache_memory_len)) + + max_bs = batch_size + 1 + context_lens = torch.randint(size=(batch_size, ), low=1, + high=cache_memory_len // 2, + dtype=torch.int32, device='mlu') + max_context_len = context_lens.max().item() + max_seq_offset = max_context_len // 3 + 1 + context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset, + dtype=torch.int32, device='mlu') + cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1, + high=(cache_memory_len - max_context_len) // 3 + 1, + dtype=torch.int32, device='mlu') + cu_context_lens = torch.cumsum(context_lens, dim=-1) + cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32) + + context = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size), + dtype=torch.float, device='mlu') + key_cache = torch.randn((max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu') + ref_key_cache = key_cache.clone() + key = context[..., :] + + cache_bs_id = None + cache_bs_id = random.sample([*range(0, max_bs)], batch_size) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() + + self.op_impl_base(key, None, ref_key_cache, None, context_lens, + max_context_len, False, context_seq_offsets, + cache_bs_id, cache_seq_offsets) + ops.reshape_linear_cache(key, None, key_cache, None, context_lens, + max_context_len, False, context_seq_offsets, + cache_bs_id, cache_seq_offsets) + self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0, "ref_cache must equal cache", + True, True, True, True) + + + def test_inductor(self): + batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16 + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.reshape_linear_cache, args) + + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.reshape_linear_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestReshapeLinearCache)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_paged_cache.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_paged_cache.py new file mode 100755 index 0000000..f1aa6bb --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_reshape_paged_cache.py @@ -0,0 +1,141 @@ +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * + +def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype): + args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype) + slot_mapping = args[11] + args = args[0:4] + [slot_mapping] + return args + +class TestReshapePagedCacheOp(BtTestCase): + def op_impl_base(self, *args): + k, v, k_cache, v_cache, slot_mapping = args + num_tokens = k.shape[0] + block_size = k_cache.shape[2] + for i in range(num_tokens): + if slot_mapping[i] >= 0: + block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + k_cache[block_id, :, block_offset, :] = k[i] + if v is not None: + v_cache[block_id, :, block_offset, :] = v[i] + return (k_cache, v_cache) if v is not None else k_cache + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device") + def test_reshape_paged_cache(self): + test_cases = 100 + num_tokens_list = torch.randint(low=1, high=1024, size=(test_cases, ), dtype=torch.int32) + num_heads_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32) + head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32) + head_size_list *= 16 + block_size_list = torch.randint(low=1, high=4, size=(test_cases, ), dtype=torch.int32) + block_size_list *= 16 + only_reshape_key_list = [True, False] + + for i in range(test_cases): + num_tokens = num_tokens_list[i] + num_heads = num_heads_list[i] + head_size = head_size_list[i] + block_size = block_size_list[i] + min_blocks = (num_tokens + block_size - 1) // block_size + num_blocks = min(min_blocks + 10, 2 * min_blocks) + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + only_reshape_key = random.choice(only_reshape_key_list) + for dtype in dtype_list: + print("num_tokens: {}, num_heads: {}, head_size: {}, num_blocks: {}, block_size: {}, testing...".format( + num_tokens, num_heads, head_size, num_blocks, block_size), flush=True) + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device="mlu") + _, key, value = qkv.unbind(dim=1) + + key_cache = torch.randn(num_blocks, num_heads, block_size, head_size, dtype=dtype, device="mlu") + value_cache = torch.randn(num_blocks, num_heads, block_size, head_size, dtype=dtype, device="mlu") + + num_slots = num_blocks * block_size + slot_mapping = random.sample(range(num_slots.item()), num_tokens.item()) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="mlu") + slot_mapping[-1] = -1 + + ref_key_cache, ref_value_cache = key_cache.clone(), value_cache.clone() + if only_reshape_key: + value, ref_value_cache, value_cache = None, None, None + self.op_impl_base(key, value, ref_key_cache, ref_value_cache, slot_mapping) + ops.reshape_paged_cache(key, value, key_cache, value_cache, slot_mapping) + self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + if not only_reshape_key: + self.assertTensorsEqual(value_cache.cpu().float(), ref_value_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_large_tensor(self): + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + print("reshape_paged_cache: test_large_tensor") + if dtype == torch.float: + dtype_size = 4 + elif dtype == torch.half or dtype == torch.bfloat16: + dtype_size = 2 + head_num = 16 + head_size = 128 + token_num = 20 + block_size = 16 + block_num = ((2**32 - 1) // dtype_size // head_num // head_size // block_size) + k = torch.randn(token_num, head_num, head_size, dtype=dtype, device="mlu") + v = torch.randn(token_num, head_num, head_size, dtype=dtype, device="mlu") + k_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu") + v_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu") + num_slots = block_num * block_size + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_num] + ref_key_cache, ref_value_cache = k_cache.clone(), v_cache.clone() + self.op_impl_base(k, v, ref_key_cache, ref_value_cache, slot_mapping) + ops.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping) + for i in range(block_size): + self.assertTensorsEqual(k_cache[:, :, i, :].cpu().float(), ref_key_cache[:, :, i, :].cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + self.assertTensorsEqual(v_cache[:, :, i, :].cpu().float(), ref_value_cache[:, :, i, :].cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + block_num = (2**32 // dtype_size // head_num // head_size // block_size) + k_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu") + v_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu") + num_slots = block_num * block_size + slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_num] + self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.reshape_paged_cache, + k, v, k_cache, v_cache, slot_mapping) + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device") + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + k = torch.randn(1024, 8, 128, dtype=torch.half, device="mlu") + v = torch.randn(1024, 8, 128, dtype=torch.half, device="mlu") + k_cache = torch.randn(1024, 8, 4, 128, dtype=torch.half, device="mlu") + v_cache = None + slot_mapping = random.sample(range(1024 * 4), 1024) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="mlu") + self.assertException("v.has_value() == v_cache.has_value().", ops.reshape_paged_cache, + k, v, k_cache, v_cache, slot_mapping) + v_cache = torch.randn(1024, 8, 4, 128, dtype=torch.half, device="mlu") + v_cache = v_cache.as_strided(size=(1024, 8, 4, 128), stride=(8*4*128, 4* 128, 127, 1)) + self.assertException("v_cache need be contiguous.", ops.reshape_paged_cache, + k, v, k_cache, v_cache, slot_mapping) + v_cache = v_cache.contiguous() + v = v.as_strided(size=(1024, 8, 128), stride=(1, 1024, 1024 * 8)) + self.assertException("v last dim must be contiguous.", ops.reshape_paged_cache, + k, v, k_cache, v_cache, slot_mapping) + v = v.as_strided(size=(1024, 8, 128), stride=(1024, 8, 1)) + self.assertException("v second dim must be contiguous.", ops.reshape_paged_cache, + k, v, k_cache, v_cache, slot_mapping) + + + @unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device") + def test_inductor(self): + batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16 + args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype) + self.base_opcheck(torch.ops.torch_mlu_ops.reshape_paged_cache, args) + +if __name__ == '__main__': + exit(run_unittest(TestReshapePagedCacheOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_self_attn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_self_attn.py new file mode 100755 index 0000000..225850c --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_self_attn.py @@ -0,0 +1,158 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import torch.nn as nn + +class SelfAttn(torch.nn.Module): + def __init__( + self, + qkv_weights, + qkv_biass, + o_weight, + o_bias, + norm_weight, + norm_bias, + input_size, + head_size, + query_factor, + eps = 1e-5 + ) -> None: + super().__init__() + assert len(qkv_weights) == 3 and len(qkv_biass) == 3, 'length of weights and biass must be 3' + self.layernorm = torch.nn.LayerNorm(input_size) + self.layernorm.eps = eps + self.layernorm.weight = nn.Parameter(norm_weight) + self.layernorm.bias = nn.Parameter(norm_bias) + self.weights = qkv_weights + self.biass = qkv_biass + self.o_weight = o_weight + self.o_bias = o_bias + self.head_size = head_size + self.head_num = self.weights[0].size(0) // head_size + self.query_factor = query_factor + + def forward(self, input: torch.Tensor): + n = input.size(0) + t = input.size(1) + normed_input = self.layernorm(input) + q = F.linear(normed_input, nn.Parameter(self.weights[0]), nn.Parameter(self.biass[0])).view(n, t, self.head_num, self.head_size) + k = F.linear(normed_input, nn.Parameter(self.weights[1]), nn.Parameter(self.biass[1])).view(n, t, self.head_num, self.head_size) + v = F.linear(normed_input, nn.Parameter(self.weights[2]), nn.Parameter(self.biass[2])).view(n, t, self.head_num, self.head_size) + qk = torch.einsum('bthd,bshd->bhts', q, k) * self.query_factor + attn = torch.softmax(qk, dim=-1, dtype=v.dtype) + qkv = torch.einsum('bhts,bshd->bthd', attn, v).reshape(n, t, -1) + output = F.linear(qkv, nn.Parameter(self.o_weight), nn.Parameter(self.o_bias)) + input + return output + +class BTSelfAttn(torch.nn.Module): + def __init__( + self, + qkv_weights, + qkv_biass, + o_weight, + o_bias, + norm_weight, + norm_bias, + head_size, + query_factor, + eps = 1e-5 + ) -> None: + super().__init__() + self.weights = qkv_weights + self.biass = qkv_biass + self.o_weight = o_weight + self.o_bias = o_bias + self.norm_weight = norm_weight + self.norm_bias = norm_bias + self.head_size = head_size + self.query_factor = query_factor + self.eps = eps + + def forward(self, input: torch.Tensor): + n, t = input.size(0), input.size(1) + q, k, v = ops.fused_norm_attention_project(input, + self.weights[0], + self.biass[0], + self.weights[1], + self.biass[1], + self.weights[2], + self.biass[2], + self.norm_weight, + self.norm_bias, + self.eps, + "nhtc", + self.head_size, + False) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + attn_out = ops.flash_attention(q, k, v, None, None, None, None, None, t, t, + self.query_factor, False, -1, -1, q.dtype).flatten(-2, -1) + output = ops.attention_project(attn_out, self.o_weight, self.o_bias, input) + return output + +dtype_list = [torch.half, torch.float] +if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) +class TestSelfAttn(BtTestCase): + def op_impl_base(self, *args): + return super().op_impl_base(*args) + + def test_self_attn(self): + N, T, input_size, hidden_size, head_size, eps, query_factor = 5, 2048, 512, 768, 64, 1e-5, 0.125 + for dtype in dtype_list: + print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format( + N, T, input_size, hidden_size, dtype), flush=True) + input = torch.randn(N, T, input_size, dtype=dtype, device="mlu") + weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") / 10 + bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu") + norm_weight = torch.randn(input_size, dtype=dtype, device="mlu") + norm_bias = torch.randn(input_size, dtype=dtype, device="mlu") + weights = torch.chunk(weight, 3) + biass = torch.chunk(bias, 3) + o_weight = torch.randn(input_size, hidden_size, dtype=dtype, device="mlu") / 10 + o_bias = torch.randn(input_size, dtype=dtype, device="mlu") + + torch_self_attn = SelfAttn(weights, biass, o_weight, o_bias, norm_weight, norm_bias, + input_size, head_size, query_factor, eps) + tmo_self_attn = BTSelfAttn(weights, biass, o_weight, o_bias, norm_weight, norm_bias, + head_size, query_factor, eps) + + # test self_attn + print("test self_attn...") + torch_out = torch_self_attn(input) + tmo_out = tmo_self_attn(input) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.011, use_MSE=True, use_RAE=True) + + def test_sd_attention(self): + N, Tq, Tk, hq, hk, head_size, query_factor = 2, 4096, 77, 8, 8, 40, 0.125 + for dtype in dtype_list: + q = torch.randn(N, Tq, hq, head_size, dtype=dtype, device="mlu") + k = torch.randn(N, Tk, hk, head_size, dtype=dtype, device="mlu") + v = torch.randn(N, Tk, hk, head_size, dtype=dtype, device="mlu") + + qk = torch.einsum('bthd,bshd->bhts', q, k) * query_factor + attn = torch.softmax(qk, dim=-1, dtype=v.dtype) + torch_out = torch.einsum('bhts,bshd->bthd', attn, v).reshape(N, Tq, -1) + + qt = q.transpose(1, 2).contiguous() + kt = k.transpose(1, 2).contiguous() + vt = v.transpose(1, 2).contiguous() + tmo_out = ops.flash_attention(qt.transpose(1, 2), + kt.transpose(1, 2), + vt.transpose(1, 2), + None, None, None, None, None, + Tq, Tk, query_factor, False).flatten(-2, -1) + self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(), + 0.004, use_MSE=True, use_RAE=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestSelfAttn)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_session_cache_attn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_session_cache_attn.py new file mode 100644 index 0000000..98a4c8f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_session_cache_attn.py @@ -0,0 +1,348 @@ +import math +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +class TestSessionCacheAttnOp(BtTestCase): + def op_impl_base(self, *args): + def dequant_from_cache(key_cache, value_cache, key_cache_scale, value_cache_scale, cache_bs_id, cache_lens, + cache_seq_offset, quant_bit, quant_layout, max_cache_len): + batch = cache_bs_id.size(0) + head_num = key_cache.size(1) + head_size = value_cache.size(-1) + cache_len = key_cache.size(-2) + cache_shape = (batch, max_cache_len, head_num, head_size) + key_cache_mem = torch.zeros(cache_shape, dtype=torch.float, device='mlu') + value_cache_mem = torch.zeros_like(key_cache_mem) + for i in range(batch): + batch_id = cache_bs_id[i] + cache_offset = 0 if cache_seq_offset is None else cache_seq_offset[i] + key_cache_data = key_cache[batch_id] # head_num_kv, cache_mem_len, head_size + value_cache_data = value_cache[batch_id] + if quant_bit == 4: + key_quant_data_int8 = torch.zeros(head_num, cache_len, head_size, dtype=torch.int8, device="mlu") + value_quant_data_int8 = torch.zeros(head_num, cache_len, head_size, dtype=torch.int8, device="mlu") + key_quant_data_int8[...,::2] = key_cache_data << 4 >> 4 + key_quant_data_int8[...,1::2] = key_cache_data >> 4 + key_quant_data_fp32 = key_quant_data_int8.clone().to(torch.float) + value_quant_data_int8[:,0::2,:] = value_cache_data << 4 >> 4 + value_quant_data_int8[:,1::2,:] = value_cache_data >> 4 + value_quant_data_fp32 = value_quant_data_int8.clone().to(torch.float) + else: + key_quant_data_fp32 = key_cache_data.clone().to(torch.float) + value_quant_data_fp32 = value_cache_data.clone().to(torch.float) + + if quant_layout == 'per_channel': + key_quant_data_fp32 = key_quant_data_fp32 * key_cache_scale[..., None, :] + value_quant_data_fp32 = value_quant_data_fp32 * value_cache_scale[..., None, :] + else: # per token + key_cache_scale_data = key_cache_scale[batch_id] + value_cache_scale_data = value_cache_scale[batch_id] + key_quant_data_fp32 = key_quant_data_fp32 * key_cache_scale_data[..., None] + value_quant_data_fp32 = value_quant_data_fp32 * value_cache_scale_data[..., None] + key_quant_data_fp32 = key_quant_data_fp32[:, cache_offset:cache_offset+max_cache_len, :] #cut vaild cache + value_quant_data_fp32 = value_quant_data_fp32[:, cache_offset:cache_offset+max_cache_len, :] + # head_num_kv, max_cache_len, headsize-> max_cache_len, head_num_kv, headsize + key_cache_mem[i] = key_quant_data_fp32.transpose(0,1) + value_cache_mem[i] = value_quant_data_fp32.transpose(0,1) + return key_cache_mem, value_cache_mem + def scale_dot_attn(q_sess, key_cache, value_cache, cu_seq_lens_q, cu_seq_lens_cache, alibi_slope, is_causal, softmax_scale): + # cache_shape: [batch, cache_len, head_kv,head_size] + batch = cu_seq_lens_q.size(0) - 1 + head_num_q = q_sess.size(-2) + head_num_kv = key_cache.size(-2) + assert head_num_q >= head_num_kv and head_num_q % head_num_kv == 0 + group = head_num_q // head_num_kv + inf = 1e6 + device= 'mlu' + out_list = [] + for i in range(batch): + q = q_sess[cu_seq_lens_q[i]:cu_seq_lens_q[i+1], ...] + k = key_cache[cu_seq_lens_cache[i]:cu_seq_lens_cache[i+1], ...] # [cache_len, head_num_kv, head_size] + v = value_cache[cu_seq_lens_cache[i]:cu_seq_lens_cache[i+1], ...] + k = torch.repeat_interleave(k, group, dim=1) #[cache_len, head_num_q, head_size] + v = torch.repeat_interleave(v, group, dim=1) + qk = torch.einsum('qhd,khd->hqk', q, k) * softmax_scale + seq_q, seq_k = q.size(0), k.size(0) + if alibi_slope is not None: + slope = alibi_slope.reshape(1, head_num_q, 1, 1) + slope_bias = torch.zeros(1, head_num_q, seq_q, seq_k).to(device=device) + if is_causal: + relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).to(device=device) + slope_bias = relative_pos * slope + else: + row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1) + col_idx = torch.arange(seq_k, dtype=torch.long) + relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).to(device=device) + slope_bias = -slope * relative_pos.to(dtype=slope.dtype) + qk += (slope_bias.squeeze(0)) + if is_causal: + assert seq_q <= seq_k, "seq_q <= seq_k if causal=True" + zeros = torch.zeros(seq_q, seq_k-seq_q, dtype=torch.float, device="mlu") + tri = torch.full((seq_q, seq_q), -inf, dtype=torch.float, device="mlu").triu(diagonal=1) + mask = torch.cat([zeros, tri], dim=1) # (q, k-q) + (q, q) => (q, k) + qk += mask + attn = torch.softmax(qk, dim=-1, dtype=torch.float).to(q.dtype) + qkv = torch.einsum('hqk,khd->qhd', attn, v) + out_list.append(qkv) + output = torch.cat(out_list, dim=0) + return output + + q_sess, k_sess, v_sess, key_cache1, value_cache1, key_cache_scale1, value_cache_scale1, cache_lens1,\ + cache_seq_offset1, quant_bit1, quant_layout1, max_cache_len1,\ + key_cache2, value_cache2, key_cache_scale2, value_cache_scale2, cache_lens2,\ + cache_seq_offset2, quant_bit2, quant_layout2, max_cache_len2,\ + sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale = args + #1. dequant cache + #input cache shape [max_batch, head_num_kv, cache_mem_len, head_size] + #output cache shape [batch, max_cache_len, head_num_kv, head_size] + key_cache_mem1, value_cache_mem1 = dequant_from_cache(key_cache1, value_cache1, key_cache_scale1, + value_cache_scale1, cache_bs_id, cache_lens1, cache_seq_offset1, + quant_bit1, quant_layout1, max_cache_len1) + key_cache_mem2, value_cache_mem2 = None, None + if key_cache2 is not None: + key_cache_mem2, value_cache_mem2 = dequant_from_cache(key_cache2, value_cache2, key_cache_scale2, + value_cache_scale2, cache_bs_id, cache_lens2, cache_seq_offset2, + quant_bit2, quant_layout2, max_cache_len2) + #2. concat cache + batch = cache_bs_id.size(0) + head_num_kv = k_sess.size(-2) + head_size = k_sess.size(-1) + # concat cache1 and cache2 + if key_cache2 is not None: + cache_lens = cache_lens1 + cache_lens2 + max_cache_len = max_cache_len1 + max_cache_len2 + key_cache_mem = torch.zeros(batch, max_cache_len, head_num_kv, head_size, dtype=torch.float, device='mlu') + value_cache_mem = torch.zeros(batch, max_cache_len, head_num_kv, head_size, dtype=torch.float, device='mlu') + for i in range(batch): #concat + len1 = cache_lens1[i] + len2 = cache_lens2[i] + key_cache_mem[i, :len1, ...] = key_cache_mem1[i, :len1, ...] + key_cache_mem[i, len1:len1+len2, ...] = key_cache_mem2[i, :len2, ...] + value_cache_mem[i, :len1, ...] = value_cache_mem1[i, :len1, ...] + value_cache_mem[i, len1:len1+len2, ...] = value_cache_mem2[i, :len2, ...] + else: + key_cache_mem, value_cache_mem = key_cache_mem1, value_cache_mem1 + cache_lens = cache_lens1 + #concat mem cache and sess cache + concat_cache_lens = cache_lens + sess_lens + cu_seq_lens_cache = torch.zeros((batch+1), dtype=torch.int32) + cu_seq_lens_cache[1:] = torch.cumsum(concat_cache_lens, dim=-1) + total_cache_len = cu_seq_lens_cache[-1] + concate_key_cache = torch.zeros((total_cache_len, head_num_kv, head_size), dtype=torch.float, device='mlu') + concate_value_cache = torch.zeros((total_cache_len, head_num_kv, head_size), dtype=torch.float, device='mlu') + for i in range(batch): #concat + mem_len = cache_lens[i] + sess_len = sess_lens[i] + off = cu_seq_lens_cache[i] + concate_key_cache[off:off+mem_len, ...] = key_cache_mem[i, :mem_len, ...] + concate_key_cache[off+mem_len:off+mem_len+sess_len, ...] = k_sess[i, :sess_len, ...] + concate_value_cache[off:off+mem_len, ...] = value_cache_mem[i, :mem_len, ...] + concate_value_cache[off+mem_len:off+mem_len+sess_len, ...] = v_sess[i, :sess_len, ...] + #3. attn + attn_out = scale_dot_attn(q_sess, concate_key_cache, concate_value_cache, cu_seq_lens_q, cu_seq_lens_cache, None, is_causal, softmax_scale) + return attn_out, concate_key_cache, concate_value_cache + + def test_session_cache_attn_quant_kv(self): + max_batch = 64 + cache_mem_len = 4096 + head_num_kv = 1 + head_num_q = 8 + head_size = 128 + batch = 32 + max_sess_len = 100 + max_cache_len = 3072 + dtype_list=[torch.float16] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + quant_bit_list = [4, 8] + quant_layout_list = ['per_token', 'per_channel'] + is_causal_list = [True, False] + arg = product(is_causal_list, dtype_list, quant_bit_list, quant_layout_list) + softmax_scale = 1 / math.sqrt(head_size) + for is_causal, dtype, quant_bit, quant_layout in arg: + print(f"is_causal: {is_causal}, dtype: {dtype}, quant_bit: {quant_bit}, quant_layout: {quant_layout}") + if quant_bit == 8: + cache_shape_key = (max_batch, head_num_kv, cache_mem_len, head_size) + cache_shape_value = cache_shape_key + else: + cache_shape_key = (max_batch, head_num_kv, cache_mem_len, head_size//2) + cache_shape_value = (max_batch, head_num_kv, cache_mem_len//2, head_size) + if quant_layout == "per_token": + scale_shape = (max_batch, head_num_kv, cache_mem_len) + quant_mode = 1 + else: # per channel + scale_shape = (head_num_kv, head_size) + quant_mode = 0 + key_cache = torch.randint(-127, 128, cache_shape_key, device='mlu').to(torch.int8) + value_cache = torch.randint(-127, 128, cache_shape_value, device='mlu').to(torch.int8) + key_cache_scale = torch.randn(scale_shape, device='mlu', dtype=torch.float) + value_cache_scale = torch.randn(scale_shape, device='mlu', dtype=torch.float) + key_cache_scale = torch.fill(key_cache_scale, 0.01) + value_cache_scale = torch.fill(value_cache_scale, 0.01) + cache_lens = torch.randint(1, max_cache_len + 1, (batch,), dtype=torch.int32, device='mlu') + sess_lens = torch.randint(1, max_sess_len + 1, (batch,), dtype=torch.int32, device='mlu') + context_lens = cache_lens + sess_lens + max_cache_len_new = torch.max(context_lens) + cu_seq_lens_q = torch.zeros(batch+1, dtype=torch.int32) + cu_seq_lens_q[1:] = torch.cumsum(sess_lens, dim=-1) + total_sess_len = cu_seq_lens_q[-1] + cu_seq_lens_q=cu_seq_lens_q.mlu() + cache_bs_id = random.sample([*range(0, max_batch)], batch) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() #block_tables + mem_cache_seq_offset = torch.randint(0, cache_mem_len-max_cache_len, (batch,), dtype=torch.int32, device='mlu') + + q_sess = torch.randn(total_sess_len, head_num_q, head_size, dtype=dtype, device='mlu') + k_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu') + v_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu') + #fill sess cache + cache_seq_offset = torch.zeros((batch + 1), dtype=torch.int32) + cache_seq_offset[1:]= torch.cumsum(context_lens, dim=-1) + cache_seq_offset = cache_seq_offset.mlu() + context_seq_offset = cache_seq_offset[:-1] + sess_seq_offset = context_seq_offset + cache_lens + total_mem_len = cache_seq_offset[-1] + cache_shape_mem = (total_mem_len, head_num_kv, head_size)#NT,H,C + key_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu') + value_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu') + for i in range(batch): + offset = sess_seq_offset[i] + len = sess_lens[i] + key_cache_mem[offset:offset+len, ...] = k_sess[i, :len, ...] + value_cache_mem[offset:offset+len, ...] = v_sess[i, :len, ...] + + # baseline + base_output, base_key_cache, base_value_cache = self.op_impl_base(q_sess.float(), k_sess.float(), v_sess.float(), + key_cache, value_cache, key_cache_scale, value_cache_scale, cache_lens, mem_cache_seq_offset, + quant_bit, quant_layout, max_cache_len, + None, None, None, None, None, None, -1, None, -1, + sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale) + + #tmo + #1. dequant cache + ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache, value_cache, key_cache_scale, + value_cache_scale, cache_lens, max_cache_len, context_seq_offset, + cache_bs_id, mem_cache_seq_offset, quant_mode, quant_bit) + self.assertTensorsEqual(base_key_cache.cpu().float(), key_cache_mem.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_value_cache.cpu().float(), value_cache_mem.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + #2. flash_attn + tmo_output = ops.flash_attention(q_sess, key_cache_mem, value_cache_mem, None, cu_seq_lens_q, cache_seq_offset, + None, None, max_sess_len, max_cache_len_new, softmax_scale, + is_causal, -1, -1, torch.float, False, None, None, None) + self.assertTensorsEqual(base_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_session_cache_attn_kv_mixquant(self): + max_batch = 64 + cache_mem_len_int4 = 4064 + cache_mem_len_int8 = 32 + head_num_kv = 1 + head_num_q = 8 + head_size = 128 + batch = 32 + max_sess_len = 100 + max_cache_len_int4 = 3072 + max_cache_len_int8 = 32 + dtype_list=[torch.float16] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + is_causal_list = [True, False] + arg = product(is_causal_list, dtype_list) + softmax_scale = 1 / math.sqrt(head_size) + for is_causal, dtype in arg: + print(f"is_causal: {is_causal}, dtype: {dtype}") + cache_shape_int8 = (max_batch, head_num_kv, cache_mem_len_int8, head_size) + cache_shape_int4_key = (max_batch, head_num_kv, cache_mem_len_int4, head_size//2) + cache_shape_int4_value = (max_batch, head_num_kv, cache_mem_len_int4//2, head_size) + scale_shape_int4 = (max_batch, head_num_kv, cache_mem_len_int4) #per_token + quant_mode_int4 = 1 + scale_shape_int8 = (head_num_kv, head_size) # per channel + quant_mode_int8 = 0 + key_cache_int8 = torch.randint(-127, 128, cache_shape_int8, device='mlu').to(torch.int8) + value_cache_int8 = torch.randint(-127, 128, cache_shape_int8, device='mlu').to(torch.int8) + key_cache_int4 = torch.randint(-127, 128, cache_shape_int4_key, device='mlu').to(torch.int8) + value_cache_int4 = torch.randint(-127, 128, cache_shape_int4_value, device='mlu').to(torch.int8) + key_cache_scale_int4 = torch.randn(scale_shape_int4, device='mlu', dtype=torch.float) + value_cache_scale_int4 = torch.randn(scale_shape_int4, device='mlu', dtype=torch.float) + key_cache_scale_int8 = torch.randn(scale_shape_int8, device='mlu', dtype=torch.float) + value_cache_scale_int8 = torch.randn(scale_shape_int8, device='mlu', dtype=torch.float) + key_cache_scale_int4 = torch.fill(key_cache_scale_int4, 0.01) + value_cache_scale_int4 = torch.fill(value_cache_scale_int4, 0.01) + key_cache_scale_int8 = torch.fill(key_cache_scale_int8, 0.01) + value_cache_scale_int8 = torch.fill(value_cache_scale_int8, 0.01) + cache_lens_int4 = torch.randint(1, max_cache_len_int4 + 1, (batch,), dtype=torch.int32, device='mlu') + cache_lens_int8 = torch.randint(1, max_cache_len_int8 + 1, (batch,), dtype=torch.int32, device='mlu') + sess_lens = torch.randint(max_sess_len-1, max_sess_len, (batch,), dtype=torch.int32, device='mlu') + cu_seq_lens_q = torch.zeros(batch+1, dtype=torch.int32) + cu_seq_lens_q[1:] = torch.cumsum(sess_lens, dim=-1) + total_sess_len = cu_seq_lens_q[-1] + cu_seq_lens_q=cu_seq_lens_q.mlu() + context_lens = cache_lens_int4 + cache_lens_int8 + sess_lens + max_cache_len_new = torch.max(context_lens) + cache_bs_id = random.sample([*range(0, max_batch)], batch) + cache_bs_id = torch.IntTensor(cache_bs_id).mlu() #block_tables + cache_seq_offset_int4 = torch.randint(0, cache_mem_len_int4-max_cache_len_int4, (batch,), dtype=torch.int32, device='mlu') + cache_seq_offset_int8 = None + + q_sess = torch.randn(total_sess_len, head_num_q, head_size, dtype=dtype, device='mlu') + k_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu') + v_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu') + #fill sess cache + cache_seq_offset1 = torch.zeros((batch + 1), dtype=torch.int32) + cache_seq_offset2 = torch.zeros((batch + 1), dtype=torch.int32) + cache_seq_offset1[1:]= torch.cumsum(context_lens, dim=-1) + cache_seq_offset2[:-1] = cache_seq_offset1[:-1] + cache_lens_int4.cpu() + cache_seq_offset1 = cache_seq_offset1.mlu() + cache_seq_offset2 = cache_seq_offset2.mlu() + + context_seq_offset1 = cache_seq_offset1[:-1] + context_seq_offset2 = cache_seq_offset2[:-1] + sess_seq_offset = context_seq_offset2 + cache_lens_int8 + total_mem_len = cache_seq_offset1[-1] + cache_shape_mem = (total_mem_len, head_num_kv, head_size)#NT,H,C + key_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu') + value_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu') + for i in range(batch): + offset = sess_seq_offset[i] + len = sess_lens[i] + key_cache_mem[offset:offset+len, ...] = k_sess[i, :len, ...] + value_cache_mem[offset:offset+len, ...] = v_sess[i, :len, ...] + + # baseline + base_output, base_key_cache, base_value_cache = self.op_impl_base(q_sess.float(), k_sess.float(), v_sess.float(), + key_cache_int4, value_cache_int4, key_cache_scale_int4, value_cache_scale_int4, cache_lens_int4, + cache_seq_offset_int4, 4, 'per_token', max_cache_len_int4, + key_cache_int8, value_cache_int8, key_cache_scale_int8, value_cache_scale_int8, cache_lens_int8, + cache_seq_offset_int8, 8, 'per_channel', max_cache_len_int8, + sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale) + #tmo + #1. dequant cache + ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache_int4, value_cache_int4, + key_cache_scale_int4, value_cache_scale_int4, cache_lens_int4, max_cache_len_int4, + context_seq_offset1, cache_bs_id, cache_seq_offset_int4, quant_mode_int4, 4) + ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache_int8, value_cache_int8, + key_cache_scale_int8, value_cache_scale_int8, cache_lens_int8, max_cache_len_int8, + context_seq_offset2, cache_bs_id, cache_seq_offset_int8, quant_mode_int8, 8) + + self.assertTensorsEqual(base_key_cache.cpu().float(), key_cache_mem.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(base_value_cache.cpu().float(), value_cache_mem.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + #2. flash_attn + tmo_output = ops.flash_attention(q_sess, key_cache_mem, value_cache_mem, None, cu_seq_lens_q, cache_seq_offset1, + None, None, max_sess_len, max_cache_len_new, softmax_scale, + is_causal, -1, -1, torch.float, False, None, None, None) + self.assertTensorsEqual(base_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestSessionCacheAttnOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_cached_kv_attn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_cached_kv_attn.py new file mode 100755 index 0000000..773761f --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_cached_kv_attn.py @@ -0,0 +1,454 @@ +import math +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +def gen_args(batch, head_size, is_pagedattn, has_alibi, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads): + input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size)).mlu().half() + input_q = input_qkv[:,0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size) + context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu() + max_context_len = int(max(context_lens)) + if is_pagedattn: + block_size = 16 + else: + block_size = max_seqlen + num_blocks = batch * ((max_seqlen + block_size - 1) // block_size) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape = (num_blocks, num_kv_heads, block_size, head_size) + scale_shape = (num_blocks, num_kv_heads, block_size) + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + if kv_data_type is not torch.int8: + key_cache = torch.randn(size=cache_shape, dtype=torch.float16).mlu() + value_cache = torch.randn(size=cache_shape, dtype=torch.float16).mlu() + key_cache_scale = None + value_cache_scale = None + else: + key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_data_type).mlu() + value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_data_type).mlu() + key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + alibi_slopes = None + if has_alibi: + alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu() + alibi_slopes.uniform_(0, 0.125) + softmax_scale = 1 / math.sqrt(head_size) + return (input_q.contiguous(), key_cache, value_cache, torch.empty_like(input_q), block_tables, context_lens, None, + key_cache_scale, value_cache_scale, alibi_slopes, max_context_len, -1, -1, softmax_scale, False, -1) + +def gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, dtype, kv_dtype, max_seqlen, seq_q, head_num, num_kv_heads): + input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype = dtype).mlu() + input_q = input_qkv[:,0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size) + context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu() + max_context_len = int(max(context_lens)) + if is_pagedattn: + block_size = 16 + else: + block_size = max_seqlen + num_blocks = batch * ((max_seqlen + block_size - 1) // block_size) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size) + cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v) + scale_shape = (num_blocks, num_kv_heads, block_size) + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + if kv_dtype is torch.int8: + key_cache = torch.zeros(cache_shape_k).uniform_(-128, 127).to(torch.int8).mlu() + value_cache = torch.zeros(cache_shape_v).uniform_(-128, 127).to(torch.int8).mlu() + key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + else: + key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu() + value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu() + key_cache_scale = None + value_cache_scale = None + alibi_slopes = None + if has_alibi: + alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu() + alibi_slopes.uniform_(0, 0.125) + return input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len + +class TestSingleQueryAttnOp(BtTestCase): + def op_impl_base(self, *args): + q, k_cache, v_cache, out, block_tables, context_lens, k_cache_quant_scale, v_cache_quant_scale, \ + alibi_slopes, max_contxt_len, windows_size_left, windows_size_right, softmax_scale, return_lse, \ + kv_cache_quant_bit_size = args + base_output = single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale, + v_cache_quant_scale, alibi_slopes, windows_size_left, windows_size_right, softmax_scale, return_lse) + return base_output + + def test_single_query_attention(self): + head_num = 16 + batch_list = [5, 12] + num_kv_heads = 4 + head_size_list = [(128, 128), (192, 384)] + seq_len_list = [512] + is_pagedattn_list = [False, True] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + is_pagedattn_list = [False] + has_alibi_list = [True, False] + seq_q_list = [1, 5] + window_size_list = [(-1, -1), (10, -1)] + data_type_list = [torch.float16, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + data_type_list.append(torch.bfloat16) + + args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, seq_len_list, seq_q_list, window_size_list, data_type_list) + for batch, (head_size, head_size_v), is_pagedattn, has_alibi, max_seqlen, seq_q, (window_size_left, window_size_right), dtype in args: + print("batch: {}, max_seqlen: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, seq_q {}, window_size_left {},\ +window_size_right {}, dtype {}, testing...".format( + batch, max_seqlen, head_size, head_size_v, is_pagedattn, has_alibi, seq_q, window_size_left, window_size_right, dtype)) + # prepare input + params = gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, dtype, dtype, max_seqlen, seq_q, head_num, num_kv_heads) + input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params + softmax_scale = 1 / math.sqrt(head_size) + torch_output = self.op_impl_base(input_q, key_cache, value_cache, + None, block_tables, context_lens, key_cache_scale, + value_cache_scale, alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, False, -1) + tmo_output1 = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache, + None, block_tables, context_lens, key_cache_scale, + value_cache_scale, alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output1.cpu().float(), + 0.003, use_MSE=True) + if seq_q == 1: + torch_output, torch_lse = self.op_impl_base(input_q, key_cache, value_cache, + None, block_tables, context_lens, key_cache_scale, + value_cache_scale, alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + tmo_output, tmo_lse = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache, + None, block_tables, context_lens, key_cache_scale, + value_cache_scale, alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, True) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True) + self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), + 0.0001, use_MSE=True) + + # @unittest.skip("not test") + def test_single_query_attention_quantize_kv(self): + head_num = 16 + batch_list = [5, 12] + num_kv_heads = 4 + head_size_list = [(128, 128), (16, 384)] + seq_len_list = [512] + is_pagedattn_list = [False, True] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + is_pagedattn_list = [False] + has_alibi_list = [True, False] + quant_mode_list = ['per_token', 'per_channel'] + kv_data_type_list = [torch.int8] + seq_q_list = [1, 5] + window_size_list = [(-1, -1), (10, -1)] + args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, kv_data_type_list, quant_mode_list, seq_len_list, seq_q_list, window_size_list) + for batch, (head_size, head_size_v), is_pagedattn, has_alibi, kv_data_type, quant_mode, max_seqlen, seq_q, (window_size_left, window_size_right) in args: + print("batch: {}, max_seqlen: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, kv_datatype {}, \ +quant_mode {}, seq_q {}, window_size_left {}, window_size_right {}, testing...".format( + batch, max_seqlen, head_size, head_size_v, is_pagedattn, has_alibi, kv_data_type, quant_mode, seq_q, window_size_left, window_size_right)) + # prepare input + params = gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, torch.float16, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads) + input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params + softmax_scale = 1 / math.sqrt(head_size) + torch_output_contiguous = self.op_impl_base(input_q.contiguous(), + key_cache, value_cache, None, block_tables, context_lens, + key_cache_scale, value_cache_scale, alibi_slopes, + max_context_len, window_size_left, window_size_right, softmax_scale, False, -1) + tmo_output_contiguous = ops.single_query_cached_kv_attn(input_q.contiguous(), + key_cache, value_cache, None, block_tables, context_lens, + key_cache_scale, value_cache_scale, alibi_slopes, + max_context_len, window_size_left, window_size_right, softmax_scale) + self.assertTensorsEqual(torch_output_contiguous.cpu().float(), tmo_output_contiguous.cpu().float(), + 0.003, use_MSE=True) + if seq_q == 1: + torch_output, torch_lse = self.op_impl_base(input_q, key_cache, value_cache, + None, block_tables, context_lens, key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, window_size_left, window_size_right, softmax_scale, True, -1) + tmo_output, tmo_lse = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache, + None, block_tables, context_lens, key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, window_size_left, window_size_right, softmax_scale, True) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.005, use_MSE=True) + self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), 0.0003, use_MSE=True) + + def test_single_query_attention_int4_kv(self): + head_num = 16 + batch_list = [5, 12] + num_kv_heads = 4 + head_size_list = [(64, 128), (256, 128), (64, 384)] + seq_len_list = [512] + is_pagedattn_list = [False, True] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + is_pagedattn_list = [False] + has_alibi_list = [True, False] + quant_mode_list = ['per_token', 'per_channel', 'per_token_group'] + data_type_list = [torch.float, torch.half] + if torch_mlu.mlu.is_bf16_supported(): + data_type_list.append(torch.bfloat16) + kv_data_type = torch.int8 + seq_q_list = [1, 5] + #int4 range + quant_bit = 4 + int_max = float(2 ** (quant_bit - 1) - 1) + int_min = -float(2 ** (quant_bit - 1)) + window_size_list = [(-1, -1), (20, -1)] + args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, quant_mode_list, seq_len_list, seq_q_list, window_size_list, data_type_list) + for batch, (head_size, head_size_v), is_pagedattn, has_alibi, quant_mode, max_seqlen, seq_q, (window_size_left, window_size_right), data_type in args: + print("kv4: batch: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, quant_mode {}, max_seqlen: {}, seq_q {}, \ +window_size_left {}, window_size_right {}, data_type {}, testing...".format( + batch, head_size, head_size_v, is_pagedattn, has_alibi, quant_mode, max_seqlen, seq_q, window_size_left, window_size_right, data_type)) + # prepare input + input_qkv = torch.randn((batch, seq_q, 3 * head_num, head_size), dtype=data_type).mlu() + input_q = input_qkv[..., 0:head_num,:] + + context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu() + max_context_len = context_lens.max().item() + if max_seqlen % 2 == 1: + max_seqlen += 1 + if is_pagedattn: + block_size = 16 + else: + block_size = max_seqlen + num_blocks = batch * ((max_seqlen + block_size - 1) // block_size) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size) + cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v) + cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, int(head_size/2)) + cache_shape_v_int4 = (num_blocks, num_kv_heads, int(block_size/2), head_size_v) + cache_shape_v_int4_tmp = (num_blocks, num_kv_heads, head_size_v, int(block_size/2)) + + if quant_mode == "per_channel": + scale_shape_k = (num_kv_heads, head_size) + scale_shape_v = (num_kv_heads, head_size_v) + elif quant_mode == "per_token": + scale_shape_k = (num_blocks, num_kv_heads, block_size) + scale_shape_v = (num_blocks, num_kv_heads, block_size) + elif quant_mode == "per_token_group": + scale_shape_k = (num_blocks, num_kv_heads, block_size, 1) #group_size = head_size + scale_shape_v = (num_blocks, num_kv_heads, block_size, 1) + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + + key_cache = torch.zeros(cache_shape_k).uniform_(int_min, int_max).to(kv_data_type).mlu() + value_cache = torch.zeros(cache_shape_v).uniform_(int_min, int_max).to(kv_data_type).mlu() + key_cache_scale = torch.randn(size=scale_shape_k, dtype=torch.float32).mlu() + value_cache_scale = torch.randn(size=scale_shape_v, dtype=torch.float32).mlu() + key_cache_view = key_cache.reshape(-1, head_size) + value_cache_view = value_cache.transpose(2, 3).reshape(-1, block_size) + key_cache_int4 = PairlyPackInt8(key_cache_view).view(cache_shape_k_int4) + value_cache_int4 = PairlyPackInt8(value_cache_view).view(cache_shape_v_int4_tmp).transpose(2,3).contiguous() + + alibi_slopes = None + if has_alibi: + alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu() + alibi_slopes.uniform_(0, 0.125) + softmax_scale = 1 / math.sqrt(head_size) + + torch_output = self.op_impl_base(input_q.contiguous(), + key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, False, 4) + tmo_output_contigous = ops.single_query_cached_kv_attn(input_q.contiguous(), + key_cache_int4, value_cache_int4, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, False, 4) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output_contigous.cpu().float(), + 0.003, use_MSE=True) + tmo_output_inplace = torch.empty((batch, seq_q, head_num, head_size_v), dtype=data_type, device="mlu") + ops.single_query_cached_kv_attn(input_q, key_cache_int4, value_cache_int4, + tmo_output_inplace, block_tables, context_lens, + key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, False, 4) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output_inplace.cpu().float(), + 0.003, use_MSE=True) + if seq_q == 1: + torch_output, torch_lse = self.op_impl_base(input_q, + key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, True, 4) + tmo_output1, tmo_lse = ops.single_query_cached_kv_attn(input_q, + key_cache_int4, value_cache_int4, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, True, 4) + self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), + 0.003, use_MSE=True) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output1.cpu().float(), + 0.003, use_MSE=True) + + def test_single_query_attention_inplace(self): + head_num = 16 + batch_list = [5] + num_kv_heads = 4 + head_size_list = [64] + seq_len_list = [512] + is_pagedattn_list = [False, True] + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + is_pagedattn_list = [False] + has_alibi_list = [True, False] + kv_data_type_list = [torch.int8, torch.float16] + seq_q_list = [1, 5] + window_size_list = [(-1, -1), (20, -1)] + args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, kv_data_type_list, seq_len_list, seq_q_list, window_size_list) + for batch, head_size, is_pagedattn, has_alibi, kv_data_type, max_seqlen, seq_q, (window_size_left, window_size_right) in args: + print("batch: {}, max_seqlen: {}, head_size: {}, is_pagedattn: {}, has_alibi {}, kv_datatype {}, seq_q {}, window_size_left {}, window_size_right {}, testing...".format( + batch, max_seqlen, head_size, is_pagedattn, has_alibi, kv_data_type, seq_q, window_size_left, window_size_right)) + # prepare input + params = gen_params(batch, head_size, head_size, is_pagedattn, has_alibi, torch.float16, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads) + input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params + softmax_scale = 1 / math.sqrt(head_size) + tmo_output = torch.empty_like(input_q) + tmo_output_contigous = torch.empty_like(input_q) + torch_output = self.op_impl_base(input_q, key_cache, value_cache, None, + block_tables, context_lens, key_cache_scale, + value_cache_scale, alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale, False, -1) + ops.single_query_cached_kv_attn(input_q, key_cache, value_cache, tmo_output, + block_tables, context_lens, key_cache_scale, + value_cache_scale, alibi_slopes, max_context_len, + window_size_left, window_size_right, softmax_scale) + ops.single_query_cached_kv_attn(input_q.contiguous(), key_cache, value_cache, + tmo_output_contigous, block_tables, context_lens, + key_cache_scale, value_cache_scale, alibi_slopes, + max_context_len, window_size_left, window_size_right, softmax_scale) + self.assertTensorsEqual(tmo_output.cpu().float(), tmo_output_contigous.cpu().float(), + 0.000, use_MSE=True) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True) + + # 防呆测试 + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + func = ops.single_query_cached_kv_attn + batch, seq_q, head_num, num_kv_heads, head_size_qk, head_size_v, max_seqlen, softmax_scale = 5, 1, 8, 3, 64, 128, 512, 0.625 + dtype = torch.float16 + input = torch.randn((batch, seq_q, head_num, head_size_qk), dtype = dtype).mlu() + context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu() + max_context_len = int(max(context_lens)) + block_size = 16 + num_blocks = batch * ((max_seqlen + block_size - 1) // block_size) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size_qk) + cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v) + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu() + value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu() + key_cache_scale = None + value_cache_scale = None + window_size_left, window_size_right = 10, 10 + self.assertException("only support windows_size_right < 0 currently.", + func, input, key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + window_size_right = -1 + self.assertException("num_heads need be mutiple of num_kv_heads.", + func, input, key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + num_kv_heads = 1 + cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size_qk) + cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v) + key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu() + value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu() + self.assertException("illegal quant bit size, only support 4, 8 or -1.", + func, input, key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, 16) + input1 = torch.randn((batch, seq_q, head_num, head_size_qk * 2), dtype = dtype).mlu() + input = input1[..., 0::2] + self.assertException("q last two dim need be contiguous.", + func, input, key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + input = torch.randn((batch, seq_q, head_num, head_size_qk), dtype = dtype).mlu() + key_cache1 = key_cache[..., :8, :] + value_cache1 = value_cache[..., :8, :] + self.assertException("k_cache and v_cache need be contiguous.", + func, input, key_cache1, value_cache1, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + self.assertException("q_ori need be mlu tensor.", + func, input.cpu(), key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + input = torch.randn((batch, 5, head_num, head_size_qk), dtype = dtype).mlu() + self.assertException("return lse only support seq_q = 1 currently.", + func, input, key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + self.assertException("block_tables type need be torch::kInt32 or torch::kLong.", + func, input, key_cache, value_cache, + None, block_tables.float(), context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, False, -1) + self.assertException("context_lens type need be torch::kInt32.", + func, input, key_cache, value_cache, + None, block_tables, context_lens.to(torch.int64), + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, False, -1) + self.assertException("context_lens need be contiguous.", + func, input, key_cache, value_cache, + None, block_tables, context_lens[0::2], + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, False, -1) + scale_shape = (num_blocks) + key_cache = torch.zeros(cache_shape_k).uniform_(-128, 127).to(torch.int8).mlu() + value_cache = torch.zeros(cache_shape_v).uniform_(-128, 127).to(torch.int8).mlu() + key_cache_scale = torch.randn(scale_shape, dtype=torch.float32).mlu() + value_cache_scale = torch.randn(scale_shape, dtype=torch.float32).mlu() + self.assertException("k_cache_quant_scale must be 2d or 3d or 4d.", + func, input, key_cache, value_cache, + None, block_tables, context_lens, + key_cache_scale, value_cache_scale, + None, max_context_len, + window_size_left, window_size_right, softmax_scale, True, -1) + + def test_inductor(self): + batch, seq_q, head_num, num_kv_heads, head_size, max_seqlen = 1, 5, 16, 16, 128, 512 + is_pagedattn_list = [False, True] if "MLU3" not in torch.mlu.get_device_name() else [False] + has_alibi_list = [True, False] + test_flags = product(is_pagedattn_list, has_alibi_list) + for is_pagedattn, has_alibi in test_flags: + args = gen_args(batch, head_size, is_pagedattn, has_alibi, torch.int8, max_seqlen, seq_q, head_num, num_kv_heads) + self.base_opcheck(torch.ops.torch_mlu_ops.single_query_cached_kv_attn, args) + +if __name__ == '__main__': + exit(run_unittest(TestSingleQueryAttnOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_mixed_cached_kv_attn.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_mixed_cached_kv_attn.py new file mode 100644 index 0000000..66a9c0b --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_single_query_mixed_cached_kv_attn.py @@ -0,0 +1,393 @@ +import math +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +def gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len, data_type, quant_bit, quant_mode, is_normal=True): + int_max = float(2 ** (quant_bit - 1) - 1) + int_min = -float(2 ** (quant_bit - 1)) + context_lens = torch.randint(seq_q, seq_len + 1, (batch, ), dtype=torch.int32).mlu() + if is_normal is False and batch > 3: # replace some batch's context to 0 + num = batch // 3 + index = torch.randint(0, batch, (num,)) + context_lens[index] = 0 + max_context_len = context_lens.max().item() + block_size = 16 + if is_pagedattn is False: + block_size = seq_len + num_blocks = batch * ((seq_len + block_size - 1) // block_size) + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + cache_shape = (num_blocks, num_kv_heads, block_size, head_size) + if quant_mode == "per_token": + scale_shape = (num_blocks, num_kv_heads, block_size, 1) + else: # per channel + scale_shape = (num_kv_heads, head_size) + + if quant_bit == 4: + cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2) + cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size) + cache_shape_v_int4_tmp = (num_blocks, num_kv_heads, head_size, block_size//2) + key_cache_int8 = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu() + value_cache_int8 = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu() + # pre_process int4_kv_cache + key_cache_view = key_cache_int8.reshape(-1, head_size) + value_cache_view = value_cache_int8.transpose(2, 3).reshape(-1, block_size) + key_cache = PairlyPackInt8(key_cache_view).view(cache_shape_k_int4) + value_cache = PairlyPackInt8(value_cache_view).view(cache_shape_v_int4_tmp).transpose(2,3).contiguous() + key_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + value_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + elif quant_bit == 8: + key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu() + value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu() + key_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + value_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu() + elif quant_bit == -1: + key_cache = torch.randn(cache_shape, dtype=data_type).mlu() + value_cache = torch.randn(cache_shape, dtype=data_type).mlu() + key_scale = None + value_scale = None + else: + print("gen case error, quant_bit_lp must be in {-1, 4, 8}") + block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq) + block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq) + output = [key_cache, value_cache, key_scale, value_scale, context_lens, block_tables] + if quant_bit == 4: + output.append(key_cache_int8) + output.append(value_cache_int8) + return output + +def concate_cache_linear(cachek1, cachek2, cachev1, cachev2, context1, context2, block_tables1, block_tables2, + scalek1, scalek2, scalev1, scalev2): + if scalek1 is not None: + if scalek1.dim() == 2: # per_channel: [kv_head_num, head_size] + scalek1 = scalek1.reshape(1, scalek1.shape[0], 1, scalek1.shape[1]) + scalev1 = scalev1.reshape(1, scalev1.shape[0], 1, scalev1.shape[1]) + elif scalek1.dim() == 3: # per_token: [num_blocks, k_head_num, block_size] + scalek1 = scalek1.reshape(*scalek1.shape, 1) + scalev1 = scalev1.reshape(*scalev1.shape, 1) + cachek1 *= scalek1 + cachev1 *= scalev1 + if scalek2 is not None: + if scalek2.dim() == 2: # per_channel: [kv_head_num, head_size] + scalek2 = scalek2.reshape(1, scalek2.shape[0], 1, scalek2.shape[1]) + scalev2 = scalev2.reshape(1, scalev2.shape[0], 1, scalev2.shape[1]) + elif scalek2.dim() == 3: # per_token: [num_blocks, k_head_num, block_size] + scalek2 = scalek2.reshape(*scalek2.shape, 1) + scalev2 = scalev2.reshape(*scalev2.shape, 1) + cachek2 *= scalek2 + cachev2 *= scalev2 + new_context = context1 + context2 + seq_len1 = cachek1.shape[2] + seq_len2 = cachek2.shape[2] + new_max_context = seq_len1 + seq_len2 + batch = cachek1.shape[0] + num_head = cachek1.shape[1] + head_size = cachek1.shape[3] + new_cache_k = torch.randn(batch, num_head, new_max_context, head_size, dtype=torch.float32) + new_cache_v = torch.randn(batch, num_head, new_max_context, head_size, dtype=torch.float32) + new_block_table = torch.arange(0, batch) + new_block_table = new_block_table.view(batch, 1) + for i in range(batch): + len1 = context1[i] + len2 = context2[i] + block_id1 = block_tables1[i] + block_id2 = block_tables2[i] + new_cache_k[i, :, :len1, :] = cachek1[block_id1, :, :len1, :] + new_cache_v[i, :, :len1, :] = cachev1[block_id1, :, :len1, :] + new_cache_k[i, :, len1:len1 + len2, :] = cachek2[block_id2, :, :len2, :] + new_cache_v[i, :, len1:len1 + len2, :] = cachev2[block_id2, :, :len2, :] + return new_cache_k.mlu(), new_cache_v.mlu(), new_block_table.mlu(), new_context.mlu() + +# cache1 and cache2 are float +def concat_cache_paged(cachek1, cachek2, cachev1, cachev2, context1, context2, block_tables1, block_tables2, + scalek1, scalek2, scalev1, scalev2): + batch = context1.shape[0] + block_size = cachek1.shape[2] + if scalek1 is not None: + if scalek1.dim() == 2: # per_channel: [kv_head_num, head_size] + scalek1 = scalek1.reshape(1, scalek1.shape[0], 1, scalek1.shape[1]) + scalev1 = scalev1.reshape(1, scalev1.shape[0], 1, scalev1.shape[1]) + elif scalek1.dim() == 3: # per_token: [num_blocks, k_head_num, block_size] + scalek1 = scalek1.reshape(*scalek1.shape, 1) + scalev1 = scalev1.reshape(*scalev1.shape, 1) + cachek1 *= scalek1 + cachev1 *= scalev1 + if scalek2 is not None: + if scalek2.dim() == 2: # per_channel: [kv_head_num, head_size] + scalek2 = scalek2.reshape(1, scalek2.shape[0], 1, scalek2.shape[1]) + scalev2 = scalev2.reshape(1, scalev2.shape[0], 1, scalev2.shape[1]) + elif scalek2.dim() == 3: # per_token: [num_blocks, k_head_num, block_size] + scalek2 = scalek2.reshape(*scalek2.shape, 1) + scalev2 = scalev2.reshape(*scalev2.shape, 1) + cachek2 *= scalek2 + cachev2 *= scalev2 + new_context = context1 + context2 + max_num_blocks_per_seq = block_tables1.shape[1] + block_tables1.shape[1] + new_cache_k = torch.concat((cachek1, cachek2), dim = 0) + new_cache_v = torch.concat((cachev1, cachev2), dim = 0) + num_block1 = cachek1.shape[0] + new_block_table = torch.zeros(batch, max_num_blocks_per_seq, dtype=torch.int32) + for i in range(batch): + len1 = context1[i] + len2 = context2[i] + block_num1 = (len1 + block_size - 1) // block_size + block_num2 = (len2 + block_size - 1) // block_size + len1_pad = block_num1 * block_size + block1 = block_tables1[i] + block2 = block_tables2[i] + new_block_table[i, :block_num1] = block1[:block_num1] + new_block_table[i, block_num1:block_num1 + block_num2] = block2[:block_num2] + num_block1 + if len1 != len1_pad: + reg_block_id = new_block_table[i, block_num1 - 1] # last block of cache1 + cat_block_id = new_block_table[i, block_num1] # frist block of cache2 + reg_len = len1 % block_size + pad_len = len1_pad - len1 + new_cache_k[reg_block_id, :, reg_len:, :] = new_cache_k[cat_block_id, :, :pad_len, :] + new_cache_v[reg_block_id, :, reg_len:, :] = new_cache_v[cat_block_id, :, :pad_len, :] + for j in range(block_num2-1): + block_id1 = new_block_table[i, block_num1 + j] # current + block_id2 = new_block_table[i, block_num1 + j + 1] # next + new_cache_k[block_id1, :, :reg_len, :] = new_cache_k[block_id1, :, pad_len:, :] + new_cache_k[block_id1, :, reg_len:, :] = new_cache_k[block_id2, :, :pad_len, :] + new_cache_v[block_id1, :, :reg_len, :] = new_cache_v[block_id1, :, pad_len:, :] + new_cache_v[block_id1, :, reg_len:, :] = new_cache_v[block_id2, :, :pad_len, :] + block_id = new_block_table[i, block_num1 + block_num2 - 1] + new_cache_k[block_id, :, :reg_len, :] = new_cache_k[block_id, :, pad_len:, :] + new_cache_v[block_id, :, :reg_len, :] = new_cache_v[block_id, :, pad_len:, :] + return new_cache_k.mlu(), new_cache_v.mlu(), new_block_table.mlu(), new_context.mlu() + +class TestSingleQueryMixedKVAttnOp(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + q = create_tensor_from_dic(dic['q']) + k_cache_lp = create_tensor_from_dic(dic['k_cache_lp']) + v_cache_lp = create_tensor_from_dic(dic['v_cache_lp']) + k_cache_hp = create_tensor_from_dic(dic['k_cache_hp']) + v_cache_hp = create_tensor_from_dic(dic['v_cache_hp']) + out = create_tensor_from_dic(dic['out']) + block_tables_lp = dic['block_tables_lp']['data'] + block_tables_hp = dic['block_tables_hp']['data'] + context_lens_lp = dic['context_lens_lp']['data'] + context_lens_hp = dic['context_lens_hp']['data'] + k_cache_quant_scale_lp = create_tensor_from_dic(dic['k_cache_quant_scale_lp']) + v_cache_quant_scale_lp = create_tensor_from_dic(dic['v_cache_quant_scale_lp']) + k_cache_quant_scale_hp = create_tensor_from_dic(dic['k_cache_quant_scale_hp']) + v_cache_quant_scale_hp = create_tensor_from_dic(dic['v_cache_quant_scale_hp']) + alibi_slopes = create_tensor_from_dic(dic['alibi_slopes']) + max_contxt_len_lp = dic['max_contxt_len_lp']['data'] + max_contxt_len_hp = dic['max_contxt_len_hp']['data'] + softmax_scale = dic['softmax_scale']['data'] + return_lse = dic['return_lse']['data'] + kv_cache_quant_bit_size_lp = dic['kv_cache_quant_bit_size_lp']['data'] + kv_cache_quant_bit_size_hp = dic['kv_cache_quant_bit_size_hp']['data'] + self.launch(q, k_cache_lp, v_cache_lp, k_cache_hp, v_cache_hp, out, block_tables_lp, + block_tables_hp, context_lens_lp, context_lens_hp, k_cache_quant_scale_lp, + v_cache_quant_scale_lp, k_cache_quant_scale_hp, v_cache_quant_scale_hp, + alibi_slopes, max_contxt_len_lp, max_contxt_len_hp, softmax_scale, + return_lse, kv_cache_quant_bit_size_lp, kv_cache_quant_bit_size_hp) + + def launch(self, *args): + torch_output, torch_lse = self.op_impl_base(*args) + tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(*args) + self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), + 0.0003, use_MSE=True) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True) + + def op_impl_base(self, *args): + q, k_cache_lp, v_cache_lp, k_cache_hp, v_cache_hp, out, block_tables_lp, block_tables_hp, \ + context_lens_lp, context_lens_hp, k_cache_quant_scale_lp, v_cache_quant_scale_lp, k_cache_quant_scale_hp, \ + v_cache_quant_scale_hp, alibi_slopes, max_contxt_len_lp, max_contxt_len_hp, softmax_scale, return_lse, \ + kv_cache_quant_bit_size_lp, kv_cache_quant_bit_size_hp = args + + if kv_cache_quant_bit_size_lp == 4: + num_blocks, num_kv_heads, block_size_lp, head_size = v_cache_lp.size() + block_size = block_size_lp * 2 + k_cache_lp = UnpackInt4(k_cache_lp).reshape(num_blocks, num_kv_heads, block_size, head_size) + v_cache_lp = UnpackInt4(v_cache_lp.transpose(2,3)).reshape(num_blocks, num_kv_heads, head_size, block_size).transpose(2,3) + + torch_output_lp, torch_lse_lp = single_query_cached_kv_attn(q.contiguous().float(), k_cache_lp.float(), v_cache_lp.float(), + block_tables_lp, context_lens_lp, k_cache_quant_scale_lp, v_cache_quant_scale_lp, alibi_slopes, -1, -1, softmax_scale, return_lse) + torch_output_hp, torch_lse_hp = single_query_cached_kv_attn(q.contiguous().float(), k_cache_hp.float(), v_cache_hp.float(), + block_tables_hp, context_lens_hp, k_cache_quant_scale_hp, v_cache_quant_scale_hp, alibi_slopes, -1, -1, softmax_scale, return_lse) + torch_output, torch_lse = update_out_and_lse_torch(torch_output_lp, torch_lse_lp, torch_output_hp, torch_lse_hp, None, None, None) + return (torch_output, torch_lse) if return_lse else torch_output + + def test_single_query_mixedkv_attention(self): + head_num = 16 + batch = 12 + num_kv_heads = 4 + seq_q = 1 + head_size = 128 + seq_len_lp = 512 + seq_len_hp = 128 + is_pagedattn_list = [True, False] + has_alibi_list = [True, False] + is_normal_list = [True, False] # if false, lp_k/v_len of some batch is 0 + quant_bit_list = [(-1, -1), (4, 8), (4, -1), (8, -1), (8, 8)] + data_type_list = [torch.float, torch.float16] + if torch_mlu.mlu.is_bf16_supported(): + data_type_list.append(torch.bfloat16) + args = product(is_pagedattn_list, has_alibi_list, data_type_list, quant_bit_list, is_normal_list) + for is_pagedattn, has_alibi, data_type, quant_bit, is_normal in args: + quant_bit_lp, quant_bit_hp = quant_bit + print("test separate:{} + {}: batch:{}, seq_len_lp:{}, seq_len_hp:{}, head_size:{}, is_pagedattn:{}, has_alibi:{}, data_type:{}, is_normal:{} ...".format( + quant_bit_lp, quant_bit_hp, batch, seq_len_lp, seq_len_hp, head_size, is_pagedattn, has_alibi, data_type, is_normal)) + if is_pagedattn: + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + continue + torch.manual_seed(1) + input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype=data_type).mlu() + input_q = input_qkv[:, 0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size) + #gen k/v cache + params_lp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_lp, data_type, quant_bit_lp, + "per_token", is_normal) + params_hp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_hp, data_type, quant_bit_hp, + "per_channel") + if quant_bit_lp == 4: + key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp, _, _ = params_lp + else: + key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp = params_lp + key_cache_hp, value_cache_hp, key_scale_hp, value_scale_hp, context_lens_hp, block_tables_hp = params_hp + max_context_len_lp = context_lens_lp.max().item() + max_context_len_hp = context_lens_hp.max().item() + + alibi_slopes = None + if has_alibi: + alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu() + alibi_slopes.uniform_(0, 0.125) + softmax_scale = 1 / math.sqrt(head_size) + torch_output, torch_lse = self.op_impl_base(input_q, + key_cache_lp, value_cache_lp, + key_cache_hp, value_cache_hp, + None, #output + block_tables_lp, block_tables_hp, + context_lens_lp, context_lens_hp, + key_scale_lp, value_scale_lp, + key_scale_hp, value_scale_hp, + alibi_slopes, + max_context_len_lp, max_context_len_hp, + softmax_scale, True, + quant_bit_lp, quant_bit_hp) + + tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(input_q, + key_cache_lp, value_cache_lp, + key_cache_hp, value_cache_hp, + None, #output + block_tables_lp, block_tables_hp, + context_lens_lp, context_lens_hp, + key_scale_lp, value_scale_lp, + key_scale_hp, value_scale_hp, + alibi_slopes, + max_context_len_lp, max_context_len_hp, + softmax_scale, True, + quant_bit_lp, quant_bit_hp) + self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), + 0.0003, use_MSE=True) + self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True) + + def test_single_query_mixedkv_attention_concate(self): + head_num = 16 + batch = 12 + num_kv_heads = 4 + seq_q = 1 + head_size = 128 + seq_len_lp = 512 + seq_len_hp = 128 + is_pagedattn_list = [True, False] + has_alibi = False #only support without alibi + is_normal_list = [True, False] # if false, lp_k/v_len of some batch is 0 + quant_bit_list = [(-1, -1), (4, 8), (4, -1), (8, -1), (8, 8)] + data_type_list = [torch.float, torch.float16] + if torch_mlu.mlu.is_bf16_supported(): + data_type_list.append(torch.bfloat16) + args = product(is_pagedattn_list, data_type_list, quant_bit_list, is_normal_list) + for is_pagedattn, data_type, quant_bit, is_normal in args: + quant_bit_lp, quant_bit_hp = quant_bit + print("test concate {} + {}: seq_len_lp: {}, seq_len_hp: {}, is_pagedattn: {}, data_type {}, is_normal {} ...".format( + quant_bit_lp, quant_bit_hp, batch, seq_len_lp, seq_len_hp, head_size, is_pagedattn, data_type, is_normal)) + if is_pagedattn: + mlu_name = torch.mlu.get_device_name() + if "MLU3" in mlu_name: + print("pagedattn is not implement on mlu370, skip it") + continue + torch.manual_seed(1) + input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype=data_type).mlu() + input_q = input_qkv[:, 0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size) + #gen k/v cache + params_lp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_lp, data_type, quant_bit_lp, + "per_token", is_normal) + params_hp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_hp, data_type, quant_bit_hp, + "per_channel") + if quant_bit_lp == 4: + key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp, key_cache_lp_torch, value_cache_lp_torch = params_lp + else: + key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp = params_lp + key_cache_lp_torch, value_cache_lp_torch = key_cache_lp, value_cache_lp + key_cache_hp, value_cache_hp, key_scale_hp, value_scale_hp, context_lens_hp, block_tables_hp = params_hp + max_context_len_lp = context_lens_lp.max().item() + max_context_len_hp = context_lens_hp.max().item() + + alibi_slopes = None + if has_alibi: + alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu() + alibi_slopes.uniform_(0, 0.125) + softmax_scale = 1 / math.sqrt(head_size) + # concate cache + if is_pagedattn: + concat_cache_k, concat_cache_v, concat_block_tables, concat_context = concat_cache_paged( + key_cache_lp_torch.float().cpu(), key_cache_hp.float().cpu(), + value_cache_lp_torch.float().cpu(), value_cache_hp.float().cpu(), + context_lens_lp.cpu(), context_lens_hp.cpu(), + block_tables_lp.cpu(), block_tables_hp.cpu(), + key_scale_lp.cpu() if key_scale_lp is not None else None, + key_scale_hp.cpu() if key_scale_hp is not None else None, + value_scale_lp.cpu() if value_scale_lp is not None else None, + value_scale_hp.cpu() if value_scale_hp is not None else None) + else: + concat_cache_k, concat_cache_v, concat_block_tables, concat_context = concate_cache_linear( + key_cache_lp_torch.float().cpu(), key_cache_hp.float().cpu(), + value_cache_lp_torch.float().cpu(), value_cache_hp.float().cpu(), + context_lens_lp.cpu(), context_lens_hp.cpu(), + block_tables_lp.cpu(), block_tables_hp.cpu(), + key_scale_lp.cpu() if key_scale_lp is not None else None, + key_scale_hp.cpu() if key_scale_hp is not None else None, + value_scale_lp.cpu() if value_scale_lp is not None else None, + value_scale_hp.cpu() if value_scale_hp is not None else None) + + torch_output_concat, torch_lse_concat = single_query_cached_kv_attn(input_q.contiguous().float(), + concat_cache_k.float(), concat_cache_v.float(), concat_block_tables, + concat_context, None, None, alibi_slopes, -1, -1, softmax_scale, True) + tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(input_q, + key_cache_lp, value_cache_lp, + key_cache_hp, value_cache_hp, + None, #output + block_tables_lp, block_tables_hp, + context_lens_lp, context_lens_hp, + key_scale_lp, value_scale_lp, + key_scale_hp, value_scale_hp, + alibi_slopes, + max_context_len_lp, max_context_len_hp, + softmax_scale, True, + quant_bit_lp, quant_bit_hp) + if is_normal: # only compare lse when context_len_lp is normal, seq=0 case nan-value in lse + self.assertTensorsEqual(torch_lse_concat.cpu().float(), tmo_lse.cpu().float(), + 0.0003, use_MSE=True) + self.assertTensorsEqual(torch_output_concat.cpu().float(), tmo_output.cpu().float(), + 0.003, use_MSE=True) + + def test_inductor(self): + return super().test_inductor() + +if __name__ == '__main__': + exit(run_unittest(TestSingleQueryMixedKVAttnOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant.py new file mode 100755 index 0000000..b4883f1 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant.py @@ -0,0 +1,555 @@ +import torch +import torch_mlu +import unittest +import torch_mlu_ops as tmo +from common_utils import * +from itertools import product +import random +import os + +def generate_token_count(num_expert, + total_token_count): + token_count = torch.randint(low=1, high=1024, size=(num_expert, ), \ + dtype=torch.int32).to(dtype=torch.float32) + sum = torch.sum(token_count, dim=-1) * 1.0 + token_count *= total_token_count / sum.item() + token_count = token_count.to(dtype=torch.int32) + cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32) + end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count + cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) + cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count) + cusum_token_count[-1] = total_token_count + token_count = cusum_token_count[1:] - cusum_token_count[0:-1] + return token_count, cusum_token_count + +def gen_case(num_tokens, + topk, + hidden_size, + multi_scale, + need_gather, + num_expert, + expert_size, + start_expert_id, + dtype, + device): + input = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device) + effective_count = num_tokens * topk if need_gather else num_tokens + if multi_scale: + token_count, cusum_token_count = generate_token_count(num_expert, effective_count) + token_count = token_count.to(device=device) + cusum_token_count = cusum_token_count.to(device=device) + scale = torch.randn((num_expert, hidden_size), dtype=torch.float32, device=device) + else: + token_count = None + cusum_token_count = None + scale = torch.randn((hidden_size), dtype=torch.float32, device=device) + + if need_gather: + gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32) // topk + gather_ids = gather_ids.mlu() + else: + gather_ids = None + + if expert_size < num_expert and multi_scale: + token_count = token_count[start_expert_id:] + cusum_token_count = cusum_token_count[start_expert_id:] + gather_index_start_position = cusum_token_count[0:1] + scale = scale[start_expert_id:] + effective_count = cusum_token_count[-1] - cusum_token_count[0] + else: + gather_index_start_position = None + + if not need_gather: + gather_index_start_position = None + + return input, scale, gather_ids, token_count, cusum_token_count, gather_index_start_position, effective_count + +def per_token_smooth_quantize_base(x: torch.Tensor, + smooth: torch.Tensor, + zero: torch.Tensor = None, + token_count: torch.Tensor = None): + output_shape = x.size() + output_scale_shape = x.size()[0:-1] + output, output_scale = QuantByRow(x.flatten(0, -2) * smooth, 8) + return output.reshape(output_shape), output_scale.reshape(output_scale_shape) + +def quantize_base(x: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor = None + ) -> torch.Tensor: + return (x * scale).round().clamp(-128.0, 127.0).to(torch.int8) + +class TestSmoothQuantOp(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + x = create_tensor_from_dic(dic['x']) + smooth = create_tensor_from_dic(dic['smooth']) + zero = None if dic['zero']['data'] is None else create_tensor_from_dic(dic['zero']) + token_count = None if dic['token_count']['data'] is None else dic['token_count']['data'] + gather_index = None if dic['gather_index']['data'] is None else dic['gather_index']['data'] + gather_index_start_position = None if dic['gather_index_start_position']['data'] is None else dic['gather_index_start_position']['data'] + output = None if dic['output']['data'] is None else create_tensor_from_dic(dic['output']) + output_scale = None if dic['output_scale']['data'] is None else create_tensor_from_dic(dic['output_scale']) + dynamic_quant = dic['dynamic_quant']['data'] + self.launch(x, smooth, zero, token_count, gather_index, gather_index_start_position, output, output_scale, dynamic_quant) + + def launch(self, *args): + tmo_out = tmo.moe_quantize(*args) + out_base = None if args[6] is None else args[6].clone() + scale_base = None if args[7] is None else args[7].clone() + args = list(args) + args[6] = out_base + args[7] = scale_base + torch_out = self.op_impl_base(*args) + + if args[-1]: + self.assertTensorsEqual(torch_out[0].cpu().reshape(-1).float(), + tmo_out[0].cpu().reshape(-1).float(), + 0.01, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_out[1].cpu().reshape(-1).float(), + tmo_out[1].cpu().reshape(-1).float(), + 0.003, use_MSE=True, use_RAE=True) + else: + self.assertTensorsEqual(torch_out.cpu().reshape(-1).float(), + tmo_out.cpu().reshape(-1).float(), + 0.01, use_MSE=True, use_RAE=True) + + + def op_impl_base(self, *args): + x, smooth, zero, token_count, gather_index, gather_index_start_position, \ + output, output_scale, dynamic_quant = args + input = x.to(dtype=torch.float32).cpu() + input_scale = smooth.cpu() + cusum_token_count = None + if token_count is not None: + token_count = token_count.cpu() + cusum_token_count = torch.zeros(token_count.shape[0] + 1, dtype=torch.int32) + cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) + + if gather_index_start_position is not None: + gather_index_start_position = gather_index_start_position.cpu() + + gather_index_start = 0 + if gather_index_start_position is not None: + gather_index_start = gather_index_start_position[0] + + if cusum_token_count is not None: + gather_index_end = cusum_token_count[-1] + gather_index_start + elif gather_index is not None: + gather_index_end = gather_index.numel() + else: + gather_index_end = input.numel() + + if gather_index is not None: + gather_index = gather_index.cpu() + gathered_input = input[gather_index[gather_index_start : gather_index_end]] + else: + gathered_input = input[gather_index_start : gather_index_end] + + if cusum_token_count is not None: + for i in range(token_count.shape[0]): + gathered_input[cusum_token_count[i] : cusum_token_count[i+1]] *= input_scale[i] + else: + gathered_input *= input_scale + + if output is None: + if not dynamic_quant: + return gathered_input.round().clamp(-128.0, 127.0).to(torch.int8), None + else: + return QuantByRow(gathered_input, 8) + else: + if not dynamic_quant: + output.copy_(gathered_input.round().clamp(-128.0, 127.0).to(torch.int8)) + output_scale = None + else: + out, scale = QuantByRow(gathered_input, 8) + output_fl = output.flatten() + output_fl[:out.numel()].copy_(out.flatten()) + # output = output_fl + output_scale_fl = output_scale.flatten() + output_scale_fl[:scale.numel()].copy_(scale.flatten()) + # output_scale = output_scale_fl + return (output, output_scale) if dynamic_quant else (output,) + + def test_random_case(self): + torch.manual_seed(333) + test_cases = 100 + num_tokens_list = torch.randint(low=1, high=4096, size=(test_cases, ), dtype=torch.int32) + topk_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32) + hidden_size_list = torch.randint(low=128, high=8193, size=(test_cases, ), dtype=torch.int32) + num_expert_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32) + expert_size_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32) + expert_size_list = torch.minimum(expert_size_list, num_expert_list) + start_expert_id_list = torch.randint(low=0, high=129, size=(test_cases, ), dtype=torch.int32) + start_expert_id_list = torch.minimum(start_expert_id_list, num_expert_list - expert_size_list) + start_expert_id_list = torch.maximum(start_expert_id_list, torch.zeros(test_cases, dtype=torch.int32)) + multi_scale_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + need_gather_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + input_with_stride_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + dynamic_quant_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32) + dtype_list = torch.randint(low=0, high=10, size=(test_cases, ), dtype=torch.int32) + dtypes = [torch.half, torch.half, torch.float32] + if torch_mlu.mlu.is_bf16_supported(): + dtypes += [torch.bfloat16, torch.bfloat16] + dtype_list = random.choices(dtypes, k=test_cases) + + device = "mlu" + for i in range(test_cases): + num_tokens = num_tokens_list[i].item() + topk = topk_list[i].item() + hidden_size = hidden_size_list[i].item() + num_expert = num_expert_list[i].item() + expert_size = expert_size_list[i].item() + start_expert_id = start_expert_id_list[i].item() + multi_scale = multi_scale_list[i].item() == 1 + need_gather = need_gather_list[i].item() == 1 + input_with_stride = input_with_stride_list[i].item() == 1 + dynamic_quant = dynamic_quant_list[i].item() == 1 + dtype = dtype_list[i] + + if not multi_scale or not torch_mlu.mlu.is_bf16_supported(): + need_gather = False + + inputs = gen_case(num_tokens, + topk, + hidden_size, + multi_scale, + need_gather, + num_expert, + expert_size, + start_expert_id, + dtype, + device) + + input = inputs[0] + input_scale = inputs[1] + gather_ids = inputs[2] + token_count = inputs[3] + cusum_token_count = inputs[4] + gather_index_start_position = inputs[5] + effective_count = inputs[6] + + if input_with_stride: + hidden_size = hidden_size - 64 + input = input[..., hidden_size : ] + input_scale = input_scale[..., hidden_size : ].contiguous() + + print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, " + "start_expert_id={}, multi_scale={}, need_gather={}, input_with_stride={}, " + "dynamic_quant={}, dtype={}, testing...".format( + num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \ + multi_scale, need_gather, input_with_stride, dynamic_quant, dtype)) + + torch_quant, torch_output_scale = self.op_impl_base(input, + input_scale, + None, + token_count, + gather_ids, + gather_index_start_position, + None, + None, + dynamic_quant) + tmo_output_scale = None + if dynamic_quant: + tmo_output, tmo_output_scale = \ + tmo.moe_quantize(input, input_scale, None, token_count, + gather_ids, gather_index_start_position, + None, None, dynamic_quant) + else: + tmo_output, = \ + tmo.moe_quantize(input, input_scale, None, token_count, + gather_ids, gather_index_start_position, + None, None, dynamic_quant) + tmo_output = tmo_output[:effective_count] + if tmo_output_scale is not None: + tmo_output_scale = tmo_output_scale[:effective_count] + + self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(), + tmo_output.cpu().reshape(-1).float(), + 0.01, use_MSE=True, use_RAE=True) + + if dynamic_quant: + self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(), + tmo_output_scale.cpu().reshape(-1).float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_interface(self): + channel = 16 + dtype_list = [torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + for dtype in dtype_list: + input = torch.randn((2, 3, 2, channel), dtype=dtype, device="mlu") + input_scale = torch.randn((channel)).float().mlu() + + print("test tmo.per_token_smooth_quantize...") + torch_quant, torch_scale = per_token_smooth_quantize_base(input, input_scale) + tmo_quant, tmo_scale = tmo.per_token_smooth_quantize(input, input_scale) + + self.assertTensorsEqual(torch_quant.cpu().float(), tmo_quant.cpu().float(), + 0.01, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_scale.cpu().float(), tmo_scale.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + print("test tmo.quantize ...") + input_scale = input_scale * 100.0 + torch_quant = quantize_base(input, input_scale) + tmo_quant = tmo.quantize(input, input_scale) + self.assertTensorsEqual(torch_quant.cpu().float(), tmo_quant.cpu().float(), + 0.003, use_MSE=True, use_RAE=True) + + print("test tmo.moe_quantize inplace ...") + num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, expert_size, start_expert_id, dtype, dynamic_quant = \ + 1024, 5, 512, True, True, 32, 4, 4, torch.half, True + inputs = gen_case(num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, + expert_size, start_expert_id, dtype, 'mlu') + input = inputs[0] + input_scale = inputs[1] + gather_ids = inputs[2] + token_count = inputs[3] + cusum_token_count = inputs[4] + gather_index_start_position = inputs[5] + effective_count = inputs[6] + tmo_output = torch.empty(num_tokens*topk, hidden_size, dtype=torch.int8, device='mlu') + tmo_output_scale = torch.empty(num_tokens*topk, dtype=torch.float, device='mlu') + if 'MLU370' not in torch_mlu.mlu.get_device_name(): + tmo.moe_quantize(input, input_scale, None, token_count, + gather_ids, gather_index_start_position, + tmo_output, tmo_output_scale, dynamic_quant) + tmo_output = tmo_output[:effective_count] + tmo_output_scale = tmo_output_scale[:effective_count] + torch_output = torch.empty_like(tmo_output) + torch_output_scale = torch.empty_like(tmo_output_scale) + self.op_impl_base(input, input_scale, + None, token_count, gather_ids, gather_index_start_position, + torch_output, torch_output_scale, dynamic_quant) + self.assertTensorsEqual(torch_output.cpu().reshape(-1).float(), + tmo_output.cpu().reshape(-1).float(), + 0.01, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(), + tmo_output_scale.cpu().reshape(-1).float(), + 0.003, use_MSE=True, use_RAE=True) + + input = input.reshape(2, 4, 8, 16, -1) + input_scale = input_scale[0] + tmo_output = torch.empty(input.size(), dtype=torch.int8, device='mlu') + tmo_output_scale = torch.empty(input.size()[:-1], dtype=torch.float, device='mlu') + tmo.moe_quantize(input, input_scale, None, None, None, None, + tmo_output, tmo_output_scale, dynamic_quant) + torch_quant = torch.empty_like(tmo_output) + torch_output_scale = torch.empty_like(tmo_output_scale) + self.op_impl_base(input, input_scale, + None, None, None, None, torch_quant, torch_output_scale, dynamic_quant) + self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(), + tmo_output.cpu().reshape(-1).float(), + 0.01, use_MSE=True, use_RAE=True) + self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(), + tmo_output_scale.cpu().reshape(-1).float(), + 0.003, use_MSE=True, use_RAE=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_prevent(self): + num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, expert_size, start_expert_id, dtype, dynamic_quant = \ + 1024, 5, 512, True, True, 32, 4, 4, torch.half, True + inputs = gen_case(num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, + expert_size, start_expert_id, dtype, 'mlu') + input = inputs[0] + input_scale = inputs[1] + gather_ids = inputs[2] + token_count = inputs[3] + cusum_token_count = inputs[4] + gather_index_start_position = inputs[5] + effective_count = inputs[6] + tmo_output = torch.empty(input.size(), dtype=torch.int8, device='mlu') + tmo_output_scale = torch.empty(input.size()[:-1], dtype=torch.float, device='mlu') + func = tmo.moe_quantize + + self.assertException("input must be mlu tensor.", + func, input.cpu(), input_scale, None, None, None, None, + None, None, dynamic_quant) + self.assertException(None, + func, input, input_scale, None, token_count, gather_ids, + gather_index_start_position.cpu(), None, None, dynamic_quant) + self.assertException("not support output_scale if dynamic_quant = false", + func, input, input_scale, None, token_count, gather_ids, + gather_index_start_position, None, tmo_output_scale, False) + self.assertException("input.dim() == 2 if has gather_index or token_count", + func, input.reshape(32, 32, -1), input_scale, None, None, gather_ids, + gather_index_start_position, None, None, dynamic_quant) + self.assertException("input.dim() >= 2", + func, input.reshape(-1), input_scale, None, None, None, None, + None, None, dynamic_quant) + self.assertException("output.dim() >= 2", + func, input, input_scale, None, token_count, gather_ids, + gather_index_start_position, tmo_output.reshape(-1), tmo_output_scale, True) + self.assertException("input and output must have the same shape", + func, input.reshape(2, 512, -1), input_scale, None, None, None, None, + tmo_output.reshape(32, 32, -1), None, dynamic_quant) + self.assertException("output_scale_shape must be equal to input_shape[0:-1]", + func, input.reshape(2, 512, -1), input_scale[0], None, None, None, None, + None, tmo_output_scale.reshape(32, 32), dynamic_quant) + self.assertException("gather_index must exist if gather_index_start_position has value", + func, input, input_scale, None, token_count, None, + gather_index_start_position, None, None, True) + self.assertException("gather_index.dim() == 1", + func, input, input_scale, None, token_count, gather_ids.reshape(1, -1), + gather_index_start_position, None, None, True) + + + def test_perf_case(self): + num_tokens_list = [1, 72, 512] + topk = 5 + hidden_size_list = [2048, 4096, 5120, 8192] + # [num_expert, start_expert_id, expert_size] + expert_options_list = [[8, 0, 8], [32, 24, 8]] + multi_scale_list = [True, False] + need_gather_list = [True, False] + dynamic_quant_list = [True] + dtype_list = [torch.half, torch.bfloat16] + + device = 'mlu' + args = product(num_tokens_list, hidden_size_list, expert_options_list,\ + multi_scale_list, need_gather_list, dynamic_quant_list, dtype_list) + for num_tokens, hidden_size, expert_options, multi_scale, need_gather, dynamic_quant, dtype in args: + num_expert = expert_options[0] + start_expert_id = expert_options[1] + expert_size = expert_options[2] + + if not multi_scale or not torch_mlu.mlu.is_bf16_supported(): + need_gather = False + + if not torch_mlu.mlu.is_bf16_supported(): + continue + + torch.manual_seed(444) + inputs = gen_case(num_tokens, + topk, + hidden_size, + multi_scale, + need_gather, + num_expert, + expert_size, + start_expert_id, + dtype, + device) + + input = inputs[0] + input_scale = inputs[1] + gather_ids = inputs[2] + token_count = inputs[3] + cusum_token_count = inputs[4] + gather_index_start_position = inputs[5] + effective_count = inputs[6] + + print("num_tokens={}, hidden_size={}, num_expert={}, expert_size={}, " + "start_expert_id={}, multi_scale={}, need_gather={}, input_with_stride={}, " + "dynamic_quant={}, dtype={}, testing...".format( + num_tokens, hidden_size, num_expert, expert_size, start_expert_id, \ + multi_scale, need_gather, False, dynamic_quant, dtype)) + + torch_quant, torch_output_scale = self.op_impl_base(input, + input_scale, + None, + token_count, + gather_ids, + gather_index_start_position, + None, + None, + dynamic_quant) + notify_start = torch.mlu.Event(enable_timing=True) + notify_end = torch.mlu.Event(enable_timing=True) + notify_start.record() + loop = 10 + for _ in range(loop): + if dynamic_quant: + tmo_output, tmo_output_scale = \ + tmo.moe_quantize(input, input_scale, None, token_count, + gather_ids, gather_index_start_position, + None, None, dynamic_quant) + else: + tmo_output, = \ + tmo.moe_quantize(input, input_scale, None, token_count, + gather_ids, gather_index_start_position, + None, None, dynamic_quant) + + notify_end.record() + notify_end.synchronize() + time = notify_start.hardware_time(notify_end) / loop + + tmo_output = tmo_output[:effective_count] + if tmo_output_scale is not None: + tmo_output_scale = tmo_output_scale[:effective_count] + print("time is: {:.1f}us".format(time)) + + self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(), + tmo_output.cpu().reshape(-1).float(), + 0.01, use_MSE=True, use_RAE=True) + if dynamic_quant: + self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(), + tmo_output_scale.cpu().reshape(-1).float(), + 0.003, use_MSE=True, use_RAE=True) + + def test_inductor(self): + multi_scale_list = [True, False] + need_gather_list = [True, False] if 'MLU370' not in torch_mlu.mlu.get_device_name() else [False] + dynamic_quant_list = [True, False] + num_tokens, hidden_size, num_expert, start_expert_id, expert_size, topk, \ + dtype, device = 1, 2048, 32, 24, 8, 5, torch.float16, 'mlu' + params = product(multi_scale_list, need_gather_list, dynamic_quant_list) + # quantize + print(f"check ops.quantize...") + input, input_scale, gather_ids, token_count, _, _, _ = gen_case(num_tokens, + topk, + hidden_size, + False, + False, + num_expert, + expert_size, + start_expert_id, + dtype, + device) + output = torch.empty(input.size(), dtype=torch.int8, device=device) + args = (input, input_scale, output, torch.Tensor(), None, None, None, None, 'per_token', False) + self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args) + + # per_token_smooth_quantize + print(f"check ops.per_token_smooth_quantize...") + output_scale = torch.empty(input.size()[:-1], dtype=input_scale.dtype, device=device) + args = (input, input_scale, output, output_scale, None, token_count, None, None, 'per_token', True) + self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args) + + # moe_quantize + for multi_scale, need_gather, dynamic_quant in params: + if not multi_scale and need_gather: + continue + print(f"check ops.moe_quantize multi_scale: {multi_scale}, need_gather: {need_gather}, dynamic_quant: {dynamic_quant} ...") + input, input_scale, gather_ids, token_count, _, \ + gather_index_start_position, _ = gen_case(num_tokens, + topk, + hidden_size, + multi_scale, + need_gather, + num_expert, + expert_size, + start_expert_id, + dtype, + device) + output_shape = list(input.size()) + output_scale_shape = list(input.size()[:-1]) + if gather_ids is not None: + output_tokens = gather_ids.size(0) + output_shape[0] = output_tokens + output_scale_shape[0] = output_tokens + output = torch.empty(output_shape, dtype=torch.int8, device=device) + output_scale = torch.empty(output_scale_shape, dtype=input_scale.dtype, device=device) if dynamic_quant else None + args = (input, input_scale, + output, output_scale, None, + token_count, gather_ids, gather_index_start_position, + 'per_token', dynamic_quant) + self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args) + +if __name__ == '__main__': + exit(run_unittest(TestSmoothQuantOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_group_gemm.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_group_gemm.py new file mode 100755 index 0000000..1ecc9cd --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_group_gemm.py @@ -0,0 +1,217 @@ +import torch +from torch_mlu import mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +from torch.nn import functional as F + + +def gen_mix_w4w8_param(bc, seq, k, n, experts_num, topk, data_type, has_bias, quant_wise): + bs = bc * seq + token_topk = bs * topk + expert_id = torch.randint(experts_num, (token_topk,), device="mlu") + sorted_expert_id, indices = expert_id.sort() + gather_idx = indices // topk + token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32) + quant_group = k // quant_wise + quant_flag = random.choices([4,8], k=experts_num * quant_group) + b_count = (sum(quant_flag) // 4) * (quant_wise // 2) * n + b = torch.randint(-128, 127, (b_count,), dtype=torch.int32, device="mlu").to(torch.int8) + b_scale = torch.normal(0, 0.01, (quant_group, experts_num, n), device="mlu", dtype=torch.float32) + a = torch.randint(-128, 127, (bs, k), device="mlu", dtype=torch.int32).to(torch.int8) + a = a[gather_idx] + a_scale = torch.normal(0, 0.01, (bs,), device="mlu", dtype=torch.float32) + a_scale = a_scale[gather_idx] + c = torch.randn(token_topk, n, device="mlu", dtype=data_type) + bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) if has_bias else None + return a, b, token_count, None, c, None, None, a_scale, b_scale, data_type, bs, bias, quant_flag + +def gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias=False, quant_bit = 8, quant_group = 1): + bs = batch * seq + token_topk = bs * topk + expert_id = torch.randint(experts_num, (token_topk,), device="mlu") + sorted_expert_id, indices = expert_id.sort() + gather_idx = indices // topk + gather_idx = gather_idx.to(torch.int32) + token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32) + + a = torch.randn(bs, k, device="mlu", dtype=data_type) + if not idx_mode: + a = a[gather_idx] + b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type) + c = torch.randn(token_topk, n, device="mlu", dtype=data_type) + + a, a_scale = QuantByRow(a, 8) + if idx_mode: + a_scale = a_scale[gather_idx] + b_shape = b.shape + b, b_scale = QuantByRow(b.view(-1, b.shape[-1]), quant_bit, quant_group) + b = b.view(b_shape) + if quant_bit == 4: + b = PairlyPackInt8(b) + b_scale = b_scale.view(experts_num, -1) if quant_group == 1 else b_scale.view(experts_num, -1, quant_group).permute(2, 0, 1).contiguous() + alpha = None + beta = None + bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) if has_bias else None + gather_idx_ = gather_idx if idx_mode else None + quant_flag = None + return a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, data_type, bs, bias, quant_flag + +class TestSmoothQuantGroupGemmOp(BtTestCase): + def run_gen_case(self, dic): + dump_data = dic.pop('dump_data') + if dump_data: + self.launch(*dic.values()) + else: + a = create_tensor_from_dic(dic['a']) + b = create_tensor_from_dic(dic['b']) + m_list = dic['m_list']['data'] + expand_idx = dic['expand_idx']['data'] + c = None if dic['c']['data'] is None else create_tensor_from_dic(dic['c']) + alpha = None if dic['alpha']['data'] is None else create_tensor_from_dic(dic['alpha']) + beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta']) + a_scale = create_tensor_from_dic(dic['a_scale'], 0, 0.01) + b_scale = create_tensor_from_dic(dic['b_scale'], 0, 0.01) + dtype = dic['dtype']['data'] + max_m = dic['max_m']['data'] + bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias']) + quant_flag = dic['quant_flag']['data'] + self.launch(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag) + + def launch(self, *args): + total_m = args[2].sum().item() + torch_out = self.op_impl_base(*args) + tmo_out = ops.smooth_quant_group_gemm(*args) + self.assertTensorsEqual(tmo_out.cpu().float()[0:total_m], torch_out.cpu().float()[0:total_m], 0.006, use_MSE=True) + + def op_impl_base(self, *args): + a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag = args + a = a.reshape(-1, a.size(-1)) + if expand_idx is not None: + a = a[expand_idx] + total_m = m_list.sum().item() + a_list = a[:total_m].split(tuple(m_list)) + + c_list = [] + if c is not None: + c = c.reshape(-1, c.size(-1)) + c_list = c[:total_m].split(tuple(m_list)) + + if a_scale is not None: + a_scale_list = a_scale[:total_m].split(tuple(m_list)) + + k = a.shape[1] + n = b.size(1) if quant_flag is None else b_scale.shape[2] + if b_scale is not None and b_scale.dim() == 3: # for quant_grouped + b_scale = b_scale.transpose(0, 1).contiguous() + if quant_flag is not None: + quant_group = b_scale.shape[1] + group_wise = k // quant_group + quant_flag = torch.tensor(quant_flag).view(-1, quant_group) + b_offset_cu = torch.cumsum(quant_flag.sum(dim=1), dim=0) // 4 * (group_wise // 2) * n + b_offset_cu = torch.nn.functional.pad(b_offset_cu, (1,0), "constant", 0) + + output_list = [] + experts = b.size(0) if quant_flag is None else b_scale.size(0) + for i in range(experts): + if (a_list[i].size(0) > 0): + if a_scale is not None and b_scale is not None: + if quant_flag is None: + gemm_out = smooth_quant_matmul(a_list[i], a_scale_list[i], b[i], b_scale[i], dtype) + else: + gemm_out = smooth_quant_matmul_w4w8_mixed(a_list[i], a_scale_list[i], + b[b_offset_cu[i]:b_offset_cu[i+1]], + b_scale[i], dtype, quant_flag = quant_flag[i]) + else: + gemm_out = F.linear(a_list[i], b[i]) + if bias is not None: + gemm_out += bias[i] + if alpha is not None: + gemm_out *= alpha[i] + if beta is not None and c_list != []: + gemm_out += c_list[i] * beta[i] + output_list.append(gemm_out) + real_res = torch.cat(output_list, dim=0) + output = torch.empty(a.shape[0], n, device=real_res.device).to(real_res.dtype) + output[:total_m] = real_res + return output + + def test_smooth_quant_group_gemm(self): + bs_list = [1, 3] + seq_list = [5, 8] + k_list = [512, 1024] + n_list = [512, 768, 2048] + expert_list = [8, 32] + topk_list = [2, 5] + dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half] + idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True] + has_bias_list = [True, False] + + args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list) + for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias in args: + print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, topk: {topk}, \ +dtype: {data_type}, idx_mode: {idx_mode}, has_bias: {has_bias} testing...", flush=True) + param = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias) + torch_out = self.op_impl_base(*param) + tmo_out = ops.smooth_quant_group_gemm(*param) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + + + def test_sq_group_gemm_quant_group(self): + bs_list = [1, 3] + seq_list = [5, 8] + k_list = [512, 1024] + n_list = [512, 768, 2048] + expert_list = [8, 32] + topk_list = [2, 5] + dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half] + idx_list = [False] + quant_bit_list = [4, 8] + quant_group_size_list = [128, 256] + has_bias_list = [True, False] + + args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list, quant_bit_list, quant_group_size_list) + for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias, quant_bit, quant_group_size in args: + print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, \ +topk: {topk}, dtype: {data_type}, idx_mode: {idx_mode}, has_bias:{has_bias}, quant_bit: {quant_bit}, quant_group_size: {quant_group_size} testing...", flush=True) + quant_group = k // quant_group_size + param = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias, quant_bit, quant_group) + torch_out = self.op_impl_base(*param) + tmo_out = ops.smooth_quant_group_gemm(*param) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + + def test_sq_group_gemm_w4w8_mixed(self): + bs_l = [1, 3] + seq_l = [5, 8] + k_l = [1024, 2048, 3072] + n_l = [512, 768, 2048] + expert_l = [8, 32] + topk_l = [2, 5] + dtype_l = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half] + has_bias_l = [True, False] + group_wise_l = [128, 256, 512] + + args = product(bs_l, seq_l, k_l, n_l, expert_l, topk_l, dtype_l, has_bias_l, group_wise_l) + for bc, seq, k, n, experts, topk, data_type, has_bias, group_wise in args: + print(f"bs: {bc}, seq_len: {seq}, k: {k}, n: {n}, experts: {experts}, \ +topk: {topk}, dtype: {data_type}, has_bias: {has_bias}, group_wise: {group_wise}, testing...", flush=True) + param = gen_mix_w4w8_param(bc, seq, k, n, experts, topk, data_type, has_bias, group_wise) + torch_out = self.op_impl_base(*param) + tmo_out = ops.smooth_quant_group_gemm(*param) + self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True) + + def test_inductor(self): + batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5 + dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half] + idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True] + has_bias_list = [True, False] + args = product( dtype_list, idx_list, has_bias_list) + for data_type, idx_mode, has_bias in args: + args = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias) + new_args = list(args)[:9] + new_args.extend([args[-2], "half" if data_type == torch.half else "bfloat16", args[-1], None, args[-3]]) + self.base_opcheck(torch.ops.torch_mlu_ops.group_gemm, new_args) + +if __name__ == '__main__': + exit(run_unittest(TestSmoothQuantGroupGemmOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_matmul.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_matmul.py new file mode 100644 index 0000000..a411122 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_smooth_quant_matmul.py @@ -0,0 +1,101 @@ +import torch +import torch_mlu +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +dtype_dict = { + torch.half: "half", + torch.bfloat16: "bfloat16", +} +dtype_list = [torch.half] +if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + +# M<=INT32_MAX, nk<=INT32_MAX, k>=16, n>=16 +class TestSmoothQuantMatmulOp(BtTestCase): + def op_impl_base(self, *args): + a, a_scale, b, b_scale, dtype, bias, c, act_mode, alpha, beta, use_hp_active = args + if a_scale is not None: + a = torch.mul(a, a_scale.unsqueeze(-1)).to(dtype) + if b_scale is not None: + b = torch.mul(b, b_scale.unsqueeze(-1)).to(dtype) + output = torch.matmul(a, b.permute(1,0)) + if bias is not None: + output += bias + output = torch.mul(output, alpha) + if c is not None: + residual = torch.mul(c, beta) + output = torch.add(output, residual) + if act_mode != "none": + act = act_mode_dict[act_mode] + output = act(output) + return output + + def test_smooth_quant_matmul(self): + m_list = [32, 64, 128] + n_list = [64, 128, 256] + k_list = [128, 256, 512] + has_bias_list = [True, False] + has_c_list = [True, False] + act_mode_list = ["none", "silu", "gelu"] + use_hp_active_list = [True, False] + + args = product(m_list, n_list, k_list, has_bias_list, has_c_list, act_mode_list, dtype_list, use_hp_active_list) + for m, n, k, has_bias, has_c, act_mode, dtype, use_hp_active in args: + if has_c and act_mode != "none": + continue + a = torch.randn(m, k, device="mlu", dtype=dtype) + b = torch.randn(n, k, device="mlu", dtype=dtype) + bias, c = None, None + if has_bias: + bias = torch.randn(n, device="mlu", dtype=dtype) + if has_c: + c = torch.randn(m, n, device="mlu", dtype=dtype) + + input_smooth = torch.randn(k, device="mlu", dtype=torch.float).abs() + 0.1 + quant_input, input_scale = QuantByRow(a * input_smooth, 8) + quant_weight, weight_scale = QuantByRow(b / input_smooth, 8) + + torch_output = self.op_impl_base(quant_input, input_scale, quant_weight, weight_scale, dtype, bias, c, + act_mode, 1.0, 1.0, use_hp_active) + tmo_output = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, dtype, bias, c, + act_mode, 1.0, 1.0, use_hp_active) + self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.006, use_MSE=True) + + def test_inductor(self): + has_bias_list = [True, False] + has_c_list = [True, False] + act_mode_list = ["none", "silu", "gelu"] + arg = product(has_c_list, has_bias_list, act_mode_list, dtype_list) + for has_c, has_bias, act_mode, dtype in arg: + if has_c and act_mode != "none": + continue + print(f"===has_c: {has_c}, has_bias: {has_bias}, act_mode: {act_mode}, dtype: {dtype}===") + M, K, N = 2, 16, 32 + quant_bit_size, use_hp_active, act_coef, alpha, beta, trans_a, trans_b = 8, True, 1., 0.8, 0.3, False, True + a = torch.randint(0, 10, (M, K), dtype=torch.int8).mlu() + b = torch.randint(0, 10, (N, K), dtype=torch.int8).mlu() + c = torch.randn(M, N, device="mlu", dtype=dtype) if has_c else None + bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None + a_scale = torch.randn(M, device="mlu", dtype=torch.float) + b_scale = torch.randn(N, device="mlu", dtype=torch.float) + a_zero, b_zero, c_zero = None, None, None + c_scale, gemm_output_scale, gemm_output_zero = None, None, None + quant_algo, a_quant_layout, b_quant_layout = "smooth_quant", "quantize_per_token", "quantize_per_channel" + str_dtype = dtype_dict[dtype] + + args = [a, a_scale, a_zero, + b, b_scale, b_zero, + bias, c, c_scale, c_zero, + gemm_output_scale, gemm_output_zero, + str_dtype, None, quant_algo, + a_quant_layout, b_quant_layout, + quant_bit_size, act_mode, use_hp_active, act_coef, + alpha, beta, trans_a, trans_b,] + self.base_opcheck(torch.ops.torch_mlu_ops.quant_matmul, args) + +if __name__ == '__main__': + random.seed(0) + torch.manual_seed(0) + exit(run_unittest(TestSmoothQuantMatmulOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_swap_blocks.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_swap_blocks.py new file mode 100755 index 0000000..52e4f16 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_swap_blocks.py @@ -0,0 +1,88 @@ +import random +import torch +import torch_mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +import os + +def gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair): + shape = (num_blocks, num_heads, block_size, head_size) + if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}: + info = torch.iinfo(dtype) + if cpy == "mlu to mlu": + src = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu() + dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu() + elif cpy == "mlu to cpu": + src = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu() + dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).cpu() + elif cpy == "cpu to mlu": + src = torch.randint(info.min, info.max, size=shape, dtype=dtype).cpu() + dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu() + else: + print("unkown copy direction.", flush=True) + exit(1) + else: + if cpy == "mlu to mlu": + src = torch.randn(size=shape, dtype=dtype).mlu() + dst = torch.randn(size=shape, dtype=dtype).mlu() + elif cpy == "mlu to cpu": + src = torch.randn(size=shape, dtype=dtype).mlu() + dst = torch.randn(size=shape, dtype=dtype).cpu() + elif cpy == "cpu to mlu": + src = torch.randn(size=shape, dtype=dtype).cpu() + dst = torch.randn(size=shape, dtype=dtype).mlu() + else: + print("unkown copy direction.", flush=True) + exit(1) + + values = list(range(num_pair)) + random.shuffle(values) + src_to_dst = {key: value for key, value in zip(range(num_pair), values)} + return dst, src, src_to_dst + +class TestSwapBlocksOp(BtTestCase): + def op_impl_base(self, *args): + dst, src, block_mapping = args + for key, value in block_mapping.items(): + dst[value] = src[key] + return dst + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test swap blocks due to ASan issues") + def test_swap_blocks(self): + num_blocks_list = [3600] + num_heads_list = [8] + head_size_list = [64,128] + block_size_list = [16] + num_pairs_list = [6,512] + + types = [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.half, torch.float] + if torch_mlu.mlu.is_bf16_supported(): + types.append(torch.bfloat16) + cpys = ["mlu to mlu", "mlu to cpu", "cpu to mlu"] + args = product(num_blocks_list, num_heads_list, block_size_list, head_size_list, types, cpys, num_pairs_list) + for num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair in args: + print("num_blocks: {}, num_heads: {}, block_size: {}, head_size: {}, dtype: {}, dir: {}, num_pairs: {} testing..." + .format(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair), flush=True) + + dst, src, src_to_dst = gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair) + + ref_src, ref_dst = src.clone(), dst.clone() + # cpu + self.op_impl_base(ref_dst, ref_src, src_to_dst) + # mlu + ops.swap_blocks(dst, src, src_to_dst) + # diff + self.assertTensorsEqual(src.cpu().float(), ref_src.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + self.assertTensorsEqual(dst.cpu().float(), ref_dst.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True) + + @unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues") + def test_inductor(self): + num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair = 3600, 8, 16, 64, torch.half, "mlu to mlu", 512 + dst, src, block_mapping = gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair) + self.base_opcheck(torch.ops.torch_mlu_ops.swap_blocks, (dst, src, block_mapping)) + + +if __name__ == '__main__': + exit(run_unittest(TestSwapBlocksOp)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_update_out_and_lse.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_update_out_and_lse.py new file mode 100755 index 0000000..b97d4b4 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_update_out_and_lse.py @@ -0,0 +1,150 @@ +import torch +from torch_mlu import mlu +import unittest +import torch_mlu_ops as ops +from common_utils import * +from itertools import product +from torch.nn import functional as F +import time + +def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack): + if not pack: + out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype) + lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32) + block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype) + block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32) + seq_offset = torch.randint(low=0, high=(max_seq_len - block_seq_len + 1), size=(batch, ), dtype=torch.int32, device="mlu") + cu_seqs = None + block_cu_seqs = None + else: + seq_lens = torch.randint(low=1, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32) + block_seq_lens = torch.randint(low=1, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32) + block_seq_lens = torch.minimum(seq_lens, block_seq_lens) + seq_offset = torch.zeros_like(seq_lens) + for i in range(batch): + seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32) + seq_offset = seq_offset.mlu() + cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu() + block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu() + total_seqs = torch.sum(seq_lens) + block_total_seqs = torch.sum(block_seq_lens) + + out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype) + lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32) + block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype) + block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32) + return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + +class TestUpdateOutAndLse(BtTestCase): + def op_impl_base(self, *args): + out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs = args + return update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs) + + def test_update_out_and_lse(self): + test_case_num = 10 + dtype_choice = [torch.float16, torch.float32] + random.seed(time.time()) + + count = test_case_num + while count > 0: + batch = random.randint(1, 16 + 1) + head_num = random.randint(1, 16 + 1) + head_size = random.randint(1, 512 + 1) + block_seq_len = random.choice([1, random.randint(2, 2048 + 1)]) + max_seq_len = 1 if block_seq_len == 1 else max(random.randint(2, 2048 + 1), block_seq_len) + pack = random.choice([True, False]) + + # 避免测试出现mlu显存不够 + if batch * head_num * head_size * max_seq_len > 10 * 1024 * 1024 * 1024: + continue + if torch_mlu.mlu.is_bf16_supported(): + dtype_choice.append(torch.bfloat16) + else: + if batch * head_num * head_size * max_seq_len > 1 * 1024 * 1024 * 1024: + continue + dtype = random.choice(dtype_choice) + + print(f"test_update_out_and_lse] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}") + + out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \ + gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack) + + out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True) + self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True) + count -= 1 + print(f"[test_update_out_and_lse] {test_case_num} cases test pass") + + def test_combine_ring_attn(self): + batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack =\ + 1, 8, 128, 8192, 8192, torch.bfloat16, True + + if not torch_mlu.mlu.is_bf16_supported(): + dtype = torch.float16 + + print(f"[test_combine_ring_attn] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}") + + out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \ + gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack) + # total_seq和block_total_seq为所有batch的真实seq长度之和,假设所有seq的 + # out: [sum(out_seqs), 64, 128] block_out [sum(block_out_seqs), 64, 128] + # lse: [1, 8, 8192] block_lse [1, 8, 8192] + # seq_offset: [128] + # cu_seqs block_cu_seqs: [128 + 1] + + out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + + ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + + self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True) + self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True) + print("[test_combine_ring_attn] pass") + + def test_combine_decoder_attn(self): + batch_list = [16, 128] + head_num_list = [64] + head_size_list = [128] + block_seq_len_list = [1] + max_seq_len_list = [1] + dtype_list = [torch.float16] + pack_list = [False] + + if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + + args = product(batch_list, head_num_list, head_size_list, block_seq_len_list, max_seq_len_list, + dtype_list, pack_list) + + for batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack in args: + print(f"[test_combine_decoder_attn] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}") + + out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \ + gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack) + # out: [128, 1, 64, 128] block_out [128, 1, 64, 128] + # lse: [128, 64, 1] block_lse [128, 1] + # seq_offset cu_seqs block_cu_seqs: None + + out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + + ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs) + + self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True) + self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True) + print("[test_combine_decoder_attn] pass") + + def test_inductor(self): + pack_list = [True, False] + for pack in pack_list: + batch, head_num, head_size, dtype = 16, 8, 128, torch.float16 + if pack: + block_seq_len, max_seq_len = 1024, 2048 + else: + block_seq_len, max_seq_len = 1, 1 + out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \ + gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack) + self.base_opcheck(torch.ops.torch_mlu_ops.update_out_and_lse, (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)) + print("[test_update_out_and_lse] test_inductor check pass") + +if __name__ == '__main__': + exit(run_unittest(TestUpdateOutAndLse)) diff --git a/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_weight_only_quant_matmul.py b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_weight_only_quant_matmul.py new file mode 100755 index 0000000..c639414 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tests/ops_pytest/test_weight_only_quant_matmul.py @@ -0,0 +1,113 @@ +import torch +import torch_mlu +import torch_mlu_ops as ops +from common_utils import * +from itertools import product + +dtype_dict = { + torch.half: "half", + torch.bfloat16: "bfloat16", +} +dtype_list = [torch.half] +if torch_mlu.mlu.is_bf16_supported(): + dtype_list.append(torch.bfloat16) + +# M<=INT32_MAX, nk<=INT32_MAX, k>=16, n>=16 +class TestWeightOnlyQuantMatmulOp(BtTestCase): + def op_impl_base(self, *args): + a, b, scale, zero, bias, c, act_mode, quant_bit_size, alpha, beta, use_hp_active = args + if quant_bit_size == 4: + n = b.shape[0] + b = UnpackInt4(b).view(n, -1) + if scale is not None: + if scale.dim() == 2: + group_size = b.size(1) // scale.size(1) + scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape) + else: + scale_bd = scale.unsqueeze(-1) + b = torch.mul(b, scale_bd).to(a.dtype) + output = torch.matmul(a, b.permute(1,0)) + if bias is not None: + output += bias + output = torch.mul(output, alpha) + if c is not None: + residual = torch.mul(c, beta) + output = torch.add(output, residual) + if act_mode != "none": + act = act_mode_dict[act_mode] + output = act(output) + return output + + def test_weight_only_quant_matmul(self): + m_list = [32, 64, 128] + n_list = [128, 256, 512] + group_list = [1, 2] + k_list = [128, 256, 512] + quant_bit_list = [8, 4] + has_bias_list = [True, False] + has_c_list = [True, False] + act_mode_list = ["none", "silu", "gelu"] + use_hp_active_list = [True, False] + + args = product(m_list, n_list, k_list, group_list, quant_bit_list, has_bias_list, has_c_list, act_mode_list, dtype_list, use_hp_active_list) + for m, n, k, group, quant_bit, has_bias, has_c, act_mode, dtype, use_hp_active in args: + if has_c and act_mode != "none": + continue + a = torch.randn(m, k, device="mlu", dtype=dtype) + b = torch.randn(n, k, device="mlu", dtype=dtype) + bias, c = None, None + zero = None + if has_bias: + bias = torch.randn(n, device="mlu", dtype=dtype) + if has_c: + c = torch.randn(m, n, device="mlu", dtype=dtype) + quant_weight_int8, weight_scale = QuantByRow(b, quant_bit, group) + if group != 1: + if act_mode != "none": + continue + weight_scale = weight_scale.to(a.dtype) + if quant_bit == 4: + quant_weight_int4 = PairlyPackInt8(quant_weight_int8) + + args = (a, quant_weight_int4 if quant_bit == 4 else quant_weight_int8, + weight_scale, zero, bias, c, act_mode, quant_bit, 1.0, 1.0, use_hp_active) + torch_output = self.op_impl_base(*args) + tmo_output = ops.weight_only_quant_matmul(*args) + self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.004, use_MSE=True) + + def test_inductor(self): + M, K, N, group_num = 2, 256, 32, 4 + quant_bit_size, act_mode, use_hp_active, act_coef, alpha, beta, trans_a, trans_b = 8, 'none', True, 1., 0.8, 0.3, False, True + has_res_list = [True, False] + group_quant_list = [True, False] + args = product(has_res_list, group_quant_list, dtype_list) + for has_res, group_quant, dtype in args: + print(f"==has_res: {has_res}, group_quant: {group_quant}, dtype: {dtype}==") + a = torch.randn((M, K), dtype=dtype).mlu() + b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu() + c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None + bias = torch.randn(N, device="mlu", dtype=dtype) + group_wise_scale = torch.randn((N, group_num), device="mlu", dtype=dtype) + b_quant_layout = "quantize_group_wise" if group_quant else "quantize_per_channel" + b_scale = group_wise_scale if group_quant else None + gemm_output_scale = None if group_quant else torch.randn(N, device="mlu", dtype=torch.float) + a_scale, a_zero, b_zero, c_zero = None, None, None, None + c_scale, gemm_output_zero = None, None + quant_algo, a_quant_layout = "weight_only", "quantize_none" + dtype_str = dtype_dict[dtype] + + args = [a, a_scale, a_zero, + b, b_scale, b_zero, + bias, c, c_scale, c_zero, + gemm_output_scale, gemm_output_zero, + dtype_str, None, quant_algo, + a_quant_layout, b_quant_layout, + quant_bit_size, act_mode, use_hp_active, act_coef, + alpha, beta, trans_a, trans_b,] + self.base_opcheck(torch.ops.torch_mlu_ops.quant_matmul, args) + + +if __name__ == '__main__': + random.seed(0) + torch.manual_seed(0) + exit(run_unittest(TestWeightOnlyQuantMatmulOp)) diff --git a/torch_mlu_ops-v1.3.2/tools/diff.py b/torch_mlu_ops-v1.3.2/tools/diff.py new file mode 100755 index 0000000..cf67aeb --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tools/diff.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 + +import numpy as np +import sys + + +def check(x: np.ndarray, y: np.ndarray) -> tuple: + error = np.abs(x - y) + diff1 = np.sum(error)/np.sum(np.abs(x)) + diff2 = np.sqrt(np.sum(error**2)/np.sum(x**2)) + + max = np.max(error) + mean = np.mean(error) + return diff1, diff2, max, mean + +if "__main__" == __name__: + # read command line arguments + x = np.loadtxt(sys.argv[1]) + y = np.loadtxt(sys.argv[2]) + # compute mean squared error + diff1, diff2, max, mean = check(x, y) + + print("diff1: ", diff1) + print("diff2: ", diff2) + print("max error: ", max) + print("mean error: ", mean) diff --git a/torch_mlu_ops-v1.3.2/tools/header_check.py b/torch_mlu_ops-v1.3.2/tools/header_check.py new file mode 100755 index 0000000..78d9373 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tools/header_check.py @@ -0,0 +1,91 @@ +import argparse +import logging +import sys +import os +import re + +copyright = "/" + "*" * 73 + "\n" +copyright += " * Copyright (C) [2023-2024] by Cambricon, Inc.\n" +copyright += " *\n" +copyright += " * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS\n" +copyright += " * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n" +copyright += " * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n" +copyright += " * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\n" +copyright += " * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\n" +copyright += " * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\n" +copyright += " * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n" +copyright += " " + "*" * 73 + "/\n" + +tmo_path = os.path.abspath(sys.path[0] + "../..") + +def arg(): + """ Argparser. """ + parser = argparse.ArgumentParser(''' + A format script for checking cc/h is following llvm format + or not.''') + parser.add_argument( + "--path", help="Location of .h/.cpp file to generate template.") + return parser.parse_args() + + +class Header(object): + """ Object to gen/check header. """ + + def __init__(self, path): + """ Init function with args. """ + # Check path first. + self.__path = os.path.abspath(path) + self.filename = os.path.split(path)[1] + if path.endswith(".cpp") or path.endswith(".cc"): + self.__has_macro = False + elif path.endswith(".h") or path.endswith(".mluh"): + self.__has_macro = True + else: + logging.info("Skip: format_checker is only for .cpp/.cc/.h") + logging.info("But met " + self.__path) + sys.exit(0) + if not self.__path.startswith(tmo_path): + logging.info("Skip: format_checker is only for genesis file") + sys.exit(0) + # Gen header. + self.__regex = copyright.replace( + "(", "\(").replace(")", "\)").replace("[", "\[").replace( + "]", "\]").replace("-", "\-").replace("*", "\*") + if self.__has_macro: + file_macro = self.__path.replace(tmo_path + "/", "").replace( + "/", "_").replace(".", "_").upper()+"_" + self.__macro = "#ifndef " + file_macro + "\n" + self.__macro += "#define " + file_macro + "\n" + self.__macro += "#endif // " + file_macro + "\n" + + def check(self): + """ To check header/macro. """ + fo = open(self.__path, "r") + file_string = fo.read() + fo.close() + ptrn_header = re.compile(self.__regex, re.DOTALL) + h_r = ptrn_header.match(file_string) + if not h_r: + logging.error(args.path + + " header mismatch, pattern should be like:") + logging.error("\n" + self.__regex) + sys.exit(-1) + + if self.__has_macro: + ptrn_macro = re.compile(".*" + self.__macro.replace("\n", "\n.*"), + re.DOTALL) + m_r = ptrn_macro.match(file_string) + if not m_r: + logging.error(args.path + + " macro mismatch, pattern should be like:") + logging.error("\n" + self.__macro) + sys.exit(-1) + +if __name__ == "__main__": + args = arg() + logging.basicConfig(level=logging.INFO) + if not os.path.exists(args.path): + logging.info("Skip: no such file") + sys.exit(0) + h = Header(args.path) + h.check() diff --git a/torch_mlu_ops-v1.3.2/tools/py_coverage/.coveragerc b/torch_mlu_ops-v1.3.2/tools/py_coverage/.coveragerc new file mode 100644 index 0000000..d15b96e --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tools/py_coverage/.coveragerc @@ -0,0 +1,3 @@ +[run] +omit = + /tmp/* diff --git a/torch_mlu_ops-v1.3.2/tools/py_coverage/README.md b/torch_mlu_ops-v1.3.2/tools/py_coverage/README.md new file mode 100644 index 0000000..7b1756d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tools/py_coverage/README.md @@ -0,0 +1,11 @@ +## PYTEST覆盖率脚本使用方式 + +```bash +bash run_test.sh +``` + +- 必须在Torch-MLU-Ops docker容器内运行。 + +- 脚本原理是使用coverage命令收集py文件运行时的覆盖率数据,所以脚本依托于`tests/kernels_pytest/run_test.sh`和`tests/ops_pytest/run_test.sh`运行。 + +- 脚本统计的是`tests/kernels_pytest`和`tests/ops_pytest`下py文件的覆盖率数据。 diff --git a/torch_mlu_ops-v1.3.2/tools/py_coverage/run_test.sh b/torch_mlu_ops-v1.3.2/tools/py_coverage/run_test.sh new file mode 100644 index 0000000..545250d --- /dev/null +++ b/torch_mlu_ops-v1.3.2/tools/py_coverage/run_test.sh @@ -0,0 +1,22 @@ +KERNELS_PYTEST_DIR="$( cd "$( dirname "$0" )/../../tests/kernels_pytest" && pwd )" +OPS_PYTEST="$( cd "$( dirname "$0" )/../../tests/ops_pytest" && pwd )" + +if ! command -v coverage &> /dev/null +then + echo "install coverage..." + pip install coverage + + if [ $? -eq 0 ]; then + echo "install coverage success" + else + echo "install coverage failure" + exit 1 + fi +fi + +${KERNELS_PYTEST_DIR}/build.sh +${KERNELS_PYTEST_DIR}/run_test.sh coverage +${OPS_PYTEST}/run_test.sh coverage + +coverage report +coverage html diff --git a/torch_mlu_ops-v1.3.2/torch_mlu_ops/__init__.py b/torch_mlu_ops-v1.3.2/torch_mlu_ops/__init__.py new file mode 100644 index 0000000..6dfc908 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/torch_mlu_ops/__init__.py @@ -0,0 +1,6 @@ +import torch +from ._ops import * +if torch.__version__ >= '2.3.0': + from .abstract import * + +from ._version import __version__ diff --git a/torch_mlu_ops-v1.3.2/torch_mlu_ops/_ops.py b/torch_mlu_ops-v1.3.2/torch_mlu_ops/_ops.py new file mode 100644 index 0000000..3807e86 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/torch_mlu_ops/_ops.py @@ -0,0 +1,2704 @@ +import torch +import torch_mlu +import inspect +from ._utils import torchDtype2Str as _torchDtype2Str +from typing import Tuple, List, Dict, Optional, Union +from ._utils import get_custom_op_library_path as _get_custom_op_library_path +from ._utils import add_gen_case_decorator + +torch.ops.load_library(_get_custom_op_library_path()) + +__all__ = [ + "active", + "apply_rotary", + "attention_project", + "batch_matmul", + "copy_blocks", + "dequant_from_linear_cache", + "dequant_from_paged_cache", + "ffn", + "flash_attention", + "flash_attn_sq_mm_allreduce", + "fused_layer_norm", + "fused_moe", + "fused_norm_attention_project", + "fused_norm_residual_ffn", + "fused_rms_norm", + "fused_rope", + "group_gemm", + "matmul", + "matmul_allreduce", + "moe_active", + "moe_cast_gating", + "moe_combine_result", + "moe_expand_input", + "moe_gen_idx", + "moe_quantize", + "moe_softmax_topk", + "offline_quant_to_linear_cache", + "offline_quant_to_paged_cache", + "per_token_smooth_quantize", + "preload", + "quant_to_linear_cache", + "quant_to_paged_cache", + "quantize", + "reshape_linear_cache", + "reshape_paged_cache", + "single_query_cached_kv_attn", + "single_query_mixed_cached_kv_attn", + "smooth_quant_group_gemm", + "smooth_quant_matmul", + "smooth_quant_matmul_allreduce", + "swap_blocks", + "update_out_and_lse", + "weight_only_quant_matmul", +] + +def quantize(x: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor = None + ) -> torch.Tensor: + """ + Apply quantization to the input tensor. + + Math: + scaled = x * scale + output = scaled.round().clamp(-128, 127).to(torch.int8) + + Args: + x (torch.Tensor): The tensor to be quantized. Shape is (..., C). + scale (torch.Tensor): The scale multipled to the input tensor. Shape is (C). + zero (torch.Tensor): Not supported, must pass None. + + DataType: + x: float, half or bfloat16 + scale: float + output: int8 + + Return: + A int8-tensor with the same shape of x. + """ + output = torch.empty_like(x, dtype=torch.int8, device=x.device) + torch.ops.torch_mlu_ops.smooth_quant(x, scale, output, torch.Tensor(), zero, + None, None, None, 'per_token', False) + return output + +def per_token_smooth_quantize(x: torch.Tensor, + smooth: torch.Tensor, + zero: torch.Tensor = None, + token_count: torch.Tensor = None + ) -> Tuple[torch.Tensor]: + """ + Apply per_token smooth quantization to the input tensor. + + Math: + if token_count: + input_list = input.split(token_count) + group = token_count.size() + result = [] + for i in range(group): + result.append(input_list[i] * smooth[i]) + smoothed = concat(result, dim=0) + else: + smoothed = input * smooth + max, _ = smoothed.abs().max(dim=-1, keepdim=True) + scale = max.to(torch.float) / 127.0 + quanted = (smoothed / scale).round().clamp(-128, 127).to(torch.int8) + output = (quanted, scale) + + Args: + input (torch.Tensor): The tensor to be quantized. Shape is (..., C) or (token_count.sum(), C). + smooth (torch.Tensor): The smooth scale multipled to the input tensor. Shape is (C) or (group, C). + zero (torch.Tensor): Not supported, must pass None. + token_count (torch.Tensor): The tensor to separate input. Shape is (group). + + DataType: + x: float, half or bfloat16 + smooth: float + token_count: int32 + output: (int8, float) + + Return: + A tuple of two tensors. + quantized_x: int8-tensor with the same shape of x. + scale: float-tensor with the shape of x.shape[0:-1]. + """ + output = torch.empty(x.size(), dtype=torch.int8, device=x.device) + output_scale = torch.empty(x.size()[:-1], dtype=smooth.dtype, device=smooth.device) + torch.ops.torch_mlu_ops.smooth_quant(x, smooth, output, output_scale, zero, + token_count, None, None, 'per_token', True) + return (output, output_scale) + +def weight_only_quant_matmul(a: torch.Tensor, + b: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor = None, + bias: torch.Tensor = None, + c: torch.Tensor = None, + act_mode: str = "none", + quant_bit_size: int = 8, + alpha: float = 1.0, + beta: float = 1.0, + use_hp_active: bool = False): + """ + Perform matrix multiplication on tensor a and quantized tensor b. + Math: + ab = matmul(a, b) * scale + bias + acted = active(ab) + output = alpha * acted + beta * c + + Args: + a (torch.Tensor): Shape is (M, K). + b (torch.Tensor): If quant_bit_size = 8, shape is (N, K). + If quant_bit_size = 4, shape is (N, K//2). + scale (torch.Tensor): If use groupwise quantization, shape must be (N, group_num), data type must be + the same as a; otherwise shape must be (N), data tyte must be float. + zero (torch.Tensor): Not supported, must pass None. + bias (torch.Tensor): Shape is (N). + c (torch.Tensor): Shape is (M, N). + act_mode (str): Choose the activation algorithm, must be 'silu', 'gelu' or 'none'. If use groupwise + quantization, act_mode must be 'none'. + quant_bit_size (int): The data format of b. + alpha (float): coefficient of acted. + beta (float): coefficient of c. + use_hp_active (bool): Describing the algorithm that used in the implementation of the activation function. + When the value is true, use the high-precision algorithm, otherwise use the fastest algorithm of activation. + + Type: + a: float, half, bfloat16 + b: int8 + scale: float or same as a + zero: float + bias: same type as a + c: same type as a + output: same type as a + + Return: + A tensor with the shape of (M, N). + """ + is_group_quant = scale.dim() == 2 + b_scale = scale if is_group_quant else None + gemm_out_scale = None if is_group_quant else scale + b_quant_layout = 'quantize_group_wise' if is_group_quant else 'quantize_per_channel' + return quant_matmul(a, None, None, b, b_scale, None, bias, c, None, None, + gemm_out_scale, zero, None, "weight_only", "quantize_none", b_quant_layout, + quant_bit_size, act_mode, use_hp_active, 1.0, alpha, beta, False, True) + +def smooth_quant_matmul(a: torch.Tensor, + a_scale: Optional[torch.Tensor], + b: torch.Tensor, + b_scale: torch.Tensor, + dtype: torch.dtype, + bias: torch.Tensor = None, + c: torch.Tensor = None, + act_mode: str = "none", + alpha: float = 1.0, + beta: float = 1.0, + use_hp_active: bool = False): + """ + Perform smooth quantized matrix multiplication on tensor a and b. + + Math: + if a_scale is not None: + ab = matmul(a, b) * a_scale.outer(b_scale) + bias + else: + ab = matmul(a, b) * b_scale + bias + acted = active(ab) + output = alpha * acted + beta * c + + Args: + a (torch.Tensor): Shape is (M, K). + a_scale (torch.Tensor): The per_token scale of a, shape is (M). If None, quant_layout is "quantize_per_tensor", and scale should be converted into b_scale. + b (torch.Tensor): Shape is (N, K). + b_scale (torch.Tensor): The per_channel scale of b, shape is (N). + dtype (torch.dtype): Specify the data type of output, must be torch.float, torch.half or torch.bfloat16. + bias (torch.Tensor): Shape is (N). + c (torch.Tensor): Shape is (M, N). + act_mode (str): Choose the activation algorithm, must be 'silu', 'gelu' or 'none'. + alpha (float): coefficient of acted. + beta (float): coefficient of c. + use_hp_active (bool): Describing the algorithm that used in the implementation of the activation function. + When the value is true, use the high-precision algorithm, otherwise use the fastest algorithm of activation. + + Type: + a: int8 + a_scale: float + b: int8 + b_scale: float + bias: same as output + c: same as output + output: specified by dtype + + Return: + A tensor with the shape of (M, N). + """ + is_pertoken = a_scale is not None + assert dtype in (torch.half, torch.float, torch.bfloat16), "dtype must be half, float or bfloat16, got: {}.".format(dtype) + return quant_matmul(a, a_scale, None, b, b_scale, None, bias, c, None, None, b_scale, None, + _torchDtype2Str(dtype), "smooth_quant", + "quantize_per_token" if is_pertoken else "quantize_per_tensor", + "quantize_per_channel", + 8, act_mode, use_hp_active, 1.0, alpha, beta, False, True) + +def smooth_quant_matmul_allreduce(cncl_comm, + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + dtype: torch.dtype, + bias: torch.Tensor = None, + c: torch.Tensor = None, + alpha: float = 1.0, + beta: float = 1.0, + block_m: int = 0): + """ + Parallelly apply smooth quantized matrix multiplication with TP-allreduce. + + Math: + if a_scale is not None: + ab = matmul(a, b) * a_scale.outer(b_scale) + bias + else: + ab = matmul(a, b) * b_scale + bias + acted = active(ab) + output = alpha * acted + beta * c + output_allreduce = allreduce(output) + + Args: + cncl_comm (int): The handle of CnclComm. + a (torch.Tensor): Shape is (M, K). + a_scale (torch.Tensor): The per_token scale of a, shape is (M) + b (torch.Tensor): Shape is (N, K). + b_scale (torch.Tensor): The per_channel scale of b, shape is (N). + dtype (torch.dtype): Specify the data type of output, must be torch.float, torch.half or torch.bfloat16. + bias (torch.Tensor): Shape is (N). + c (torch.Tensor): Shape is (M, N). + alpha (float): coefficient of acted. + beta (float): coefficient of c. + block_m (int): If the value is greater than 0, the specified value will be used; if the value is less than 0, the recommended value will be used. + + Type: + a: int8 + a_scale: float + b: int8 + b_scale: float + bias: same as output + c: same as output + output: specified by dtype + + Return: + A tensor with the shape of (M, N). + """ + is_pertoken = a_scale is not None + assert dtype in (torch.half, torch.float, torch.bfloat16), "dtype must be half, float or bfloat16, got: {}.".format(dtype) + return quant_matmul_allreduce(cncl_comm, a, a_scale, None, b, b_scale, + None, bias, c, None, None, b_scale, None, + _torchDtype2Str(dtype), "smooth_quant", + "quantize_per_token" if is_pertoken else "quantize_per_tensor", + "quantize_per_channel", + 8, alpha, beta, False, True, block_m) + +def fused_layer_norm(x: torch.Tensor, + residual: Optional[torch.Tensor], + gamma: Optional[torch.Tensor], + beta: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + store_output_before_norm: bool, + quant_scale: torch.Tensor = None, + out: torch.Tensor = None, + dynamic_quant: bool = False): + """ + Apply layernorm to the input tensor. + Please refer to https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm + + Math: + if bias is not None: + x1 = x + bias + if residual is not None: + x1 += residual + normed = layernorm(x1, gamma, beta, eps) + if dynamic_quant: + max, _ = (normed * quant_scale).abs().max(dim=-1, keepdim=True) + smooth_quant_scale = max.to(torch.float) / int_max + normed = (normed / smooth_quant_scale).round().to(int8) + elif quant_scale is not None: + normed = (normed * quant_scale).round().clamp(-128, 127).to(torch.int8) + if store_output_before_norm == True: + output = (normed, x1) + else: + output = normed + + Args: + x (torch.Tensor): The input tensor to be normalized. Shape is (p1, p2, ..., pn, C). + residual (torch.Tensor): The tensor add to x. Shape is the same as x. + gamma (torch.Tensor): The weight of layernorm. Shape is (C). + beta (torch.Tensor): The bias of layernorm. Shape is (C). + eps (float): The eps of layernorm. + store_output_before_norm (bool): If return the tensor before applying layernorm. + quant_scale (torch.Tensor): The smooth_scale of output if dynamic_quant, else it is the scale of quantization. Shape is (C). + out (torch.Tensor): The output tensor of normalizes. Shape is the same as x. + dynamic_quant(bool): Flag quant_mode of output, dynamic quant or not. + smooth_quant_scale(torch.Tensor): The output scale tensor. Shape is [p1, p2, ..., pn]. + + Type: + x: float, half, bfloat16 + residual: same as x + beta: same as x + gamma: same as x + bias: same as x + quant_scale: float + output: same as x, int8 + smooth_quant_scale: float + + Return: + Support inplace output. + If store_output_before_norm is True, return Two tensors before and after applying layernorm, + otherwise return a tensor after applying layernorm and quantization. + """ + if out is None: + if quant_scale is not None: + out = torch.empty(*x.shape, dtype=torch.int8, device=x.device) + else: + out = torch.empty(*x.shape, dtype=x.dtype, device=x.device) + else: + assert store_output_before_norm is False and quant_scale is None, "store_output_before_norm and quant_out are not support when inplace out." + residual_out = torch.empty(*x.shape, dtype=x.dtype, device=x.device) if store_output_before_norm else None + smooth_quant_scale = torch.empty(x.shape[:-1], dtype=torch.float, device=x.device) if dynamic_quant else None + torch.ops.torch_mlu_ops.fused_layernorm(x, out, residual, gamma, beta, bias, quant_scale, residual_out, smooth_quant_scale, "layernorm", eps, store_output_before_norm, dynamic_quant) + output = [out] + if store_output_before_norm: + output.append(residual_out) + if dynamic_quant: + output.append(smooth_quant_scale) + return output[0] if len(output) == 1 else tuple(output) + + +def fused_rms_norm(x: torch.Tensor, + residual: Optional[torch.Tensor], + gamma: Optional[torch.Tensor], + beta: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + store_output_before_norm: bool, + quant_scale: torch.Tensor = None, + out: torch.Tensor = None, + dynamic_quant: bool = False): + """ + Apply rmsnorm to the input tensor. + Paramter limitations the same as "fused_layer_norm", except 'beta' should be None. + """ + if out is None: + if quant_scale is not None: + out = torch.empty(*x.shape, dtype=torch.int8, device=x.device) + else: + out = torch.empty(*x.shape, dtype=x.dtype, device=x.device) + else: + assert store_output_before_norm is False and quant_scale is None, "store_output_before_norm and quant_out are not support when inplace out." + residual_out = torch.empty(*x.shape, dtype=x.dtype, device=x.device) if store_output_before_norm else None + smooth_quant_scale = torch.empty(x.shape[:-1], dtype=torch.float, device = x.device) if dynamic_quant else None + torch.ops.torch_mlu_ops.fused_layernorm(x, out, residual, gamma, beta, bias, quant_scale, residual_out, smooth_quant_scale, "rmsnorm", eps, store_output_before_norm, dynamic_quant) + output = [out] + if store_output_before_norm: + output.append(residual_out) + if dynamic_quant: + output.append(smooth_quant_scale) + return output[0] if len(output) == 1 else tuple(output) + +def attention_project(input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + residual: torch.Tensor = None, + alpha: float = 1.0, + beta: float = 1.0 + ) -> torch.Tensor: + """ + Perform attention project operation which including adding bias and residual. + + Math: + proj = matmul(input, weight) + if bias is not None: + proj += bias + if residual is not None: + proj = alpha * proj + beta * residual + output = proj + + Args: + input (torch.Tensor): The input tensor. Shape is (batch, seq, input_size) or (total_seq, input_size). + weight (torch.Tensor): The weight tensor. Shape is (hidden_size, input_size). + bias (torch.Tensor): The bias tensor. Shape is (hidden_size). + residual (torch.Tensor): The residual tensor. Shape is the same as input. + alpha (float): The coefficient of proj. + beta (float): The coefficient of residual. + + Type: + input: float, half, bfloat16 + weight: same as input + bias: same as input + residual: same as input + output: same as input + + Return: + A tensor with the shape of (batch, seq, hidden_size) or (total_seq, hidden_size). + """ + output = torch.ops.torch_mlu_ops.attention_project(input, weight, bias, None, None, None, + None, None, None, residual, "nthc", 1, + 1e-5, alpha, beta, False) + return output[0] + +def fused_norm_attention_project(input: torch.Tensor, + q_weight: torch.Tensor, + q_bias: Optional[torch.Tensor], + k_weight: Optional[torch.Tensor], + k_bias: Optional[torch.Tensor], + v_weight: Optional[torch.Tensor], + v_bias: Optional[torch.Tensor], + norm_weight: Optional[torch.Tensor], + norm_bias: Optional[torch.Tensor], + eps: float, + out_layout: str = "nthc", + head_size: int = 1, + norm_out: bool = False): + """ + Perform attention project operation which including layernorm, adding bias/residual and permuting. + + Math: + normed = layernorm(input, norm_weight, norm_bias, eps) + proj_q = matmul(normed, q_weight) + q_bias + if k_weight is not None: + proj_k = matmul(normed, k_weight) + k_bias + if v_weight is not None: + proj_v = matmul(normed, v_weight) + v_bias + if out_layout is 'nhtc': + c = proj_q.reshape(batch, seq, hidden_size_q // head_size, head_size).transpose(1, 2) + proj_k = proj_k.reshape(batch, seq, hidden_size_k // head_size, head_size).transpose(1, 2) + proj_v = proj_v.reshape(batch, seq, hidden_size_v // head_size, head_size).transpose(1, 2) + outputs = [proj_q] + if k_weight is not None: + outputs.append(proj_k) + if v_weight is not None: + outputs.append(proj_v) + if norm_out: + outputs.append(normed) + return tuple(outputs) + Args: + input (torch.Tensor): The input tensor. Shape is (batch, seq, input_size) or (total_seq, input_size). + q_weight (torch.Tensor): The q_weight tensor. Shape is (hidden_size_q, input_size). + q_bias (torch.Tensor): The q_bias tensor. Shape is (hidden_size_q). + k_weight (torch.Tensor): The k_weight tensor. Shape is (hidden_size_k, input_size). + k_bias (torch.Tensor): The k_bias tensor. Shape is (hidden_size_k). + v_weight (torch.Tensor): The v_weight tensor. Shape is (hidden_size_v, input_size). + v_bias (torch.Tensor): The v_bias tensor. Shape is (hidden_size_v). + norm_weight (torch.Tensor): The weight of layernorm. Shape is (input_size). + norm_bias (torch.Tensor): The bias of layernorm. Shape is (input_size). + eps (float): The eps of layernorm. + out_layout (str): The layout of output tensor. Must be 'nthc' or 'nhtc'. + head_size (int): The last dimension of output when out_layout is 'nhtc'. + norm_out (bool): If return the layernorm result. + + Type: + input: float, half, bfloat16 + q_weight: same as input + q_bias: same as input + k_weight: same as input + k_bias: same as input + v_weight: same as input + v_bias: same as input + norm_weight: same as input + norm_bias: same as input + output: same as input + + Return: + If len(outputs) > 1, return a tuple of tensors, otherwise return a tensor. + + """ + output = torch.ops.torch_mlu_ops.attention_project(input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias, + norm_weight, norm_bias, None, out_layout, head_size, eps, 1.0, + 1.0, norm_out) + return tuple(output) if len(output) > 1 else output[0] + +def ffn(input: torch.Tensor, + up_fc_weight: torch.Tensor, + up_fc_bias: Optional[torch.Tensor], + down_proj_weight: torch.Tensor, + down_proj_bias: Optional[torch.Tensor], + gate_up_proj_weight: Optional[torch.Tensor], + gate_up_proj_bias: Optional[torch.Tensor], + act_mode: str + ) -> torch.Tensor: + """ + Perform feed forward operation. + + Math: + up = matmul(input, up_fc_weight) + if up_fc_bias is not None: + up += up_fc_bias + act = active(up) + if gate_up_proj_weight is not None: + gate = matmul(input, gate_up_proj_weight) + if gate_up_proj_bias is not None: + gate += gate_up_proj_bias + act *= gate + down = matmul(act, down_proj_weight) + if gate_up_proj_bias is not None: + down += gate_up_proj_bias + output = down + + Args: + input (torch.Tensor): The input tensor. Shape is (batch, seq, hidden_size) or (total_seq, hidden_size). + up_fc_weight (torch.Tensor): The weight of upper matmul. Shape is (inner_size, hidden_size). + up_fc_bias (torch.Tensor): The bias of upper matmul. Shape is (inner_size). + down_proj_weight (torch.Tensor): The weight of lower matmul. Shape is (hidden_size, inner_size). + down_proj_bias (torch.Tensor): The bias of lower matmul. Shape is (hidden_size). + gate_up_proj_weight (torch.Tensor): The weight of gate matmul. Shape is (inner_size, hidden_size). + gate_up_proj_bias (torch.Tensor): The bias of gate matmul. Shape is (inner_size). + act_mode (str): The activation mode, must be 'silu', 'gelu', 'relu' or 'none'. + + Type: + input: float, half, bfloat16. + up_fc_weight: same as input. + up_fc_bias: same as input. + down_proj_weight: same as input. + down_proj_bias: same as input. + gate_up_proj_weight: same as input. + gate_up_proj_bias: same as input. + output: same as input. + + Return: + An output tensor. + """ + return torch.ops.torch_mlu_ops.ffn(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, + gate_up_proj_weight, gate_up_proj_bias, None, None, act_mode, "none", + 1e-5, 1., 1.) + +def fused_norm_residual_ffn(input: torch.Tensor, + up_fc_weight: torch.Tensor, + up_fc_bias: Optional[torch.Tensor], + down_proj_weight: torch.Tensor, + down_proj_bias: Optional[torch.Tensor], + gate_up_proj_weight: Optional[torch.Tensor], + gate_up_proj_bias: Optional[torch.Tensor], + layernorm_weight: Optional[torch.Tensor], + layernorm_bias: Optional[torch.Tensor], + eps: float, + act_mode: str, + residual_is: str = "input", + alpha: float = 1.0, + beta: float = 1.0 + ) -> torch.Tensor: + """ + Perform feed forward operation. + + Math: + residual = input + normed = input + if layernorm_weight is not None: + normed = layernorm(input, layernorm_weight, layernorm_bias, eps) + if residual_is == 'normed_input': + residual = normed + up = matmul(normed, up_fc_weight) + if up_fc_bias is not None: + up += up_fc_bias + act = active(up) + if gate_up_proj_weight is not None: + gate = matmul(input, gate_up_proj_weight) + if gate_up_proj_bias is not None: + gate += gate_up_proj_bias + act *= gate + down = matmul(act, down_proj_weight) + if gate_up_proj_bias is not None: + down += gate_up_proj_bias + if residual_is != 'none': + down = alpha * down + beta * residual + output = down + + Args: + input (torch.Tensor): The input tensor. Shape is (batch, seq, hidden_size) or (total_seq, hidden_size). + up_fc_weight (torch.Tensor): The weight of upper matmul. Shape is (inner_size, hidden_size). + up_fc_bias (torch.Tensor): The bias of upper matmul. Shape is (inner_size). + down_proj_weight (torch.Tensor): The weight of lower matmul. Shape is (hidden_size, inner_size). + down_proj_bias (torch.Tensor): The bias of lower matmul. Shape is (hidden_size). + gate_up_proj_weight (torch.Tensor): The weight of gate matmul. Shape is (inner_size, hidden_size). + gate_up_proj_bias (torch.Tensor): The bias of gate matmul. Shape is (inner_size). + layernorm_weight (torch.Tensor): The weight of layernorm. + layernorm_bias (torch.Tensor): The bias of layernorm. + eps (float): The eps of layernorm. + act_mode (str): The activation mode, must be 'silu', 'gelu', 'relu' or 'none'. + residual_is (str): A string controlling which tensor is residual, must be 'input', 'normed_input' or 'none'. + alpha (float): The coefficient of lower matmul result. + bete (float): The coefficient of residual. + + Type: + input: float, half, bfloat16. + up_fc_weight: same as input. + up_fc_bias: same as input. + down_proj_weight: same as input. + down_proj_bias: same as input. + gate_up_proj_weight: same as input. + gate_up_proj_bias: same as input. + layernorm_weight: same as input. + layernorm_bias: same as input. + output: same as input. + + Return: + An output tensor. + """ + return torch.ops.torch_mlu_ops.ffn(input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, + gate_up_proj_weight, gate_up_proj_bias, layernorm_weight, layernorm_bias, act_mode, + residual_is, eps, alpha, beta) + +def flash_attention(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seq_lens_q: Optional[torch.Tensor], + cu_seq_lens_kv: Optional[torch.Tensor], + alibi_slope: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], + max_seq_len_q: int, + max_seq_len_kv: int, + softmax_scale: float, + is_causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + compute_dtype: torch.dtype = torch.float, + return_lse: bool = False, + block_tables: Optional[torch.Tensor] = None, + k_cache_quant_scale: Optional[torch.Tensor] = None, + v_cache_quant_scale: Optional[torch.Tensor] = None + ): + """ + Apply attention operation on q, k and v. + + Math: + qk = bmm(q, k) * softmax_scale + if alibi_slope is not None: + qk += create_alibi(alibi_slope) + if attn_bias is not None: + qk += attn_bias + if is_causal: + qk = tril_mask(qk) + if return_lse: + lse = logsumexp(qk, dim=-1) + attn = softmax(qk, dim=-1) + attn_out = bmm(attn, v) + output = (attn_out, lse) if return_lse else attn_out + + Args: + q (torch.Tensor): The query tensor. Shape is (batch, seq_q, head_num_q, head_size_qk) or (total_seq_q, head_num_q, head_size_qk). + k (torch.Tensor): The key tensor. If block_tables is None, the shape of k is (batch, seq_kv, head_num_kv, head_size_qk) or (total_seq_kv, head_num_kv, head_size_qk). + If block_tables is not None, if use paged attention, shape is (total_blocks, head_num_kv, block_size, head_size_qk), + where the total_blocks = max_batch * (memory_cache_len / block_size). If not use paged attention, shape is (max_batch, head_num_kv, memory_cache_len, head_size_qk). + v (torch.Tensor): The value tensor. Shape is the same as k, except, head_size_v could be different with head_size_qk. + out (torch.Tensor): The output tensor. Shape is the same as q. + cu_seq_lens_q (torch.Tensor): The cusum of seq_q if q is 3-D tensor. Shape is (batch+1). + cu_seq_lens_kv (torch.Tensor): The cusum of seq_kv if k/v is 3-D tensor. Shape is (batch+1). + alibi_slope (torch.Tensor): The alibi_slope to generate alibi position embedding. Shape is (head_num_q). + attn_bias (torch.Tensor): The bias added to qk. Shape is (batch, head_num_q, max_seq_len_q, max_seq_len_kv). + max_seq_len_q (int): The maximum value of seq_q. + max_seq_len_kv (int): The maximum value of seq_kv. + softmax_scale (float): The scale factor multiplied to qk. + is_causal (bool): Controlling if use causal mask. + window_size_left (int): left window size, set -1 if it is unlimited. The maximum lengths of key and value involved in the attention calculation is windows_size_left + 1. + window_size_right (int): right window size, set -1 if it is unlimited. The maximum lengths of key and value involved in the attention calculation is windows_size_right + 1 + compute_dtype (torch.dtype): The dtype used to calculate softmax. + return_lse (bool): Controlling if return lse tensor. + k_cache_quant_scale (torch.Tensor): Reserved parameter, the quantized scale of k_cache, must be None. + v_cache_quant_scale (torch.Tensor): Reserved parameter, The quantized scale of v_cache, must be None. + block_tables (torch.Tensor): The tensor recording the k/v_cache's positions of each batch. + If use paged attention, shape is (batch, max_block_num), where max_block_num = memory_cache_len / block_size. + If not use paged attention, shape is (batch, 1). + + Type: + q: float, half, bfloat16. + k: same as q. + v: same as q. + cu_seq_lens_q: int32. + cu_seq_lens_kv: int32. + alibi_slope: float. + attn_bias: same as q. + attn_out : same as q. + lse: float. + + Return: + If return_lse is True, return a tuple of (attn_out, lse), else return attn_out. + """ + tmo_out = out + if out is None: + tmo_out = torch.empty(q.size()[:-1] + (v.size()[-1],), dtype=q.dtype, device=q.device) + out_lse = None + if return_lse: + is_pack = cu_seq_lens_q is not None + batch = cu_seq_lens_q.shape[0] - 1 if is_pack else q.shape[0] + head_num = q.shape[1] if is_pack else q.shape[2] + out_lse = torch.empty((batch, head_num, max_seq_len_q), dtype=torch.float64, device=q.device) + torch.ops.torch_mlu_ops.flash_attention(q, k, v, + tmo_out, out_lse, + cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, + attn_bias, k_cache_quant_scale, v_cache_quant_scale, + block_tables, max_seq_len_q, max_seq_len_kv, + softmax_scale, is_causal, + window_size_left, window_size_right, + _torchDtype2Str(compute_dtype), return_lse) + return (tmo_out, out_lse) if return_lse else tmo_out + + +def single_query_cached_kv_attn( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + out: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + k_cache_quant_scale: Optional[torch.Tensor], + v_cache_quant_scale: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_contxt_len: int, + windows_size_left: int, + windows_size_right: int, + softmax_scale: float, + return_lse: bool = False, + kv_cache_quant_bit_size = -1 +) -> torch.Tensor: + """ + Apply attention operation on query and cached key/value. + + Math: + out_list = [] + lse_list = [] + for i in range(q.shape[0]): + key = collect(k_cache, block_tables[i], context_lens[i]) + value = collect(v_cache, block_tables[i], context_lens[i]) + qk = bmm(q[i], key) * softmax_scale + if alibi_slope is not None: + alpha = qk + create_alibi(alibi_slope) + attn = softmax(alpha, dim=-1) + attn_out = bmm(attn, value) + lse = logsumexp(qk, dim = -1) + out_list.append(attn_out) + lse_list.append(lse) + output = concat(out_list, dim = 0) + output_lse = concat(lse_list, dim = 0) + + Args: + q (torch.Tensor): The query tensor, shape is (batch, seq_q, head_num_q, head_size_qk). + k_cache (torch.Tensor): The key cache tensor. + If use paged attention, shape is (total_blocks, head_num_kv, block_size, head_size_qk), where the total_blocks = max_batch * (memory_cache_len / block_size). + If not use paged attention, shape is (max_batch, head_num_kv, memory_cache_len, head_size_qk). + v_cache (torch.Tensor): The value cache tensor. shape is the same as k_cache, except, head_size_v could be different with head_size_qk. + out (torch.Tensor): The output tensor, shape is the same as q + block_tables (torch.Tensor): The tensor recording the k/v_cache's positions of each batch. + If use paged attention, shape is (batch, max_block_num), where max_block_num = max_cache_len / block_size. + If not use paged attention, shape is (batch, 1). + context_lens (torch.Tensor): The tensor recording the context lengths of each batch, shape is (batch). + k_cache_quant_scale (torch.Tensor): The quantized scale of k_cache. + When quantize mode is per_token, shape is + paged:(total_blocks, head_num_kv, block_size), linear: (max_batch, head_num_kv, memory_cache_len). + Or, paged:(total_blocks, head_num_kv, block_size, 1), linear: (max_batch, head_num_kv, memory_cache_len, 1). + when quantize mode is per_channel, shape is (head_num_kv, head_size) + v_cache_quant_scale (torch.Tensor): The quantized scale of v_cache, shape is the same as k_cache_quant_scale. + alibi_slopes (torch.Tensor): The alibi_slope to generate alibi position embedding. Shape is (head_num_q). + max_contxt_len (int): The maximum value of context_lens. + If max_contxt_len <= 0, indicating an uncertain maximum value of context_lens. + If 0 < max_contxt_len < max(context_lens), the correctness is not guaranteed. + windows_size_left (int): left window size, set -1 if it is unlimited. The maximum lengths of key and value involved in the attention calculation is windows_size_left + 1. + windows_size_right (int): Reserved parameter, set -1. + softmax_scale (float): The scale factor multiplied to qk. + return_lse (bool): Controlling if return lse tensor. + kv_cache_quant_bit_size (int): The data format of kv_cache, 4 for int4, 8 for int8, and -1 for the same as query. + + Type: + q: float, half, bfloat16. + k_cache: float, half, bfloat16, int8, int8(pairwise-int4). + v_cache: same as k_cache. + block_tables: int32. + context_lens: int32. + k_cache_quant_scale: float. + v_cache_quant_scale: float. + alibi_slopes: float. + + Return: + If return_lse is True, return a tuple of (output, lse), else return output. + """ + tmo_out = torch.empty((q.shape[0], q.shape[1], q.shape[2], v_cache.shape[-1]), dtype=q.dtype, device=q.device) if out is None else out + out_lse = None + if return_lse: + batch = q.shape[0] + seq_q = q.shape[1] + head_num = q.shape[2] + out_lse = torch.empty((batch, head_num, seq_q), dtype=torch.float, device=q.device) + torch.ops.torch_mlu_ops.single_query_cached_kv_attn( + q, k_cache, v_cache, tmo_out, block_tables, context_lens, out_lse, + k_cache_quant_scale, v_cache_quant_scale, alibi_slopes, max_contxt_len, + windows_size_left, windows_size_right, softmax_scale, return_lse, + kv_cache_quant_bit_size) + return (tmo_out, out_lse) if return_lse else tmo_out + +def single_query_mixed_cached_kv_attn( + q: torch.Tensor, + k_cache_lp: torch.Tensor, + v_cache_lp: torch.Tensor, + k_cache_hp: torch.Tensor, + v_cache_hp: torch.Tensor, + out: torch.Tensor, + block_tables_lp: torch.Tensor, + block_tables_hp: torch.Tensor, + context_lens_lp: torch.Tensor, + context_lens_hp: torch.Tensor, + k_cache_quant_scale_lp: Optional[torch.Tensor], + v_cache_quant_scale_lp: Optional[torch.Tensor], + k_cache_quant_scale_hp: Optional[torch.Tensor], + v_cache_quant_scale_hp: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_contxt_len_lp: int, + max_contxt_len_hp: int, + softmax_scale: float, + return_lse: bool = False, + kv_cache_quant_bit_size_lp = -1, + kv_cache_quant_bit_size_hp = -1, +) -> Union[Tuple[torch.Tensor], torch.Tensor]: + """ + Apply attention operation on query and cached key/value. + + Math: + lse1, attn_out1 = single_query_cached_kv_attn(query, key_cache_lp, value_cache_lp, ...) + lse2, attn_out2 = single_query_cached_kv_attn(query, key_cache_hp, value_cache_hp, ...) + output = update_out_and_lse(lse1, attn_out1, lse2, attn_out2) + + Args: + q (torch.Tensor): The query tensor, shape is (batch, seq_q, head_num_q, head_size). seq_q must be 1 now. + k_cache_lp/k_cache_hp (torch.Tensor): The key cache tensor. + If use paged attention, shape is (total_blocks, head_num_kv, block_size, head_size), where the total_blocks = max_batch * (memory_cache_len / block_size). + If not use paged attention, shape is (max_batch, head_num_kv, memory_cache_len, head_size). + v_cache_lp/vcache_hp (torch.Tensor): The value cache tensor, shape is the same as k_cache, just, head_size_v could be different with head_size_qk. + out (torch.Tensor): The output tensor, shape is the same as q + block_tables_lp/block_tables_hp (torch.Tensor): The tensor recording the k/v_cache's positions of each batch. + If use paged attention, shape is (batch, max_block_num), where max_block_num = memory_cache_len / block_size. + If not use paged attention, shape is (batch, 1). + context_lens_lp/context_lens_hp (torch.Tensor): The tensor recording the context lengths of each batch, shape is (batch). + k_cache_quant_scale_lp/k_cache_quant_scale_hp (torch.Tensor): The quantized scale of k_cache. + When quantize mode is per_token, shape is + paged:(total_blocks, head_num_kv, block_size, 1), linear: (max_batch, head_num_kv, memory_cache_len, 1). + when quantize mode is per_channel, shape is (head_num_kv, head_size) + v_cache_quant_scale_lp/v_cache_quant_scale_hp (torch.Tensor): The quantized scale of v_cache, shape is the same as k_cache_quant_scale. + alibi_slopes (torch.Tensor): The alibi_slope to generate alibi position embedding. Shape is (head_num_q). + max_contxt_len_lp/max_contxt_len_hp (int): The maximum value of context_lens. + If max_contxt_len <= 0, indicating an uncertain maximum value of context_lens. + If 0 < max_contxt_len < max(context_lens), the correctness is not guaranteed. + softmax_scale (float): The scale factor multiplied to qk. + return_lse (bool): Controlling if return lse tensor. + kv_cache_quant_bit_size_lp/kv_cache_quant_bit_size_hp (int): The data format of kv_cache, -1 for float-point. + + Type: + q: float, half, bfloat16. + k_cache_lp/k_cache_hp: float, half, bfloat16, int8, int8(pairwise-int4). + v_cache_lp/v_cache_hp: same as k_cache. + block_tables_lp/block_tables_hp: int32. + context_lens: int32. + k_cache_quant_scale_lp/k_cache_quant_scale_hp: float. + v_cache_quant_scale_lp/v_cache_quant_scale_hp: float. + alibi_slopes: float. + + Return: + If return_lse is True, return a tuple of (output, lse), else return output. + """ + + batch = q.shape[0] + seq_q = q.shape[1] + head_num = q.shape[2] + head_size_v = v_cache_lp.shape[-1] + tmo_out = torch.empty((q.shape[0], q.shape[1], q.shape[2], head_size_v), dtype=q.dtype, device=q.device) if out is None else out + out_hp = torch.empty((q.shape[0], q.shape[1], q.shape[2], head_size_v), dtype=q.dtype, device=q.device) + out_lse = torch.empty((batch, head_num, seq_q), dtype=torch.float32, device=q.device) + out_lse_hp = torch.empty((batch, head_num, seq_q), dtype=torch.float32, device=q.device) + + torch.ops.torch_mlu_ops.single_query_cached_kv_attn( + q, k_cache_lp, v_cache_lp, tmo_out, block_tables_lp, context_lens_lp, out_lse, + k_cache_quant_scale_lp, v_cache_quant_scale_lp, alibi_slopes, max_contxt_len_lp, + -1, -1, softmax_scale, True, kv_cache_quant_bit_size_lp) + torch.ops.torch_mlu_ops.single_query_cached_kv_attn( + q, k_cache_hp, v_cache_hp, out_hp, block_tables_hp, context_lens_hp, out_lse_hp, + k_cache_quant_scale_hp, v_cache_quant_scale_hp, alibi_slopes, max_contxt_len_hp, + -1, -1, softmax_scale, True, kv_cache_quant_bit_size_hp) + torch.ops.torch_mlu_ops.update_out_and_lse(tmo_out, out_lse, out_hp, out_lse_hp, None, None, None) + return (tmo_out, out_lse) if return_lse else tmo_out + +def apply_rotary( + input: torch.Tensor, + sin_cache: torch.Tensor, + cos_cache: torch.Tensor, + position_ids: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + interleaved: bool, + discrete: bool, + dynamic_ntk: bool, + max_seqlen: int, +) -> torch.Tensor: + """ + Apply rotary embedding to the input tensor. + + Math: + rotary_dim = sin_cache.shape[-1] + input_rot = input[..., :rotary_dim].clone() + if interleaved is True: + x1, x2 = input_rot[..., ::2].clone(), input_rot[..., 1::2].clone() + input_rot[..., ::2], input_rot[..., 1::2] = -x2, x1 + else: + x1, x2 = input_rot.chunk(2, dim=-1) + input_rot = concat((-x2, x1), dim=-1) + input = input_rot * sin_cache[position_ids] + input * cos_cache[position_ids] + return input + + Args: + input (torch.Tensor): Shape is (batch, seq, head_num, head_size) or (total_seq, head_num, head_size). + The head_num is usually equal to head_num_q+head_num_k. + sin_cache (torch.Tensor): If dynamic_ntk is true, the shape must be (batch, rotary_seq_len, rotary_dim), + otherwise shape must be (rotary_seq_len, rotary_dim). + cos_cache (torch.Tensor): If dynamic_ntk is true, the shape must be (batch, rotary_seq_len, rotary_dim), + otherwise shape must be (rotary_seq_len, rotary_dim). + position_ids (torch.Tensor): The position index of input tokens. + If discrete is True, shape must be (batch, seq) or (total_seq), which indicates each token has a position index. + If discrete is False, a tensor shape of (batch) specifies the start position of each batch, + a tensor of None indicates start position is 0 for each batch. + cu_seqlens (torch.Tensor): The sequence length of each batch in pack mode. + interleaved (bool): If interleaved is True, apply cross rotary embedding, otherwise apply fold rotary embedding. + dynamic_ntk (bool): A boolean indicates whether all batches share the same sin_cache and cos_cache. + max_seqlen (int): The maximum sequence lengths of input. + + Type: + input: float, half, bfloat16. + sin_cache: same as input. + cos_cache: same as input. + position_ids: int32. + cu_seqlens: int32. + + Return: + Return the input tensor after applying rotary embedding. + """ + torch.ops.torch_mlu_ops.apply_rotary( + input, + sin_cache, + cos_cache, + position_ids, + cu_seqlens, + interleaved, + discrete, + dynamic_ntk, + max_seqlen, + ) + return input + +def reshape_linear_cache( + key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + packed: bool, + context_seq_offset: Optional[torch.Tensor], + cache_bs_id: Optional[torch.Tensor], + cache_seqlen_offset: Optional[torch.Tensor], +) -> Union[Tuple[torch.Tensor], torch.Tensor]: + """ + Put key and value into key_cache and value_cache. + + Math: + for idx in range(batch): + key_cache[idx] = key[idx][:, context_length[idx]] + value_cache[idx] = value[idx][:, context_length[idx]] + Args: + key (torch.Tensor): The key tensor. If use pack mode, shape is (batch, seqlen, head_num_kv, head_size), + else, shape is (total_seqlen, head_num_kv, head_size). + value (torch.Tensor): The value tensor. Shape is the same as key. + key_cache (torch.Tensor): The key_cache tensor. Shape is (max_batch, head_num_kv , cache_mem_len, head_size). + value_cache (torch.Tensor): The value_cache tensor. Shape is the same as key_cache. + context_lengths (torch.Tensor): Store key or cache lengths. If packed, shape is (batch + 1), else + shape is (batch). + max_context_len (int): The maximum sequence length of context. + packed (bool): A boolean value indicates whether to use pack mode. + context_seq_offset (torch.Tensor): Store the sequence offset of context. Shape if (batch). + cache_bs_id (torch.Tensor): Indicate context will be put into which batch of cache. A negative value + indicates that the corresponding batch will be ignored. Shape is (batch). + cache_seqlen_offset (torch.Tensor): Indicate the offset of the position in the cache. Shape is (batch). + Type: + key: float, half, bfloat16. + value: float, half, bfloat16. + key_cache: float, half, bfloat16. + value_cache: float, half, bfloat16. + context_lengths: int32. + max_context_len: int32. + packed: bool. + context_seq_offset: int32. + cache_bs_id: int32. + cache_seqlen_offset: int32. + Return: + Support inplace outputs. + Directly return the given key_cache and value_cache if value_cache existed, else return key_cache only. + """ + torch.ops.torch_mlu_ops.reshape_linear_cache( + key, + value, + key_cache, + value_cache, + context_lengths, + max_context_len, + packed, + context_seq_offset, + cache_bs_id, + cache_seqlen_offset, + ) + return (key_cache, value_cache) if value_cache is not None else key_cache + +def reshape_paged_cache( + k: torch.Tensor, + v: Optional[torch.Tensor], + k_cache: torch.Tensor, + v_cache: Optional[torch.Tensor], + slot_mapping: torch.Tensor +) -> Union[Tuple[torch.Tensor], torch.Tensor]: + """ + Perform reshape_paged_cache operation. + + Math: + for i in range(num_tokens): + if slot_mapping[i] >= 0: + block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + k_cache[block_id, :, block_offset, :] = k[i] + v_cache[block_id, :, block_offset, :] = v[i] + + Args: + k (torch.Tensor): The key tensor. Shape is (num_token, head_num_kv, head_size). + v (torch.Tensor): The value tensor. Shape is the same as v. + k_cache (torch.Tensor): The key_cache tensor. Shape is (block_num, head_num_kv, block_size, head_size). + v_cache (torch.Tensor): The value_cache tensor. Shape is the same as key_cache. + slot_mapping (torch.Tensor): The slot_mapping tensor. Shape is (num_token). + + Types: + k: float, half, bfloat16. + v: float, half, bfloat16. + k_cache: float, half, bfloat16. + v_cache: float, half, bfloat16. + slot_mapping: int32. + + Return: + Support inplace outputs. + Directly return the given k_cache and v_cache if v_cache existed, else return k_cache only. + """ + torch.ops.torch_mlu_ops.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping) + return (k_cache, v_cache) if v_cache is not None else k_cache + +def quant_to_paged_cache( + k: torch.Tensor, + v: Optional[torch.Tensor], + k_cache: torch.Tensor, + v_cache: Optional[torch.Tensor], + k_cache_quant_scale: torch.Tensor, + v_cache_quant_scale: Optional[torch.Tensor], + slot_mapping: torch.Tensor, +) -> Tuple[torch.Tensor]: + """ + Quant k and v, put quanted_kv into k_cache and v_cache, put quanted_scale into + k_cache_quant_scale and v_cache_quant_scale. + + Math: + key_i = key_i[batch_idx] + max_value = max(key_i.abs(), dim=-1) + int_max = float(2 ** (quant_bit - 1) - 1) + scale_i = float(max_value / int_max) + key_i /= key_i + k_cache[pos] = k_i + k_cache_quant_scale[pos] = scale_i + + Args: + k (torch.Tensor): The key tensor. Shape is [token_nums, head_num_kv, head_size]. + v (torch.Tensor): The value tensor. Shape is the same as k. + k_cache (torch.Tensor): The k_cache tensor. Shape is [block_nums, head_num_kv, block_size, head_size]. + v_cache (torch.Tensor): The v_cache tensor. Shape is the same as v_cache. + k_cache_quant_scale (torch.Tensor): The k_cache_quant_scale tensor. Shape is [block_nums, head_num_kv, block_size]. + v_cache_quant_scale (torch.Tensor): The v_cache_quant_scale tensor. Shape is the same as v_cache_quant_scale. + slot_mapping (torch.Tensor): The slot_mapping tensor. Shape is [token_nums]. + + Type: + k: float, half, bfloat16. + v: float, half, bfloat16. + k_cache: int8. + v_cache: int8. + k_cache_quant_scale: float. + v_cache_quant_scale: float. + slot_mapping: int32. + + Return: + Support inplace outputs. + Directly return the given k_cache, v_cache, k_cache_quant_scale and v_cache_quant_scale if v_cache existed, + else return k_cache and k_cache_quant_scale. + """ + torch.ops.torch_mlu_ops.quant_to_paged_cache( + k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping + ) + return (k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale) if v_cache is not None else (k_cache, k_cache_quant_scale) + +def offline_quant_to_paged_cache( + k: torch.Tensor, + v: Optional[torch.Tensor], + k_cache_scale: torch.Tensor, + v_cache_scale: Optional[torch.Tensor], + slot_mapping: torch.Tensor, + k_cache: torch.Tensor, + v_cache: Optional[torch.Tensor], +) -> Union[Tuple[torch.Tensor], torch.Tensor]: + """ + Apply quant_to_paged_cache operation on key and value. + + Math: + tokens_num = k.shape[0] + block_size = k_cache.shape[2] + for i in range(tokens_num): + if slot_mapping[i] >= 0: + key_i = k[i] + value_i = v[i] + block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + key_cache_i = k_cache[block_id, :, block_offset, :] + value_cache_i = v_cache[block_id, :, block_offset, :] + quant_key_i = quant2int8(key_i, k_cache_scale) + quant_value_i = quant2int8(value_i, v_cache_scale) + k_cache_i[...] = quant_key_i + v_cache_i[...] = quant_value_i + + Args: + k (torch.Tensor): The key tensor, shape is (num_tokens, num_heads, head_size). + v (torch.Tensor): The value tensor. shape is (num_tokens, num_heads, head_size). + k_cache_scale (torch.Tensor): The key_cache_scale tensor. shape is (num_heads, head_size). + v_cache_scale (torch.Tensor): The value_cache_scale tensor. shape is (num_heads, head_size). + slot_mapping (torch.Tensor): The slot_mapping tensor. shape is (num_tokens). + k_cache (torch.Tensor): The key_cache tensor. shape is (num_blocks, num_heads, block_size, head_size). + v_cache (torch.Tensor): The value_cache tensor. shape is (num_blocks, num_heads, block_size, head_size). + + Type: + k: float, half, bfloat16. + v: float, half, bfloat16. + k_cache_scale: float. + v_cache_scale: float. + slot_mapping: int32. + k_cache: int8. + v_cache: int8. + + Return: + Support inplace outputs. + Directly return the given k_cache and v_cache if v_cache existed, else return k_cache only. + + """ + torch.ops.torch_mlu_ops.offline_quant_to_paged_cache( + k, v, k_cache_scale, v_cache_scale, slot_mapping, k_cache, v_cache, + ) + return (k_cache, v_cache) if v_cache is not None else k_cache + +def quant_to_linear_cache( + key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + packed: bool, + context_seq_offset: Optional[torch.Tensor], + cache_bs_id: Optional[torch.Tensor], + cache_seqlen_offset: Optional[torch.Tensor], + quant_bit: int = 8, +) -> Tuple[torch.Tensor]: + """ + Quant key and value, put quanted_kv into key_cache and value_cache, put quanted_scale into + key_cache_quant_scale and value_cache_quant_scale. + + Math: + key_i = key[batch_idx] + if groupwise quantize: + key_i = key_i.reshape(head_num, -1, group_size) + else: + key_i = key_i.reshape(head_num, -1, head_size) + max_value = max(key_i.abs(), dim=-1) + int_max = float(2 ** (quant_bit - 1) - 1) + scale_i = max_value / int_max + key_i /= scale_i + key_i = clip(key_i, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1) + if quant_bit == 4: + key_i = key_i.flatten() + d0 = key_i[0::2].to(uint8) + d1 = key_i[0::2].to(uint8) + dp = (d1 << 4) + (d0 & 0x0f) + key_i = key_i.to(int8).reshape(head_num, -1, head_size // 2) + k_cache[pos] = k_i + k_cache_quant_scale[pos] = scale_i + + Args: + key (torch.Tensor): The key tensor. If packed, shape is (total_seqlen, head_num_kv, head_size), + else, shape is (batch, seqlen, head_num_kv, head_size). + value (torch.Tensor): The value tensor. Shape is the same as key. + key_cache (torch.Tensor): The key cache tensor. Shape is (max_batch, head_num_kv, cache_mem_len, head_size). + value_cache (torch.Tensor): The value cache tenosr. Shape is the same as key_value. + key_cache_quant_scale (torch.Tensor): The key_cache_quant_scale tensor. If use groupwise quantization, shape is + (max_batch, head_num_kv, cache_mem_len, group_num), otherwise shape is + (max_batch, head_num_kv, cache_mem_len). + value_cache_quant_scale (torch.Tensor): The value_cache_quant_scale tensor. Shape is the same as key_cache_quant_scale. + context_lengths (torch.Tensor): The context_lengths tensor. If packed, shape is (batch + 1), else, shape is (batch). + max_context_len: The maximum sequence length of context. + packed: A boolean value indicates whether to use pack mode. + context_seq_offset (torch.Tensor): Indicate the sequence offset of context, shape is (batch). + cache_bs_id (torch.Tensor): Indicate key and value will be put into which batch of kv_cache. A negative value indicates + that the corresponding batch will be ignored. Shape is (batch). + cache_seqlen_offset (torch.Tensor): Indicate the offset of the position in the cache. Shape is (batch). + + Types: + key: float, half, bfloat16. + value: float, half, bfloat16. + key_cache: int8. + value_cache: int8. + key_cache_quant_scale: float32. + value_cache_quant_scale: float32. + context_lengths: int32. + max_context_len: int32. + packed: bool. + context_seq_offset: int32. + cache_bs_id: int32. + cache_seqlen_offset: int32. + + Return: + Support inplace outputs. + Directly return the given key_cache, value_cache, key_cache_quant_scale and value_cache_quant_scale if value_cache existed, + else return key_cache and key_cache_quant_scale. + """ + torch.ops.torch_mlu_ops.quant_to_linear_cache( + key, + value, + key_cache, + value_cache, + key_cache_quant_scale, + value_cache_quant_scale, + context_lengths, + max_context_len, + packed, + context_seq_offset, + cache_bs_id, + cache_seqlen_offset, + quant_bit, + ) + return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale) + +def offline_quant_to_linear_cache( + key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + quant_mode: int, + packed: bool, + context_seq_offset: Optional[torch.Tensor], + cache_bs_id: Optional[torch.Tensor], + cache_seqlen_offset: Optional[torch.Tensor], +) -> Tuple[torch.Tensor]: + """ + Quantize current key and value, Then store key and value to key_cache and value_cache. + + Math: + key_i = key[batch_idx].transpose(1, 0) + if quant_mode == 0: + key_i /= key_cache_quant_scale.reshape(head_num, 1, head_size) + else: + key_i /= key_cache_quant_scale.reshape(head_num, seq, 1) + key_i = clip(round(key_i), -128, 127) + key_cache[pos] = key_i + + Args: + key (torch.Tensor): The key tensor. If packed, shape is (total_seqlen, head_num_kv, head_size), + else, shape is (batch, seqlen, head_num_kv, head_size). + value (torch.Tensor): The value tensor. Shape is the same as key. + key_cache (torch.Tensor): The key_cache tensor. Shape is (max_batch, head_num, cache_mem_len, head_size). + value_cache (torch.Tensor): The value_cache tensor. Shape is the same as key_cache. + key_cache_quant_scale (torch.Tensor): If per_channel quantize, shape is (head_num_kv, head_size). + If per_head quantize, shape is (head_num_kv, cache_mem_len). + value_cache_quant_scale (torch.Tensor): Shape is the same as key_cache_quant_scale. + context_lengths (torch.Tensor): A tensor indicate context lengths. Shape is (batch+1) if packed is True, which + stores the cumsum of seqlen, otherwise shape is (batch). + max_context_len: The maximum sequence length of context. + quant_mode: A int value indicates which quantize mode will be used. If quant_mode is 0, use per_channel + quantize, otherwise use per_head quantize. + packed: A boolean value indicates whether to use pack mode. + context_seq_offset (torch.Tensor): Indicate the sequence offset of context, shape is (batch). + cache_bs_id (torch.Tensor): Indicate context will be put into which batch of cache. A negative value + indicates that the corresponding batch will be ignored. Shape is (batch). + cache_seqlen_offset (torch.Tensor): Indicate the offset of the position in the cache. Shape is (batch). + + Types: + key: float, half, bfloat16. + value: float, half, bfloat16. + key_cache: int8. + value_cache: int8. + key_cache_quant_scale: float32. + value_cache_quant_scale: float32. + context_lengths: int32. + max_context_len: int32. + quant_mode: int32. + packed: bool. + context_seq_offset: int32. + cache_bs_id: int32. + cache_seqlen_offset: int32. + + Return: + Support inplace outputs. + Directly return the given key_cache, value_cache, key_cache_quant_scale and value_cache_quant_scale if value_cache existed, + else return key_cache and key_cache_quant_scale. + """ + torch.ops.torch_mlu_ops.offline_quant_to_linear_cache( + key, + value, + key_cache, + value_cache, + key_cache_quant_scale, + value_cache_quant_scale, + context_lengths, + max_context_len, + quant_mode, + packed, + context_seq_offset, + cache_bs_id, + cache_seqlen_offset, + ) + return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale) + +def swap_blocks(dst: torch.Tensor, + src: torch.Tensor, + block_mapping: Dict[int, int]) -> torch.Tensor: + """ + Copy src value to dst according to block_mapping. + + Math: + for key, value in block_mapping.items(): + dst[value] = src[key] + + Args: + dst (torch.Tensor): Destination tensor. Shape is (num_blocks, num_heads, block_size, head_size). + src (torch.Tensor): Source tensor. Shape is (num_blocks, num_heads, block_size, head_size). + block_mapping (dict): Mapping table of src and dst. + + Types: + dst: Unlimited. + src: Unlimited. + block_mapping: [int, int]. + + Return: + Support inplace outputs. + Directly return the given dst. + """ + torch.ops.torch_mlu_ops.swap_blocks(dst, src, block_mapping) + return dst + +def copy_blocks( + k_caches: List[torch.Tensor], + v_caches: Optional[List[torch.Tensor]], + block_mapping: Dict[int, List[int]] +) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], List[torch.Tensor]]: + """ + Remap the k_caches and v_caches with the block indexes pairs specified by block_mapping. + + Math: + for src, dsts in block_mapping.items(): + srcs = [src for i in range(len(dsts))] + for key_cache in k_caches: + key_cache[dsts] = key_cache[srcs] + for value_cache in v_caches: + value_cache[dsts] = value_cache[srcs] + + Args: + k_caches (List[torch.Tensor]): A tensor list of k_cache. + v_caches (List[torch.Tensor]): A tensor list of v_cache. + block_mapping (Dict[int, List[int]]): The pairs of source and destination block indexes. + + Type: + k_caches: int8, uint8, int16, int32, int64, half, float, bfloat16. + v_caches: int8, uint8, int16, int32, int64, half, float, bfloat16. + + Return: + Support inplace outputs. + Directly return the given k_caches and v_caches. + """ + torch.ops.torch_mlu_ops.copy_blocks(k_caches, v_caches if v_caches is not None else [], block_mapping) + return (k_caches, v_caches) if v_caches is not None else k_caches + +def quant_matmul( + a_tensor: torch.Tensor, + a_scale: Optional[torch.Tensor], + a_zero: Optional[torch.Tensor], + b_tensor: torch.Tensor, + b_scale: Optional[torch.Tensor], + b_zero: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + c_tensor: Optional[torch.Tensor], + c_scale: Optional[torch.Tensor], + c_zero: Optional[torch.Tensor], + gemm_output_scale: Optional[torch.Tensor], + gemm_output_zero: Optional[torch.Tensor], + data_type: Optional[str], + quant_algo: str, + a_quant_layout: str, + b_quant_layout: str, + quant_bit_size: int = 8, + act_mode: str = "none", + use_hp_active: bool = False, + act_coef: float = 1.0, + alpha: float = 1.0, + beta: float = 1.0, + trans_a: bool = False, + trans_b: bool = True, +) -> torch.Tensor: + """ + Full-featured function to do quantized matmul. + It is recommended to use the function with fewer parameters such as 'weight_only_quant_matmul' or 'smooth_quant_matmul'. + """ + return torch.ops.torch_mlu_ops.quant_matmul( + a_tensor, + a_scale, + a_zero, + b_tensor, + b_scale, + b_zero, + bias, + c_tensor, + c_scale, + c_zero, + gemm_output_scale, + gemm_output_zero, + data_type, + None, + quant_algo, + a_quant_layout, + b_quant_layout, + quant_bit_size, + act_mode, + use_hp_active, + act_coef, + alpha, + beta, + trans_a, + trans_b, + ) + +def quant_matmul_allreduce( + cncl_comm, + a_tensor: torch.Tensor, + a_scale: Optional[torch.Tensor], + a_zero: Optional[torch.Tensor], + b_tensor: torch.Tensor, + b_scale: Optional[torch.Tensor], + b_zero: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + c_tensor: Optional[torch.Tensor], + c_scale: Optional[torch.Tensor], + c_zero: Optional[torch.Tensor], + gemm_output_scale: Optional[torch.Tensor], + gemm_output_zero: Optional[torch.Tensor], + data_type: Optional[str], + quant_algo: str, + a_quant_layout: str, + b_quant_layout: str, + quant_bit_size: int = 8, + alpha: float = 1.0, + beta: float = 1.0, + trans_a: bool = False, + trans_b: bool = True, + block_m: int = 0 +) -> torch.Tensor: + return torch.ops.torch_mlu_ops.quant_matmul_allreduce( + cncl_comm, + a_tensor, + a_scale, + a_zero, + b_tensor, + b_scale, + b_zero, + bias, + c_tensor, + c_scale, + c_zero, + gemm_output_scale, + gemm_output_zero, + data_type, + None, + quant_algo, + a_quant_layout, + b_quant_layout, + quant_bit_size, + alpha, + beta, + trans_a, + trans_b, + block_m + ) + +def active(input: torch.Tensor, act_mode: str, is_gated: bool, active_coef: float = 1.0) -> torch.Tensor: + """ + Apply activation to the input tensor. + + Math: + C = input.shape[-1] + if is_gated = True: + output = active(input[..., :C//2]) * input[..., C//2:] + else: + output = active(input) + + Args: + input (torch.Tensor): Shape is (..., C) + act_mode (str): The activation mode, must be 'silu', 'gelu', 'swish' or 'quick_gelu'. + is_gated (bool): If use gated activation. + active_coef (float): The coefficient used in the swish activation. Default is 1.0. + + Type: + input: float, half, bfloat16. + output: same as input. + + Return: + A tensor with the same shape and dtype of input. + """ + out_shape = list(input.shape) + if is_gated: + out_shape[-1] = input.size(-1) // 2 + output = torch.empty(out_shape, dtype=input.dtype, device=input.device) + torch.ops.torch_mlu_ops.active(input, output, None, None, act_mode, + is_gated, 0, 0, active_coef) + return output + +def fused_moe( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], + residual: Optional[torch.Tensor], + input_smooth: Optional[torch.Tensor], + act_smooth: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + topk: int, + renormalize: bool, + gated: bool, + act_mode: str, + start_expert_id: int = 0, + block_n: int = 0, + cncl_comm: int = 0, + w1_quant_flag: Optional[List[int]] = None, + w2_quant_flag: Optional[List[int]] = None, + ) -> torch.Tensor: + """ + Apply mixtrue of experts operation to the input tensor. + + Math: + Please refer to 'test/pytest/test_moe.py' in this repository. + + Args: + hidden_states (torch.Tensor): Shape is (batch, seq, hidden_size) or (total_seq, hidden_size). + gating_output (torch.Tensor): The output of router FC. Shape is (batch, seq, num_expert) or (total_seq, num_expert). + w1 (torch.Tensor): shape are shown as followed: + (expert_size, inner_size * (1 + gated), hidden_size) for normal mode. + (expert_size, inner_size * (1 + gated), hidden_size / 2) for int4 quantization. + w2 (torch.Tensor): shape are shown as followed: + (expert_size, hidden_size, inner_size) for normal mode. + (part_a, hidden_size, part_b, inner_size) for retiled mode, which part_a * part_b must be expert_size. + (expert_size, hidden_size, inner_size / 2) for int4 quantization. + bias1 (torch.Tensor): Must be None. + bias2 (torch.Tensor): Must be None. + residual (torch.Tensor): Shape is (batch, seq, hidden_size) or (total_seq, hidden_size). + input_smooth (torch.Tensor): The smooth scale of hidden_states, shape is (expert_size, hidden_size). + act_smooth (torch.Tensor): The smooth scale of the activation output, shape is (expert_size, inner_size). + w1_scale (torch.Tensor): The per-channel scale of w1, shape are shown as followed: + (expert_size, (1 + gated) * inner_size) for normal quantization. + (w1_quant_group, expert_size, (1 + gated) * inner_size) for groupwise quantization. + w2_scale (torch.Tensor): The per-channel scale of w2, shape are shown as followed: + (expert_size, hidden_size) for normal quantization. + (w2_quant_group, expert_size, hidden_size) for groupwise quantization. + topk (int): The number of experts selected from the total experts, topk <= num_expert. + renormalize (bool): If do normalization on the topk_logits. + gated (bool): If use gated activation. + act_mode (str): The activation mode, must be 'silu' or 'gelu'. + start_expert_id (int): The starting id of the expert. + block_n (int): Take effect only if block_n > 0, otherwise the recommended value will be used. + cncl_comm (int): The handle of CnclComm. + w1_quant_flag (List[int]): A list of int values, the elements must be 4 or 8, which specify the quantization bits of each group, + the other values may cause undefined behavior. Length is (expert_size * w1_quant_group). + w2_quant_flag (List[int]): A list of int values, the elements must be 4 or 8, which specify the quantization bits of each group, + the other values may cause undefined behavior. Length is (expert_size * w2_quant_group). + + Type: + hidden_states: float, half, bfloat16. + gating_output: same as hidden_states. + w1: float, half, bfloat16, int8, int8(int4x2). + w2: same as w1. + bias1: same as hidden_states. + bias2: same as hidden_states. + residual: same as hidden_states. + input_smooth: float. + act_smooth: float. + w1_scale: float. + w2_scale: float. + output: same as hidden_states. + + Return: + A tensor with the same shape of hidden_states. + + Note: + The elements in w1_quant_flag and w2_quant_flag must be 4 or 8. + """ + return torch.ops.torch_mlu_ops.fused_moe( + hidden_states, + gating_output, + w1, + w2, + bias1, + bias2, + residual, + input_smooth, + act_smooth, + w1_scale, + w2_scale, + w1_quant_flag, + w2_quant_flag, + topk, + renormalize, + gated, + act_mode, + start_expert_id, + block_n, + cncl_comm + ) + +def matmul( + a: torch.Tensor, + b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + act_mode: str = 'none', + alpha: float = 1.0, + beta: float = .0, + fast_act: bool = True, + approximate: bool = True, + dtype: torch.dtype = None, + a_scale: float = 1.0, + b_scale: float = 1.0, + trans_a: bool = False, + trans_b: bool = True, + ) -> torch.Tensor: + """ + Perform matrix multiplication operation. + + Math: + ab = matmul(a, b) + if c is not None: + ab = alpha * ab + beta * c + if bias is not None: + ab += bias + if act_mode is not 'none': + ab = active(ab) + output = ab + + Args: + a (torch.Tensor): If trans_a, shape is (k, m), else shape is (m, k). + b (torch.Tensor): If trans_b, shape is (n, k), else shape is (k, n). + bias (torch.Tensor): Shape is (n). + c (torch.Tensor): Shape is (m, n). + act_mode (str): The activation type, must be 'gelu', 'relu', 'silu' or 'none'. + alpha (float): The coefficient of ab. + beta (float): The coefficient of c. + fast_act (bool): Describing the algorithm that used in the implementation of the activation function. + When the value is true, use the fastest algorithm of activation, otherwise use the high-precision algorithm. + approximate(bool): Describes the implementation logic of different GELU approximation algorithms. + d (torch.Tensor): Shape is (m, n). + a_scale (float): quant_scale of a. + b_scale (float): quant_scale of b. + + Type: + a: float, half, bfloat16, int8. + b: same as a. + bias: float, half, bfloat16. + c: float, half, bfloat16. + output: float, half, bfloat16. + + Return: + Return output tensor. + """ + d_dtype = None if dtype is None else _torchDtype2Str(dtype) + return torch.ops.torch_mlu_ops.matmul(a, b, None, bias, c, d_dtype, act_mode, alpha, beta, + fast_act, approximate, a_scale, b_scale, trans_a, trans_b) + +def batch_matmul( + a: torch.Tensor, + b: torch.Tensor, + c: Optional[torch.Tensor] = None, + alpha: float = 1.0, + beta: float = .0, + a_scale: float = 1.0, + b_scale: float = 1.0, + trans_a: bool = False, + trans_b: bool = True, + ) -> torch.Tensor: + """ + Perform matrix multiplication operation. + + Math: + ab = batch_matmul(a, b) + if c is not None: + c = alpha * ab + beta * c + output = c + else: + output = ab + + Args: + a (torch.Tensor): Shape is (batch, m, k). + b (torch.Tensor): Shape is (batch, n, k). + c (torch.Tensor): Shape is (batch, m, n), c will be changed when the calculation is completed. + alpha (float): The coefficient of ab. + beta (float): The coefficient of c. + a_scale (float): quant_scale of a. + b_scale (float): quant_scale of b. + + Type: + a: float, half, bfloat16, int8. + b: same as a. + c: float, half, bfloat16. + output: float, half, bfloat16. + + Return: + Return output tensor. + """ + tmo_out = c + if c is None: + assert a.dtype in (torch.half, torch.float, torch.bfloat16), "dtype must be half, float or bfloat16, when c is None." + m = a.shape[-1] if trans_a else a.shape[-2] + n = b.shape[-2] if trans_b else b.shape[-1] + batch = a.shape[0] + tmo_out = torch.zeros((batch, m, n), dtype=a.dtype, device=a.device) + torch.ops.torch_mlu_ops.batch_matmul(a, b, tmo_out, alpha, beta, a_scale, b_scale, trans_a, trans_b) + return tmo_out + +def matmul_allreduce(cncl_comm, + a: torch.Tensor, + b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + alpha: float = 1.0, + beta: float = .0, + block_m: int = 0 + ): + """ + Parallelly apply matrix multiplication with TP-allreduce. + + Math: + ab = matmul(a, b) + if c is not None: + ab = alpha * ab + beta * c + if bias is not None: + ab += bias + output = ab + output_allreduce = allreduce(output) + + Args: + cncl_comm (int): The handle of CnclComm. + a (torch.Tensor): Shape is (m, k). + b (torch.Tensor): Shape is (n, k). + bias (torch.Tensor): Shape is (n). + c (torch.Tensor): Shape is (m, n). + alpha (float): The coefficient of ab. + beta (float): The coefficient of c. + block_m (int): If the value is greater than 0, the specified value will be used; if the value is less than 0, the recommended value will be used. + + Type: + a: float, half, bfloat16. + b: same as a. + bias: same as a. + c: same as a. + output: same as a. + + Return: + Return output tensor. + """ + return torch.ops.torch_mlu_ops.matmul_allreduce(cncl_comm, a, b, bias, c, None, alpha, beta, block_m) + +def group_gemm(a: torch.Tensor, + b: torch.Tensor, + m_list: torch.Tensor, + expand_idx: Optional[torch.Tensor], + c: Optional[torch.Tensor], + alpha: Optional[torch.Tensor], + beta: Optional[torch.Tensor], + max_m: int, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ''' + Perform the grouped matrix multiplication operation. + + Math: + if expand_idx: + a = a[expand_idx] + a_list = a.split(m_list) + c_list = c.split(m_list) + d_list = d.split(m_list) + experts = len(m_list) + for i in range(experts): + d_list[i] = matmul(a_list[i], b[i]) * alpha[i] + beta[i] * c_list[i] + bias[i] + d = concat(d_list, dim=0) + + Args: + a (torch.Tensor): Shape is (m, k). If expand_idx exists, total_m = m_list.sum(), otherwise total_m = m. + b (torch.Tensor): Shape is (experts, n, k). + m_list (torch.Tensor): Shape is (experts). + expand_idx (torch.Tensor): optional. Shape is (total_m). + c (torch.Tensor): optional. Shape is (total_m, n). + alpha (torch.Tensor): optional. Shape is (experts). + beta (torch.Tensor): optional. Shape is (experts). + max_m (int): Maximum possible value in m_list. + bias (torch.Tensor): optional. Shape is (experts, n). + + Type: + a: half, bfloat16. + b: same as a. + m_list: int32. + expand_idx: int32. + c: same as a. + alpha: float. + beta: float. + d: same as a. + bias: same as a. + + Return: + Return d. + ''' + return torch.ops.torch_mlu_ops.group_gemm(a, b, m_list, expand_idx, c, alpha, beta, None, None, bias, _torchDtype2Str(a.dtype), None, None, max_m) + +def smooth_quant_group_gemm(a: torch.Tensor, + b: torch.Tensor, + m_list: torch.Tensor, + expand_idx: Optional[torch.Tensor], + c: Optional[torch.Tensor], + alpha: Optional[torch.Tensor], + beta: Optional[torch.Tensor], + a_scale: torch.Tensor, + b_scale: torch.Tensor, + dtype, + max_m: int, + bias: Optional[torch.Tensor] = None, + quant_flag: Optional[List[int]] = None + ) -> torch.Tensor: + ''' + Perform smooth quantized group_gemm operation. + + Math: + if expand_idx: + a = a[expand_idx] + a_list = a.split(m_list) + c_list = c.split(m_list) + d_list = d.split(m_list) + a_scale_list = a_scale.split(m_list) + experts = len(m_list) + for i in range(experts): + ab_scale = a_scale_list[i].outer(b_scale[i]) + ab = matmul(a_list[i], b[i]) * ab_scale + d_list[i] = ab * alpha[i] + beta[i] * c_list[i] + bias[i] + d = concat(d_list, dim=0) + + Args: + a (torch.Tensor): Shape is (m, k). If expand_idx exists, total_m = m_list.sum(), otherwise total_m = m. + b (torch.Tensor): Shape is (experts, n, k), or (experts, n, k // 2) for int 4 quantization. + m_list (torch.Tensor): Shape is (experts). + expand_idx (torch.Tensor): optional. Shape is (total_m). + c (torch.Tensor): optional. Shape is (total_m, n). + alpha (torch.Tensor): optional. Shape is (experts). + beta (torch.Tensor): optional. Shape is (experts), the value can only be 0 or 1. + a_scale (torch.Tensor): Shape is (total_m). + b_scale (torch.Tensor): Shape is (experts, n), or (quant_group, experts, n) for groupwise quantization. + max_m (int): Maximum possible value in m_list. + dtype (torch.dtype). Specify the data type of output, must be torch.half or torch.bfloat16. + bias (torch.Tensor): optional. Shape is (experts, n). + quant_flag (List[int]): A list of int values, the elements must be 4 or 8, which specify the quantization bits of each group, + the other values may cause undefined behavior. Length is (experts * quant_group). + + Type: + a: int8. + b: int8, int8(int4x2). + m_list: int32. + expand_idx: int32. + c: half, bfloat16. + alpha: float. + beta: float. + a_scale: float. + b_scale: float. + d: half, bfloat16. + bias: half, bfloat16. + + Return: + Return d. + + Note: + The elements in quant_flag must be 4 or 8. + ''' + return torch.ops.torch_mlu_ops.group_gemm(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, bias, _torchDtype2Str(dtype), quant_flag, None, max_m) + +def preload( + weight: torch.Tensor, + size: int, +) -> None: + torch.ops.torch_mlu_ops.preload(weight, size) + return weight + +def flash_attn_sq_mm_allreduce(cncl_comm: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seq_lens_q: Optional[torch.Tensor], + cu_seq_lens_kv: Optional[torch.Tensor], + alibi_slope: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], + smooth: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: Optional[torch.Tensor], + max_seq_len_q: int, + max_seq_len_kv: int, + softmax_scale: float, + is_causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + compute_dtype: torch.dtype = torch.float, + block_seq: int = 0) -> torch.Tensor: + """ + Parallelly apply attention-pertoken_smooth_quant-smooth_quant_matmul with TP-allreduce. + + Math: + attn_out = flash_attn(q, k, v, cu_seq_lens_q, cu_seq_lens_kv) + quant_a, scale_a = per_token_smooth_quantize(attn_out, smooth) + output = smooth_quant_matmul(quant_a, scale_a, weight, weight_scale, bias) + output_allreduce = allreduce(output) + + Args: + cncl_comm (int): The handle of CnclComm. + q (torch.Tensor): The query tensor. Shape is (batch, seq_q, head_num_q, head_size) or (total_seq_q, head_num_q, head_size). + k (torch.Tensor): The key tensor. Shape is (batch, seq_kv, head_num_kv, head_size) or (total_seq_kv, head_num_kv, head_size). + v (torch.Tensor): The value tensor. Shape is (batch, seq_kv, head_num_kv, head_size) or (total_seq_kv, head_num_kv, head_size). + cu_seq_lens_q (torch.Tensor): The cusum of seq_q if q is 3-D tensor. Shape is (batch+1). + cu_seq_lens_kv (torch.Tensor): The cusum of seq_kv if k/v is 3-D tensor. Shape is (batch+1). + alibi_slope (torch.Tensor): Reserved parameter, must be None. + attn_bias (torch.Tensor): Reserved parameter, must be None. + smooth (torch.Tensor): The smooth scale. Shape is (head_num_q*head_size). + weight (torch.Tensor): The weight of smooth_quant_matmul. Shape is (head_num_q*head_size*TP_NUM, head_num_q*head_size). + bias (torch.Tensor): The bias of smooth_quant_matmul. Shape is (head_num_q*head_size*TP_NUM). + max_seq_len_q (int): The maximum value of seq_q. + max_seq_len_kv (int): The maximum value of seq_kv. + softmax_scale (int): The scale factor multiplied to qk. + is_causal (bool): Controlling if use causal mask. + window_size_left (int): Reserved parameter, must be -1. + window_size_right (int): Reserved parameter, must be -1. + compute_dtype (torch.dtype): The dtype used to calculate softmax. + block_seq (int): Reserved parameter. + + Type: + q: half, bfloat16. + k: same as q. + v: same as q. + cu_seq_lens_q: int32. + cu_seq_lens_kv: int32. + smooth: float. + weight: int8. + weight_scale: float. + bias: same as q. + + Return: + Return the result of allreduce. + + Note: + batch must be 1. + """ + return torch.ops.torch_mlu_ops.flash_attn_sq_mm_allreduce(cncl_comm, q, k, v, + cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, + smooth, weight, weight_scale, bias, max_seq_len_q, max_seq_len_kv, + softmax_scale, is_causal, window_size_left, window_size_right, + _torchDtype2Str(compute_dtype), block_seq) + +def moe_cast_gating(input: torch.Tensor, + weight: torch.Tensor) -> None: + """ + Cast input data type to torch.float32, and perform gating operation. + + Math: + input_fp32 = input.to(torch.float32). + gating_output = F.Linear(input_fp32, weight). + + Args: + input (torch.Tensor): The input tensor, shape is (..., hidden_size), hidden_size must be less than or equal to 16384. + weight (torch.Tensor): The input tensor, shape is (expert_num, hidden_size), expert_num must be less than or equal to 128. + + Types: + input: half, bfloat16. + weight: float. + + Return: + return a float tensor with shape [..., expert_num]. + """ + + return torch.ops.torch_mlu_ops.moe_cast_gating(input, weight) + +def moe_softmax_topk(input: torch.Tensor, + topk: int, + normalize: bool = False, + num_expert_group: int = -1, + topk_group: int = 0, + mask: Optional[torch.Tensor] = None, + normed_by : str = "topk_logit") -> Tuple[torch.Tensor]: + """ + Apply grouped softmax-topk to input tensor. + + Math: + if num_expert_group <= 1: + if origin_mask is not None: + softmax = softmax * origin_mask + reduce_weight, expert_id = torch.topk(softmax, k=topk, dim=-1) + if normalize: + if normed_by == "topk_logit": + reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True) + if normed_by == "softmax_logit": + reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True) + return reduce_weight, expert_id + else: + group_size = softmax.shape[-1] // num_expert_group + new_shape = softmax.shape[:-1] + (num_expert_group, group_size) + group_data = softmax.view(new_shape) + group_max_value = group_data.max(dim=-1).values + group_idx = torch.topk(group_max_value, k=topk_group, dim=-1)[1] + mask_shape = softmax.shape[:-1] + (num_expert_group,) + mask = torch.zeros((mask_shape), dtype = torch.bool, device = group_idx.device) + mask.scatter_(-1, group_idx, True) + mask = mask.unsqueeze(-1).expand(new_shape) + masked_data = group_data.masked_fill(~mask, 0.0) + masked_data = masked_data.reshape(softmax.shape) + reduce_weight, expert_id = torch.topk(masked_data, k=topk, dim=-1) + if normalize: + if normed_by == "topk_logit": + reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True) + if normed_by == "softmax_logit": + reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True) + return reduce_weight, expert_id + + Args: + input (torch.Tensor): The input tensor, shape is (..., expert_num). + topk (int): The number of experts that each token would be dispatched to. + normalize (bool): If do normalization to the reduce_weight after topk operation. + num_expert_group(int): The number of groups that each expert would be grouped to. + topk_group(int): The number of groups to be selected from num_expert_group. + mask (torch.Tensor): The mask multiplied to the softmax result. The dimension must be + the same as input. The last two dims must be the same as input and the other dims must + be 1. + normed_by (str): The mode of normalization, which can be "topk_logit" or "softmax_logit" + + Type: + input: float, half, bfloat16. + mask: same as input + + Return: + Return the reduce_weight and expert_id tensor. + + Note: + Input and mask must be contiguous. + """ + outputs = torch.ops.torch_mlu_ops.moe_softmax_topk(input, + topk, + num_expert_group, + topk_group, + normalize, + mask, + normed_by) + return tuple(outputs) + +def moe_expand_input(input: torch.Tensor, + gather_idx: torch.Tensor, + cusum_token_count: Optional[torch.Tensor] = None, + start_expert_id: int = 0, + expert_size: int = 0) -> torch.Tensor: + """ + Expands the input tensor with the indices specified by gather_idx. + + Math: + if cusum_token_count is None: + output = input[gather_idx] + else: + idx = gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]] + output = input[idx] + + Args: + input (torch.Tensor): The input tensor, shape is (token_num, hidden_size). + gather_idx (torch.Tensor): The input tensor containing the indices to index. Shape is (expand_token_num). + cusum_token_count (torch.Tensor): The input tensor storing the prefix sum of token count. Shape is (expert_num + 1). + output (torch.Tensor): The output tensor storing expanded input. Shape is (expand_token_num, hidden_size). + start_expert_id: the first expert id. + expert_size: the number of experts currently being processed. + + Types: + input: int8, float, half, bfloat16. + gather_idx: int32. + cusum_token_count: int32. + output: float, half, bfloat16. + start_expert_id: int32. + expert_size: int32. + + Return: + return the input selected with gather idx. + """ + return torch.ops.torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count, + start_expert_id, expert_size) + +def moe_gen_idx(expert_id: torch.Tensor, + expert_num: int) -> Tuple[torch.Tensor]: + """ + Generate expand_idx, combine_idx, token_count, and cusum_token_count. + + Args: + expert_id (torch.Tensor): The input tensor stores the expert id of each token, the shape must be [num_token, topk]. + expert_num: the number of expert. + + Types: + expert_id: int32. + expert_num: int32. + + Return: + return a tensor list includes expand_idx, combine_idx, token_count and cusum_token_count. + """ + + outputs = torch.ops.torch_mlu_ops.moe_gen_idx(expert_id, expert_num) + return tuple(outputs) + +def moe_combine_result(input: torch.Tensor, + reduce_weight: torch.Tensor, + gather_ids: torch.Tensor, + residual: Optional[torch.Tensor], + cusum_token_count: Optional[torch.Tensor], + start_expert_id: int, + expert_size: int, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + ''' + Perform combine result operation in moe. + + Math: + if cusum_token_count: + input = input[gather_ids - cusum_token_count[start_expert_id + expert_size]] + else: + input = input[gather_ids] + + if bias: + for i in range(start_expert_id : start_expert_id + expert_size): + out[cusum_token_count[i] : cusum_token_count[i+1]] += bias[i] + + if cusum_token_count: + reduce_weight *= (gather_ids >= cusum_token_count[start_expert_id]) * + (gather_ids < cusum_token_count[start_expert_id + expert_size]) + out = input * reduce_weight + out = out.sum(1) + if residual: + out += residual + + Args: + input (torch.Tensor): Shape is (num_tokens, hidden_size). + reduce_weight (torch.Tensor): Shape is (num_token, topk). + gather_ids (torch.Tensor): Shape is (num_tokens). + residual (torch.Tensor): optional. Shape is (num_token, hidden_size). + bias (torch.Tensor): optional. Shape is (num_expert, hidden_size). + cusum_token_count (torch.Tensor): optional. Shape is (num_expert + 1). + start_expert_id (int): begin expert, used in expert parrallelism. + expert_size (int): expert size, used in expert parrallelism. + + Type: + input: bfloat16, float16, float32. + reduce_weight: float32. + gather_ids: int32. + residual: bfloat16, float16, float32. + bias: bfloat16, float16, float32. + cusum_token_count: int32. + start_expert_id: int32. + expert_size: int32. + + Return: + Return output. + ''' + return torch.ops.torch_mlu_ops.moe_combine_result(input, + reduce_weight, + gather_ids, + residual, + cusum_token_count, + start_expert_id, + expert_size, + bias) + +def moe_quantize(x: torch.Tensor, + smooth: torch.Tensor, + zero: Optional[torch.Tensor] = None, + token_count: Optional[torch.Tensor] = None, + gather_index: Optional[torch.Tensor] = None, + gather_index_start_position: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + dynamic_quant: bool = True + ) -> Tuple[torch.Tensor]: + """ + Apply quantization to the input tensor in MoE network. + + Math: + if gather_idx and gather_index_start_position: + start_pos = gather_index_start_position[0] + input = input[gather_index[start_pos : start_pos + token_count.sum()]] + elif gather_idx: + input = input[gather_index] + + if token_count: + input_list = input.split(token_count) + group = token_count.size() + result = [] + for i in range(group): + result.append(input_list[i] * smooth[i]) + smoothed = concat(result, dim=0) + else: + smoothed = input * smooth + + if dynamic_quant: + max, _ = smoothed.abs().max(dim=-1, keepdim=True) + output_scale = max.to(torch.float) / 127.0 + quanted = (smoothed / output_scale).round().clamp(-128, 127).to(torch.int8) + output = (quanted, output_scale) + else: + quanted = smoothed.round().clamp(-128, 127).to(torch.int8) + output = (quanted) + + Args: + input (torch.Tensor): The tensor to be quantized. Shape is (..., C) or (token_count.sum(), C). + smooth (torch.Tensor): The smooth scale multipled to the input tensor. Shape is (group, C). + zero (torch.Tensor): Not supported, must pass None. + token_count (torch.Tensor): The tensor to separate input. Shape is (group). + gather_index (torch.Tensor): The indices tensor to expand input. Shape is (expanded_tokens_num). + gather_index_start_position (torch.Tensor): The tensor indicating the start position of gather_index. Shape is (1). + output: (torch.Tensor): The tensor to store the output. + output_scale: (torch.Tensor): The tensor to store output_scale. + dynamic_quant: whether do dynamic quant. + + DataType: + x: float, half or bfloat16. + smooth: float. + token_count: int32. + gather_index: int32. + gather_index_start_position: int32. + dynamic_quant: bool. + output: int8. + output_scale: float. + + Return: + Returns (output, output_scale) if dynamic_quant is True, otherwise returns output only. + """ + output_shape = list(x.size()) + output_scale_shape = list(x.size()[:-1]) + if gather_index is not None: + output_tokens = gather_index.size(0) + output_shape[0] = output_tokens + output_scale_shape[0] = output_tokens + output = torch.empty(output_shape, dtype=torch.int8, device=x.device) if output is None else output + output_scale = torch.empty(output_scale_shape, dtype=smooth.dtype, device=smooth.device) if output_scale is None and dynamic_quant else output_scale + torch.ops.torch_mlu_ops.smooth_quant(x, smooth, + output, + output_scale, + zero, + token_count, gather_index, + gather_index_start_position, + 'per_token', dynamic_quant) + return (output, output_scale) if dynamic_quant else (output,) + +def moe_active(input: torch.Tensor, + act_mode: str, + is_gated: bool, + output: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cusum_token_count: Optional[torch.Tensor] = None, + start_expert_id: int = 0, + expert_size: int = 0) -> torch.Tensor: + """ + Apply activation to the input tensor. + + Math: + C = input.shape[-1] + if bias: + start_idx = 0 + for i in range(start_expert_id : start_expert_id + expert_size): + deal_token_num = cusum_token_count[i+1] - cusum_token_count[i] + input[start_idx:start_idx+deal_token_num] += bias[i] + start_idx += deal_token_num + if is_gated: + output = active(input[..., :C//2]) * input[..., C//2:] + else: + output = active(input) + + Args: + input (torch.Tensor): Shape is (..., C). + act_mode (str): The activation mode, must be 'silu' or 'gelu'. + is_gated (bool): If use gated activation. + bias(torch.Tensor): optional, Shape is (expert_num, C). + cusum_token_count(torch.Tensor): optional, Shape is (expert_num + 1). + start_expert_id(int): optional, begin expert, used in expert parallelism. + expert_size(int): optional, expert_size, used in expert parallelism. + + Type: + input: float, half, bfloat16. + act_mode: str. + is_gated: bool. + bias: same as input. + cusum_token_count: int32. + start_expert_id: int32. + expert_size: int32. + output: same as input. + + Return: + if is_gated, output shape is (input.size()[:-1], C // 2), + else, the same as input shape. + """ + if output is None: + if is_gated: + output = torch.empty(input.size()[:-1] + (input.size()[-1] // 2,), dtype=input.dtype, device=input.device) + else: + output = torch.empty_like(input) + torch.ops.torch_mlu_ops.active(input, + output, + bias, + cusum_token_count, + act_mode, + is_gated, + start_expert_id, + expert_size) + return output + +def fused_rope(qkv: torch.Tensor, + k_cache_hp: torch.Tensor, + v_cache_hp: torch.Tensor, + sin_cache: torch.Tensor, + cos_cache: torch.Tensor, + position_id: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + k_cache_lp: Optional[torch.Tensor] = None, + v_cache_lp: Optional[torch.Tensor] = None, + cache_bs_id_hp: Optional[torch.Tensor] = None, + cache_seq_offsets_hp: Optional[torch.Tensor] = None, + cache_bs_id_lp: Optional[torch.Tensor] = None, + cache_seq_offsets_lp: Optional[torch.Tensor] = None, + k_scale_hp: Optional[torch.Tensor] = None, + v_scale_hp: Optional[torch.Tensor] = None, + k_scale_lp: Optional[torch.Tensor] = None, + v_scale_lp: Optional[torch.Tensor] = None, + slot_mapping_hp: Optional[torch.Tensor] = None, + slot_mapping_lp: Optional[torch.Tensor] = None, + eps: float = 1e-5) : + """ + Perform query and key fold rope + key layernorm + (key, value perchannel quantize) + + reshape key value to kv cache. + + Math: + q = qkv[:, :, 0:head_num_q] + k = qkv[:, :, head_num_q:head_num_q + head_num_k].clone + v = qkv[:, :, head_num_q + head_num_k:].clone + + q = apply_rotary(q, sin_table, cos_table, position_ids) + k = apply_rotary(k, sin_table, cos_table, position_ids) + k = layernorm(k, gamma, beta) + + if (key_scale is not None and value is not None): + k = quantize(k, key_scale) + v = quantize(v, value_scale) + + if (slot_mapping is not None): + reshape_paged_cache(key_cache, value_cache, k, v, slot_mapping) + else: + reshape_linear_cache(key_cache, value_cache, k, v, cache_bs_id, cache_seq_offset) + + Args: + qkv (torch.Tensor): The qkv tensor. Shape is (batch, 1, head_num_q + head_num_k * 2, head_size). + key_cache_hp (torch.Tensor): The high precision key cache tensor. Shape is (max_bs, head_num_k, max_decode_len, head_size) or + (num_blocks, head_num_k, block_size, head_size). + value_cache_hp (torch.Tensor): The high precision value cache tensor. Shape is (max_bs, head_num_k, max_decode_len, head_size) or + (num_blocks, head_num_k, block_size, head_size). + sin_table (torch.Tensor): The rotary sin table tensor. Shape is (rotary_seq, head_size). + cos_table (torch.Tensor): The rotary cos table tensor. Shape is (rotary_seq, head_size). + position_ids (torch.Tensor): The rotary seq_len offset of each batch. Shape is (batch). + gamma (torch.Tensor): The layernorm gamma tensor. Shape is (head_size). + beta (torch.Tensor): The layenorm beta tensor. Shape is (head_size). + key_cache_lp (torch.Tensor): The low precision key cache tensor. Shape is (max_bs, head_num_k, max_decode_len, head_size) or + (num_blocks, head_num_k, block_size, head_size / 2). + value_cache_lp (torch.Tensor): The low precision value cache tensor. Shape is (max_bs, head_num_k, max_decode_len / 2, head_size) or + (num_blocks, head_num_k, block_size / 2, head_size). + cache_bs_id_hp (Torch.Tensor): The high precision cache batch offset tensor of each batch. Shape is (batch). This tensor + can be None, that means cache_bs_id is (0, 1, 2, 3, 4, ...). + cache_seq_offset_hp: (Torch.Tensor): The high precision cache seq_len offset tensor of each batch. Shape is (batch). + cache_bs_id_lp (Torch.Tensor): The low precision cache batch offset tensor of each batch. Shape is (batch). This tensor + can be None, that means cache_bs_id is (0, 1, 2, 3, 4, ...). + cache_seq_offset_lp: (Torch.Tensor): The low precision cache seq_len offset tensor of each batch. Shape is (batch). + key_scale_hp (Torch.Tensor): The high precision key scale tensor that quantize key to int8. Shape is (head_num_k, head_size). + value_scale_hp (Torch.Tensor): The high precision value scale tensor that quantize value to int8. Shape is (head_num_k, head_size). + key_scale_lp (Torch.Tensor): The low precision key scale tensor that quantize key to int8. Shape is (max_bs, head_num_k, group_num). + value_scale_lp (Torch.Tensor): The low precision value scale tensor that quantize value to int8. Shape is (max_bs, head_num_k, group_num). + slot_mapping_hp (Torch.Tensor): The slot_mapping tensor of high precision cache. Shape is (batch). + slot_mapping_lp (Torch.Tensor): The slot_mapping tensor of low precision cache. Shape is (batch). + eps (float): The layernorm eps param. + + Type: + qkv: half, bfloat16. + key_cache_hp: same as qkv if not quant kv else int8. + value_cache_hp: same as qkv if not quant kv else int8. + sin_table: same as qkv. + cos_table: same as qkv. + position_ids: int32. + gamma: same as qkv. + beta: same as qkv. + key_cache_lp: int8(int4x2). + value_cache_lp: int8(int4x2). + cache_bs_id_hp: int32. + cache_seq_offsets_hp: int32. + cache_bs_id_lp: int32. + cache_seq_offsets_lp: int32. + key_scale_hp: float. + value_scale_hp: float. + key_scale_lp: float. + value_scale_lp: float. + slot_mapping_hp: int32. + slot_mapping_lp: int32. + + Return: + If both k_cache_lp and k_scale_lp are None, return a tuple of (qkv, k_cache_hp, v_cache_hp). + If k_cache_lp exists but k_scale_lp is None, return a tuple of (qkv, k_cache_hp, v_cache_hp, k_cache_lp, v_cache_lp). + If both k_cache_lp and k_scale_lp exist, return a tuple of (qkv, k_cache_hp, v_cache_hp, k_cache_lp, v_cache_lp, k_scale_lp, v_scale_lp). + """ + torch.ops.torch_mlu_ops.fused_rope(qkv, k_cache_hp, v_cache_hp, k_cache_lp, v_cache_lp, + sin_cache, cos_cache, position_id, gamma, beta, k_scale_hp, v_scale_hp, + k_scale_lp, v_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp, + cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp, eps) + out = (qkv, k_cache_hp, v_cache_hp) + if k_cache_lp is not None: + out += (k_cache_lp, v_cache_lp) + if k_scale_lp is not None: + out += (k_scale_lp, v_scale_lp) + return out + +def update_out_and_lse(out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + seq_offsets: Optional[torch.Tensor] = None, + cu_seqs: Optional[torch.Tensor] = None, + block_cu_seqs: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]: + """ + Update out and log-sum-exp(lse) according to block out and block lse. + + Math: + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + Args: + out (torch.Tensor): The output tensor. In pad mode, shape is (batch, max_seq_len, head_num, head_size). + In pack mode, shape is (total_seq_len, head_num, head_size). + lse (torch.Tensor): The lse tensor. Shape is (batch, head_num, max_seq_len) + block_out (torch.Tensor): The output tensor. In pad mode, shape is (batch, block_seq_len, head_num, head_size). + In pack mode, shape is (total_block_seq_len, head_num, head_size). + block_lse (torch.Tensor): The lse tensor. Shape is (batch, head_num, block_seq_len) + seq_offsets (torch.Tensor): The seq offset of origin out and origin lse. Shape is (batch). + cu_seqs (torch.Tensor): The cumulative sum of out seq_lens. Shape is (batch + 1). + block_cu_seqs (torch.Tensor): The cumulative sum of block out seq_lens. Shape is (batch + 1). + Type: + out: half, bfloat16, float. + lse: float. + block_out: half, bfloat16, float. + block_lse: float. + seq_offsets: int32_t. + cu_seqs: int32_t. + block_cu_seqs: int32_t. + Return: + Support inplace outputs. + Directly return the given out and lse. + """ + torch.ops.torch_mlu_ops.update_out_and_lse(out, lse, block_out, block_lse, + seq_offsets, cu_seqs, block_cu_seqs) + return (out, lse) + +def dequant_from_linear_cache(key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor], + cache_bs_id: Optional[torch.Tensor], + cache_seq_offset: Optional[torch.Tensor], + quant_mode: int = 0, + quant_bit: int = 8) -> None: + """ + De-quantizes the key and value tensors from the provided linear cache and scale. + + This function is used during inference to convert the quantized key and value tensors stored in the cache back to + their original floating-point format. It supports both 8-bit and 4-bit quantization and different quantization modes + (per-channel or per-head). + + + Math: + Using the key de-quantization processes as an example: + len_i = context_lengths[i] + start_pos_i = sum(context_lentgths[:i]) if context_seq_offset is None else context_seq_offset[i] + end_pos_i = start_pos_i + len_i + key_i = key[start_pos_i:end_pos_i,...].transpose(1, 0) + cache_pos = cache_bs_id[i] + key_cache_i = key_cache[cache_pos,cache_seq_offset[i]:cache_seq_offset[i]+len[i],...] + key_scale_i = key_scale[i,...] if quant_per_batch else key_scale + if quant_mode == 0: + key_scale_i = key_scale_i.reshape(head_num, -1, head_size) + else: + key_scale_i = key_scale_i.reshape(head_num, seq, -1) + if quant_bit == 8: + key_i[:] = ((key_cache_i.to(torch.float32) * key_scale_i).to(key.dtype) + else: # quant_bit == 4 + key_unpack_i = key_i.flatten() + key_cache_flat_i = key_cache_i.flatten() + key_unpack_i[0::2] = key_cache >> 4 + key_unpack_i[1::2] = key_cache << 4 >> 4 + key_i[:] = (key_unupack_i.reshape(k_i).to(torch.float32) * k_scale_quant_scale).to(key.dtype) + + Args: + key (torch.Tensor): The key tensor with shape (total_seqlen, head_num, head_size). + value (torch.Tensor, optional): The value tensor with shape (total_seqlen, head_num, head_size). Default is None. + key_cache (torch.Tensor): The key cache tensor, shape depends on the quantization bit width, + where max_batch should be greater than or equal to the actual batch size. + - For 8-bit quantization: shape is (max_batch, head_num, cache_mem_len, head_size). + - For 4-bit quantization: shape is (max_batch, head_num, cache_mem_len, head_size//2). + value_cache (torch.Tensor, optional): The value cache tensor, shape depends on the quantization bit width. Default is None. + - For 8-bit quantization: shape is (max_batch, head_num, cache_mem_len, head_size). + - For 4-bit quantization: shape is (max_batch, head_num, cache_mem_len//2, head_size). + key_cache_quant_scale (torch.Tensor): The quantization scale tensor for the key cache. Shape depends on the quantization mode. + - For per-channel quantization: shape is (head_num, head_size). + - For per-token quantization: shape is (max_batch, head_num, cache_mem_len). + value_cache_quant_scale (torch.Tensor, optional): The quantization scale tensor for the value cache, same shape as key_cache_quant_scale. Default is None. + context_lengths (torch.Tensor): The actual lengths of the input sequences within the current context, shape (batch). + max_context_len (int32): The maximum length of all sequences in the current context, used to determine cache size. + context_seq_offset (torch.Tensor, optional): The starting position offset of each input sequence in the context, where cache_mem_len should be greater than or equal to max_context_len. + If not provided, it is calculated as the cumulative sum of context_lengths. Shape (batch). Default is None. + cache_bs_id (torch.Tensor, optional): The batch index in the cache where the key and value tensors will be placed. Shape (batch). Default is None. + cache_seq_offset (torch.Tensor, optional): The starting position offset of each input sequence in the cache. Shape (batch). Default is None. + quant_mode (int, optional): Quantization mode: 0 for per-channel quantization (each channel uses a different quantization scale), + 1 for per-token quantization (each token uses a different quantization scale). Default is 0. + quant_bit (int, optional): Quantization bit width: 8 for 8-bit quantization, 4 for 4-bit quantization. Default is 8. + + Type: + key: half, bfloat16 + value: half, bfloat16 + key_cache: int8 + value_cache: int8 + key_cache_quant_scale: float32 + value_cache_quant_scale: float32 + context_lengths: int32 + max_context_len: int32 + context_seq_offset: int32 + cache_bs_id: int32 + cache_seq_offset: int32 + quant_mode: int32 + quant_bit: int32 + + Return: + None. + """ + + torch.ops.torch_mlu_ops.dequant_from_linear_cache( + key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, + context_lengths, max_context_len, context_seq_offset, cache_bs_id, cache_seq_offset, + quant_mode, quant_bit + ) + +def dequant_from_paged_cache(key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor], + block_tables: torch.Tensor, + quant_mode: int = 0, + quant_bit: int = 8) -> None: + """ + Dequantize `key_cache` and `value_cache` to `key` and `value` based on quantization parameters. + + Math: + block_num = context_lengths[i] // block_size + rem_token_num = context_lengths[i] % block_size + # Index key_cache based on block_tables and context_lengths + key_cache_i = torch.concat( + [key_cache[block_tables[i, j], ...] for j in range(block_num)] + + ([key_cache[block_tables[i, block_num], :, :rem_token_num, :]] if rem_token_num > 0 else []) + ) + # Get context begin and end indices + context_begin_i = torch.cumsum(context_lengths, dims=-1)[i - 1] if context_seq_offset is None else context_seq_offset[i] + context_end_i = context_begin_i + context_lengths[i] + + # Slice key based on context_lengths and context_seq_offset + key_i = key[context_begin_i:context_end_i, ...] + + # Determine the scale based on quant_mode + if quant_mode == 0: + key_scale_i = key_cache_quant_scale[:, None, :] + elif quant_mode == 1: + key_scale_i = torch.concat( + [key_cache_quant_scale[block_tables[i, j], ..., None] for j in range(block_num)] + + ([key_cache_quant_scale[block_tables[i, block_num], :, :rem_token_num, None]] if rem_token_num > 0 else []) + ) + + # Dequantize key_cache to key + key_i[:] = (key_cache_i * key_scale_i).transpose(1, 0) + + Args: + key (torch.Tensor): Output tensor for dequantized key. Shape is (total_seqlens, head_num, head_size). + value (torch.Tensor): Output tensor for dequantized value. Shape is (total_seqlens, head_num, head_size). + key_cache (torch.Tensor): Quantized key cache tensor. Shape is (total_blocks, head_num, block_size, head_size). + value_cache (torch.Tensor): Quantized value cache tensor. Shape is (total_blocks, head_num, block_size, head_size). + key_cache_quant_scale (torch.Tensor): Quantization scale for key cache. Shape is (head_num, head_size) for per-channel or (max_batch, head_num, block_size) for per-token. + value_cache_quant_scale (torch.Tensor): Quantization scale for value cache. Shape is (head_num, head_size) for per-channel or (max_batch, head_num, block_size) for per-token. + context_lengths (torch.Tensor): Lengths of contexts for each batch. Shape is (batch). + max_context_len (int): Maximum context length. + context_seq_offset (torch.Tensor): Optional sequence offset for context lengths. Shape is (batch). + block_tables (torch.Tensor): Block tables for indexing. Shape is (batch, max_block_num). + quant_mode (int): Quantization mode. 0 for per-channel, 1 for per-token. Default is 0. + quant_bit (int): Quantization bit. Default is 8. + + Type: + key: half, bfloat16 + value: half, bfloat16 + key_cache: int8 + value_cache: int8 + key_cache_quant_scale: float32 + value_cache_quant_scale: float32 + context_lengths: int32 + max_context_len: int + context_seq_offset: int32 + block_tables: int32 + quant_mode: int + quant_bit: int + + Return: + None + """ + + torch.ops.torch_mlu_ops.dequant_from_paged_cache( + key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, + context_lengths, max_context_len, context_seq_offset, block_tables, quant_mode, quant_bit + ) + +# add dump_gese as decorator automatically for every function mentioned in __all__ +add_gen_case_decorator(inspect.currentframe()) diff --git a/torch_mlu_ops-v1.3.2/torch_mlu_ops/_utils.py b/torch_mlu_ops-v1.3.2/torch_mlu_ops/_utils.py new file mode 100644 index 0000000..c1d2090 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/torch_mlu_ops/_utils.py @@ -0,0 +1,121 @@ +import torch +from torch import Tensor +import os +import importlib.machinery +import inspect +from datetime import datetime +import sys +# sys.setrecursionlimit(10000) + +def enable_gen_case(): + def check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + enable_gen_case = check_env_flag('TMO_GEN_CASE') + enalbe_dump_data = check_env_flag('TMO_GEN_CASE_DUMP_DATA') + tmo_gen_case_op_name = os.getenv('TMO_GEN_CASE_OP_NAME') + op_names = tmo_gen_case_op_name.split(';') if tmo_gen_case_op_name else [] + tmo_gen_case_path = os.getenv('TMO_GEN_CASE_PATH', os.getcwd()) + return enable_gen_case, op_names, tmo_gen_case_path, enalbe_dump_data + +ENABLE_GEN_CASE, GEN_CASE_OP_NAMES, GEN_CASE_PATH, ENALBE_DUMP_DATA = enable_gen_case() + +def gen_case_dump(**kwargs): + torch.save(kwargs['map'], kwargs['path']) + +def dump_info(value): + info = dict() + if value.__class__ in (list, tuple): + cls_set = {elem.__class__ for elem in value} + if {torch.Tensor, list, tuple} & cls_set: + temp_li = [] + for elem in value: + temp_li.append(dump_info(elem)) + info = {'type': value.__class__, 'has_compound': True, 'data': temp_li} + else: + info = {'type': value.__class__, 'has_compound': False, 'data': value} + elif value.__class__ is dict: + temp_dic = dict() + for k, v in value.items(): + temp_dic[k] = dump_info(v) + info = {'type': value.__class__, 'data': temp_dic} + elif value.__class__ is torch.Tensor: + if value is None: + info = {'type': value.__class__, 'data': value} + else: + info = {'type': value.__class__, + 'shape': value.shape, + 'dtype': value.dtype, + 'device': value.device, + 'is_contiguous': value.is_contiguous(), + 'stride': value.stride(), + 'data': value if value.dtype in (torch.int16, torch.int32, torch.int64) else ''} + else: + info = {'type': value.__class__, 'data': value} + return info + +def dump_pt_case(func, *args, **kwargs): + op_name = func.__name__ + case_dir = os.path.join(GEN_CASE_PATH, 'tmo_gen_case', op_name) + os.makedirs(case_dir, exist_ok=True) + time_stamp = int(datetime.now().timestamp()*1e6) + case_path = os.path.join(case_dir, op_name + '_' + str(time_stamp) + '.pt') + print("[torch_mlu_ops] dump case ====> ", case_path) + + signature = inspect.signature(func) + param_objs = [obj for obj in signature.parameters.values()] # signature.parameters是有序的mapping + params_map = {"op": op_name, "dump_data": ENALBE_DUMP_DATA} + if ENALBE_DUMP_DATA: + for obj, value in zip(param_objs[:len(args)], args): + params_map[obj.name] = value + for obj in param_objs[len(args):]: + params_map[obj.name] = obj.default + params_map = {**params_map, **kwargs} + else: + for obj, value in zip(param_objs[:len(args)], args): + params_map[obj.name] = dump_info(value) + for obj in param_objs[len(args):]: + params_map[obj.name] = dump_info(obj.default) + for k, v in kwargs.items(): + if k in signature.parameters.keys(): + params_map[k] = dump_info(v) + gen_case_dump(map = params_map, path = case_path) + +def dump_case(func): + def wrapper_dump(*args, **kwargs): + dump_pt_case(func, *args, **kwargs) + return func(*args, **kwargs) + def wrapper_nothing(*args, **kwargs): + return func(*args, **kwargs) + return wrapper_dump if (ENABLE_GEN_CASE and \ + (GEN_CASE_OP_NAMES == [] or func.__name__ in GEN_CASE_OP_NAMES)) \ + else wrapper_nothing + +def add_gen_case_decorator(frame): + if ENABLE_GEN_CASE: + _current_module = inspect.getmodule(frame) + members = dict() + for k, v in inspect.getmembers(_current_module): + members[k] = v + + func_names = members['__all__'] if GEN_CASE_OP_NAMES == [] else GEN_CASE_OP_NAMES + for name in func_names: + setattr(_current_module, name, dump_case(members[name])) + +dtype2str_map = { torch.half: "half", + torch.float16: "half", + torch.float: "float", + torch.bfloat16: "bfloat16", + torch.int32: "int32", + torch.int8: "int8" + } + +def torchDtype2Str(torch_dtype: torch.dtype) -> str: + assert torch_dtype in dtype2str_map, "unrecognized torch type: {}".format(torch_dtype) + return dtype2str_map[torch_dtype] + + +def get_custom_op_library_path(): + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) + ext_specs = importlib.machinery.FileFinder(os.path.dirname(__file__), loader_details).find_spec("_C") + assert ext_specs is not None, "Could not find custom op library" + return ext_specs.origin diff --git a/torch_mlu_ops-v1.3.2/torch_mlu_ops/abstract.py b/torch_mlu_ops-v1.3.2/torch_mlu_ops/abstract.py new file mode 100644 index 0000000..d77ade2 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/torch_mlu_ops/abstract.py @@ -0,0 +1,627 @@ +from typing import Tuple, List, Dict, Optional +import torch +from torch import Tensor +import torch._custom_ops + +@torch._custom_ops.impl_abstract("torch_mlu_ops::attention_project") +def attention_project_abstract( + input: Tensor, + q_weight: Tensor, + q_bias: Optional[Tensor], + k_weight: Optional[Tensor], + k_bias: Optional[Tensor], + v_weight: Optional[Tensor], + v_bias: Optional[Tensor], + norm_weight: Optional[Tensor], + norm_bias: Optional[Tensor], + residual: Optional[Tensor], + out_layout: str, + head_size: int, + eps: float, + alpha: float, + beta: float, + norm_out: bool, +) -> List[Tensor]: + input_view = input + if input.dim() == 2: + input_view = input.unsqueeze(0) + n = input_view.size(0) + t = input_view.size(1) + hidden_size_q = q_weight.size(0) + hidden_size_k = k_weight.size(0) if k_weight is not None else 0 + hidden_size_v = v_weight.size(0) if v_weight is not None else 0 + head_num_q = hidden_size_q // head_size + head_num_k = hidden_size_k // head_size + head_num_v = hidden_size_v // head_size + + out_q = torch.empty(n, t, hidden_size_q, dtype=input_view.dtype, device=input_view.device) + out_k = torch.empty(n, t, hidden_size_k, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None + out_v = torch.empty(n, t, hidden_size_v, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None + + if out_layout == "nhtc": + out_q = torch.empty(n, head_num_q, t, head_size, dtype=input_view.dtype, device=input_view.device) + out_k = torch.empty(n, head_num_k, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_k > 0 else None + out_v = torch.empty(n, head_num_v, t, head_size, dtype=input_view.dtype, device=input_view.device) if hidden_size_v > 0 else None + + out_ln = torch.empty_like(input_view) if norm_out else None + res = [out_q] + if k_weight is not None: + res.append(out_k) + if v_weight is not None: + res.append(out_v) + if norm_out: + res.append(out_ln) + + return res + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::ffn") +def ffn_abstract( + input: Tensor, + up_fc_weight: Tensor, + up_fc_bias: Optional[Tensor], + down_proj_weight: Tensor, + down_proj_bias: Optional[Tensor], + gate_up_proj_weight: Optional[Tensor], + gate_up_proj_bias: Optional[Tensor], + layernorm_weight: Optional[Tensor], + layernorm_bias: Optional[Tensor], + act_mode: str, + residual_is: str, + eps: float, + alpha: float, + beta: float, +) -> Tensor: + return torch.empty_like(input) + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attention") +def flash_attention_abstract( + q: Tensor, + k: Tensor, + v: Tensor, + out: Tensor, + output_lse: Optional[Tensor], + cu_seq_lens_q: Optional[Tensor], + cu_seq_lens_kv: Optional[Tensor], + alibi_slope: Optional[Tensor], + attn_bias: Optional[Tensor], + k_quant_scale: Optional[Tensor], + v_quant_scale: Optional[Tensor], + block_tables: Optional[Tensor], + max_seq_len_q: int, + max_seq_len_kv: int, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + compute_dtype: str, + return_lse: bool, +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::single_query_cached_kv_attn") +def single_query_cached_kv_attn_abstract( + q_ori: Tensor, + k_cache: Tensor, + v_cache: Tensor, + output: Tensor, + block_tables: Tensor, + context_lens: Tensor, + output_lse: Optional[Tensor], + k_cache_quant_scale: Optional[Tensor], + v_cache_quant_scale: Optional[Tensor], + alibi_slopes: Optional[Tensor], + max_contxt_len: int, + windows_size_left: int, + windows_size_right: int, + softmax_scale: float, + return_lse: bool, + kv_cache_quant_bit_size: int +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::apply_rotary") +def apply_rotary_abstract( + input: Tensor, + sin_cache: Tensor, + cos_cache: Tensor, + position_ids: Optional[Tensor], + cu_seqlens: Optional[Tensor], + interleaved: bool, + discrete: bool, + dynamic_ntk: bool, + max_seqlen: int, +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_linear_cache") +def reshape_linear_cache_abstract( + key: Tensor, + value: Optional[Tensor], + key_cache: Tensor, + value_cache: Optional[Tensor], + context_lengths: Tensor, + max_context_len: int, + packed: bool, + context_seq_offset: Optional[Tensor], + cache_bs_id: Optional[Tensor], + cache_seqlen_offset: Optional[Tensor], +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::reshape_paged_cache") +def reshape_paged_cache_abstract( + k: Tensor, + v: Optional[Tensor], + k_cache: Tensor, + v_cache: Optional[Tensor], + slot_mapping: Tensor +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_paged_cache") +def quant_to_paged_cache_abstract( + k: Tensor, + v: Optional[Tensor], + k_cache: Tensor, + v_cache: Optional[Tensor], + k_cache_scale: Tensor, + v_cache_scale: Optional[Tensor], + slot_mapping: Tensor, +) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_paged_cache") +def offline_quant_to_paged_cache_abstract( + k: Tensor, + v: Optional[Tensor], + k_cache_scale: Tensor, + v_cache_scale: Optional[Tensor], + slot_mapping: Tensor, + k_cache: Tensor, + v_cache: Optional[Tensor], +) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_to_linear_cache") +def quant_to_linear_cache_abstract( + key: Tensor, + value: Optional[Tensor], + key_cache: Tensor, + value_cache: Optional[Tensor], + key_cache_scale: Tensor, + value_cache_scale: Optional[Tensor], + context_lengths: Tensor, + max_context_len: int, + packed: bool, + context_seq_offset: Optional[Tensor], + cache_bs_id: Optional[Tensor], + cache_seqlen_offset: Optional[Tensor], + quant_bit: int = 8, +) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::offline_quant_to_linear_cache") +def offline_quant_to_linear_cache_abstract( + key: Tensor, + value: Optional[Tensor], + key_cache: Tensor, + value_cache: Optional[Tensor], + key_cache_scale: Tensor, + value_cache_scale: Optional[Tensor], + context_lengths: Tensor, + max_context_len: int, + quant_mode: int, + packed: bool, + context_seq_offset: Optional[Tensor], + cache_bs_id: Optional[Tensor], + cache_seqlen_offset: Optional[Tensor], +) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::swap_blocks") +def swap_blocks_abstract( + dst: Tensor, src: Tensor, block_mapping: Dict[int, int] +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks") +def copy_blocks_abstract( + k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]] +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::copy_blocks_out_of_place") +def copy_blocks_out_of_place_abstract(k_caches: List[Tensor], v_caches: List[Tensor], block_mapping: Dict[int, List[int]]) -> (List[Tensor], List[Tensor]): + return ([torch.empty_like(k) for k in k_caches], [torch.empty_like(v) for v in v_caches]) + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul") +def quant_matmul_abstract( + a_tensor: Tensor, + a_scale: Optional[Tensor], + a_zero: Optional[Tensor], + b_tensor: Tensor, + b_scale: Optional[Tensor], + b_zero: Optional[Tensor], + bias: Optional[Tensor], + c_tensor: Optional[Tensor], + c_scale: Optional[Tensor], + c_zero: Optional[Tensor], + gemm_output_scale: Optional[Tensor], + gemm_output_zero: Optional[Tensor], + data_type: Optional[str], + d: Optional[Tensor], + quant_algo: str, + a_quant_layout: str, + b_quant_layout: str, + quant_bit_size: int = 8, + act_mode: str = "none", + use_hp_active: bool = False, + act_coef: float = 1.0, + alpha: float = 1.0, + beta: float = 1.0, + trans_a: bool = False, + trans_b: bool = True, +) -> Tensor: + + if data_type is None: + output_type = a_tensor.dtype + elif data_type == "float": + output_type = torch.float32 + elif data_type == "bfloat16": + output_type = torch.bfloat16 + else: + output_type = torch.float16 + return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::quant_matmul_allreduce") +def quant_matmul_allreduce_abstract( + cncl_comm, + a_tensor: torch.Tensor, + a_scale: Optional[torch.Tensor], + a_zero: Optional[torch.Tensor], + b_tensor: torch.Tensor, + b_scale: Optional[torch.Tensor], + b_zero: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + c_tensor: Optional[torch.Tensor], + c_scale: Optional[torch.Tensor], + c_zero: Optional[torch.Tensor], + gemm_output_scale: Optional[torch.Tensor], + gemm_output_zero: Optional[torch.Tensor], + data_type: Optional[str], + d: Optional[torch.Tensor], + quant_algo: str, + a_quant_layout: str, + b_quant_layout: str, + quant_bit_size: int = 8, + alpha: float = 1.0, + beta: float = 1.0, + trans_a: bool = False, + trans_b: bool = True, + block_m: int = 0 +) -> torch.Tensor: + output_type = torch.float16 + if data_type == "float": + output_type = torch.float32 + elif data_type == "bfloat16": + output_type = torch.bfloat16 + return torch.empty(a_tensor.size(0), b_tensor.size(0), dtype=output_type, device=a_tensor.device) + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::active") +def active_abstract(input: Tensor, + output: Tensor, + bias: Optional[Tensor], + cusum_token_count: Optional[Tensor], + act_mode: str, + is_gated: bool, + start_expert_id: int = 0, + expert_size: int = 0, + active_coef: float = 1.0) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::smooth_quant") +def smooth_quant_abstract( + input: Tensor, + input_scale: Tensor, + output: Tensor, + output_scale: Tensor, + input_zero: Optional[Tensor], + token_count: Optional[Tensor], + gather_index: Optional[Tensor], + gather_index_start_position: Optional[Tensor], + quant_mode: str, + dynamic_quant: bool +) -> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_layernorm") +def fused_layernorm_abstract( + input: Tensor, + output: Tensor, + residual: Optional[Tensor], + beta: Optional[Tensor], + gamma: Optional[Tensor], + bias: Optional[Tensor], + quant_scale: Optional[Tensor], + residual_out: Optional[Tensor], + smooth_quant_scale: Optional[Tensor], + norm_mode: str, + eps: float, + store_output_before_norm: bool, + dynamic_quant: bool, +)-> None: + return None + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_moe") +def fused_moe_abstract( + hidden_states: Tensor, + gating_output: Tensor, + w1: Tensor, + w2: Tensor, + bias1: Optional[Tensor], + bias2: Optional[Tensor], + residual: Optional[Tensor], + input_smooth: Optional[Tensor], + act_smooth: Optional[Tensor], + w1_scale: Optional[Tensor], + w2_scale: Optional[Tensor], + topk: int, + renormalize: bool, + gated: bool, + act_mode: str, + start_expert_id: int, + block_n: int, + cncl_comm: int, + w1_quant_flag: Optional[List], + w2_quant_flag: Optional[List] + ) -> Tensor: + return torch.empty_like(hidden_states) + + +@torch._custom_ops.impl_abstract("torch_mlu_ops::matmul") +def matmul_abstract( + a: Tensor, + b: Tensor, + d: Optional[Tensor], + bias: Optional[Tensor], + c: Optional[Tensor], + data_type: Optional[str], + act_mode: str, + alpha: float, + beta: float, + fast_act: bool, + approximate: bool, + a_scale: float, + b_scale: float, + trans_a: bool, + trans_b: bool +) -> Tensor: + m = a.size(1) if trans_a else a.size(0) + n = b.size(0) if trans_b else b.size(1) + if data_type is None: + output_type = a.dtype + elif data_type == "float": + output_type = torch.float32 + elif data_type == "bfloat16": + output_type = torch.bfloat16 + else: + output_type = torch.half + return torch.empty(m, n, dtype=output_type, device=a.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::batch_matmul") +def batch_matmul_abstract( + a: Tensor, + b: Tensor, + c: Tensor, + alpha: float, + beta: float, + a_scale: float, + b_scale: float, + trans_a: bool, + trans_b: bool +) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::matmul_allreduce") +def matmul_allreduce_abstract( + cncl_comm, + a: torch.Tensor, + b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + d: Optional[torch.Tensor] = None, + alpha: float = 1.0, + beta: float = .0, + block_m: int = 0 +) -> Tensor: + return torch.empty(a.size(0), b.size(0), dtype=a.dtype, device=a.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::group_gemm") +def group_gemm_abstract( + a: Tensor, + b: Tensor, + m_list: Tensor, + expand_idx: Optional[Tensor], + c: Optional[Tensor], + alpha: Optional[Tensor], + beta: Optional[Tensor], + a_scale: Optional[Tensor], + b_scale: Optional[Tensor], + bias: Optional[Tensor], + data_type: Optional[str], + quant_flag: Optional[List], + b_offset: Optional[Tensor], + max_m: int +) -> Tensor: + if data_type is None: + output_type = a.dtype + elif data_type == "float": + output_type = torch.float32 + elif data_type == "bfloat16": + output_type = torch.bfloat16 + else: + output_type = torch.half + total_m = a.size(0) if expand_idx is None else expand_idx.size(0) + + return torch.empty(total_m, b.size(1), dtype=output_type, device=a.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::preload") +def preload_abstract( + weight: Tensor, + size: int +) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::flash_attn_sq_mm_allreduce") +def flash_attn_sq_mm_allreduce_abstract(cncl_comm, + q, + k, + v, + cu_seq_lens_q, + cu_seq_lens_k, + alibi_slope, + attn_bias, + smooth, + weight, + weight_scale, + bias, + max_seq_len_q, + max_seq_len_kv, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + compute_dtype, + block_seq) -> torch.Tensor: + res_q = q.unsqueeze(0) if cu_seq_lens_q is not None else q + res_q = res_q.flatten(-2, -1).flatten(0, 1) + return torch.empty(res_q.size(0), weight.size(0), dtype=q.dtype, device=q.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_softmax_topk") +def moe_softmax_topk_abstract(input, + topk, + num_expert_group, + topk_group, + normalize, + mask: Optional[torch.Tensor] = None, + normed_by: str = "topk_logit") -> Tuple[torch.Tensor, torch.Tensor]: + out_shape = list(input.size())[:-1] + [topk] + reduce_weight = torch.empty(out_shape, dtype=torch.float32, device=input.device) + expert_id = torch.empty(out_shape, dtype=torch.int, device=input.device) + return (reduce_weight, expert_id) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_expand_input") +def moe_expand_input_abstract(input: Tensor, + gather_idx: Tensor, + cusum_token_count: Optional[Tensor] = None, + start_expert_id: int = 0, + expert_size: int = 0) -> Tensor: + return torch.empty(gather_idx.size(0), input.size(-1), dtype=input.dtype, device=input.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_gen_idx") +def moe_gen_idx_abstract(expert_id: Tensor, + expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + token_num, topk = expert_id.size(0), expert_id.size(1) + expand_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device) + combine_idx = torch.empty((token_num * topk), dtype=torch.int32, device=expert_id.device) + token_count = torch.empty((expert_num,), dtype=torch.int32, device=expert_id.device) + cusum_token_count = torch.empty((expert_num + 1,), dtype=torch.int32, device=expert_id.device) + return (expand_idx, combine_idx, token_count, cusum_token_count) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_combine_result") +def moe_combine_result_abstract(input: torch.Tensor, + reduce_weight: torch.Tensor, + gather_ids: torch.Tensor, + residual: Optional[torch.Tensor], + cusum_token_count: Optional[torch.Tensor], + start_expert_id: int, + expert_size: int, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + num_tokens, hidden_size, topk = input.size(0), input.size(1), reduce_weight.size(1) + num_token = num_tokens // topk + return torch.empty(num_token, hidden_size, dtype=input.dtype, device=input.device) + +@torch._custom_ops.impl_abstract("torch_mlu_ops::fused_rope") +def fused_rope_abstract(qkv: torch.Tensor, + key_cache_hp: torch.Tensor, + value_cache_hp: torch.Tensor, + key_cache_lp: Optional[torch.Tensor], + value_cache_lp: Optional[torch.Tensor], + sin_table: torch.Tensor, + cos_table: torch.Tensor, + position_ids: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + key_scale_hp: Optional[torch.Tensor], + value_scale_hp: Optional[torch.Tensor], + key_scale_lp: Optional[torch.Tensor], + value_scale_lp: Optional[torch.Tensor], + cache_bs_id_hp: Optional[torch.Tensor], + cache_seq_offsets_hp: Optional[torch.Tensor], + cache_bs_id_lp: Optional[torch.Tensor], + cache_seq_offsets_lp: Optional[torch.Tensor], + slot_mapping_hp: Optional[torch.Tensor], + slot_mapping_lp: Optional[torch.Tensor], + eps: float) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::moe_cast_gating") +def moe_cast_gating_abstract(input: torch.Tensor, + weight: torch.Tensor) ->Tensor: + output_shape = input.shape[:-1] + (weight.shape[0],) + output = torch.empty(output_shape, dtype=torch.float, device="mlu") + return output + +@torch._custom_ops.impl_abstract("torch_mlu_ops::update_out_and_lse") +def update_out_and_lse_abstract(out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + seq_offsets: Optional[torch.Tensor] = None, + cu_seqs: Optional[torch.Tensor] = None, + block_cu_seqs: Optional[torch.Tensor] = None) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_linear_cache") +def dequant_from_linear_cache_abstract(key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor], + cache_bs_id: Optional[torch.Tensor], + cache_seq_offset: Optional[torch.Tensor], + quant_mode: int = 0, + quant_bit: int = 8) -> None: + return None + +@torch._custom_ops.impl_abstract("torch_mlu_ops::dequant_from_paged_cache") +def dequant_from_paged_cache_abstract(key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor], + block_tables: torch.Tensor, + quant_mode: int = 0, + quant_bit: int = 8) -> None: + return None diff --git a/torch_mlu_ops-v1.3.2/version.tpl b/torch_mlu_ops-v1.3.2/version.tpl new file mode 100644 index 0000000..bca7075 --- /dev/null +++ b/torch_mlu_ops-v1.3.2/version.tpl @@ -0,0 +1 @@ +__version__ = '{{VERSION}}'