22 Commits

Author SHA1 Message Date
Chranos
3fed2190ad fixing kvcache bug 2026-02-06 16:39:42 +08:00
Chranos
c1b6f39a11 fix: pass lm_head to LogitsProcessor instead of calling forward()
In vLLM v0.6.2, ParallelLMHead.forward() raises RuntimeError since
its weights should be used through LogitsProcessor.linear_method.apply().
Pass lm_head as first arg to LogitsProcessor which handles the
hidden_states -> logits projection internally.
2026-02-06 15:05:49 +08:00
Chranos
3e301ce158 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
87f96e1001 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
e1a2afd244 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
63a1a05999 testing dynamic register 2026-02-06 15:05:49 +08:00
Chranos
6d814b0cd4 testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
dc239a740c testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
a476b6458b testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
80e9a636af testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
16353d5d2a testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
70bee4e3ec testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
83c958a7c5 testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
9b84dd52be testing dynamic register 2026-02-06 15:05:48 +08:00
Chranos
2cb9f6ce1d testing dynamic register 2026-02-06 15:05:48 +08:00
31e7cd3bf9 删除 .DS_Store 2026-02-05 16:21:10 +08:00
Chranos
6b650ae280 add gitignore 2026-02-05 16:19:33 +08:00
Chranos
92f0016e6f add dynamic register 2026-02-05 15:53:43 +08:00
Chranos
9563c9af0d opt llama3 2026-02-05 11:53:52 +08:00
Chranos
3b3e614cb6 opt llama3 2026-02-05 11:42:01 +08:00
Chranos
3cf13dd8c5 add ops 2026-02-04 17:51:35 +08:00
Chranos
79dfc69789 add ops 2026-02-04 17:39:32 +08:00
315 changed files with 58053 additions and 31 deletions

BIN
.DS_Store vendored

Binary file not shown.

240
.gitignore vendored Normal file
View File

@@ -0,0 +1,240 @@
# version file generated by setuptools-scm
/vllm/_version.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*
# FlashMLA interface copied from source
vllm/third_party/flashmla/flash_mla_interface.py
# triton jit
.triton
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
/.deps/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# generated files
**/generated/**
# uv
uv.lock
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
docs/argparse
docs/examples/*
!docs/examples/README.md
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VSCode
.vscode/
# Claude
CLAUDE.md
.claude/
# Codex
AGENTS.md
.codex/
# Cursor
.cursor/
# DS Store
.DS_Store
# Results
*.csv
# Python pickle files
*.pkl
# Sphinx documentation
_build/
# vim swap files
*.swo
*.swp
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h
# Benchmark dataset
benchmarks/**/*.json
# Linting
actionlint
shellcheck*/
# Ignore moe/marlin_moe gen code
csrc/moe/marlin_moe_wna16/kernel_*
# Ignore ep_kernels_workspace folder
ep_kernels_workspace/
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
!vllm/benchmarks/lib/
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi

View File

@@ -3,6 +3,7 @@
# 寒武纪 mlu370 文本生成
该模型测试框架在寒武纪mlu370 X8/X4加速卡上基于vllm 推理引擎,适配了 Qwen1.5-1.8B-Chat 模型。
* Qwen1.5-1.8B-Chat 是通义千问系列中一款约18亿参数、轻量级的中英文对话大模型专为高效推理和多场景聊天交互设计。
* Llama-2-7b-chat-hfMeta 发布的 LLaMA 2 系列中 70 亿参数的对话优化版开源大模型,适合多轮聊天与通用任务。
* ChatGLM3-6B智谱 AI 推出的第 3 代 ChatGLM 系列中 60 亿参数的中英双语对话大模型,支持推理、代码和多任务能力。

View File

@@ -0,0 +1,4 @@
BasedOnStyle: Chromium
ColumnLimit: 100
PointerAlignment: Right
AllowShortIfStatementsOnASingleLine: true

View File

@@ -0,0 +1,15 @@
Checks: "bugprone-*,google-*,-google-explicit-constructor,cppcoreguidelines-avoid-const-or-ref-data-members,cppcoreguidelines-init-variables,cppcoreguidelines-interfaces-global-init,cppcoreguidelines-misleading-capture-default-by-value,cppcoreguidelines-missing-std-forward,cppcoreguidelines-no-malloc,cppcoreguidelines-pro-type-member-init,cppcoreguidelines-rvalue-reference-param-not-moved,cppcoreguidelines-slicing,cppcoreguidelines-virtual-class-destructor,performance-unnecessary-copy-initialization",
CheckOptions: [
{
key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion,
value: false
},
{
key: cppcoreguidelines-narrowing-conversions.WarnOnIntegerToFloatingPointNarrowingConversion,
value: false
},
{
key: cppcoreguidelines-narrowing-conversions.IgnoreConversionFromTypes,
value: size_t;ptrdiff_t;size_type;difference_type
}
]

View File

@@ -0,0 +1,13 @@
.cache/
.devcontainer/
build/
__pycache__/
.vscode/
*.so
*.out
*.o
core_dump_*
record.txt
*.log
CMakePresets.json
.clangd

4
torch_mlu_ops-v1.3.2/.gitattributes vendored Normal file
View File

@@ -0,0 +1,4 @@
# git archive exclude files
tools/ci export-ignore
docker export-ignore
# legacy export-ignore

25
torch_mlu_ops-v1.3.2/.gitignore vendored Normal file
View File

@@ -0,0 +1,25 @@
.cache/
.devcontainer/
build/
packages/
__pycache__/
.idea/
.vscode/
*.so
*.out
*.o
*.pyc
*.deb
*.pt
core_dump_*
record.txt
*.log
CMakePresets.json
.clangd
compile_flags.txt
legacy/samples/pytorch/outputs/
legacy/src/thirdparty/bangtransformer.egg-info/
legacy/src/thirdparty/dist/
torch_mlu_ops.egg-info/
torch_mlu_ops/_version.py
dist/

View File

@@ -0,0 +1,15 @@
leak:libtorch_cpu.so
leak:libtorch_mlu.so
leak:libtorch_mlu_python.so
leak:libtriton.so
leak:libtorch_python.so
leak:libcatch_python.so
leak:pybind11::cpp_function::initialize_generic
leak:pybind11::cpp_function::make_function_record
leak:pybind11::detail::process_attribute
leak:libstdc++.so
leak:numpy
leak:Objects
leak:Modules
leak:Python
leak:libcrypto.so

View File

@@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2024, Cambricon Technologies
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,183 @@
<div align="center">
Torch-MLU-Ops
===========================
![system](https://img.shields.io/badge/system-ubuntu22.04_debian10.11-blue.svg)![python](https://img.shields.io/badge/python-3.10-green)![pytorch](https://img.shields.io/badge/pytorch-2.1_2.4_2.5-green)![release](https://img.shields.io/badge/release-1.3.2-green.svg)![license](https://img.shields.io/badge/license-BSD_3-blue.svg)
<div align="center">
<b>
<a href="https://www.cambricon.com/docs/sdk_1.15.0/bangtransformer_0.4.0/index.html">
<font size="4"> 📖 Torch_MLU_Ops用户手册</font>
</a>
</b>
&nbsp;&nbsp;&nbsp;&nbsp;
<b>
<a href="https://developer.cambricon.com/">
<font size="4"> 🌏 寒武纪开发者社区</font>
</a>
</b>
&nbsp;&nbsp;&nbsp;&nbsp;
<b>
<a href="https://sdk.cambricon.com/download?sdk_version=V1.13.0&component_name=Basis">
<font size="4"> 🛠️ 依赖组件获取</font>
</a>
</b>
</div>
---
<div align="left">
## 目录
- [Torch-MLU-Ops](#torch-mlu-ops)
- [目录](#目录)
- [简介](#简介)
- [安装](#安装)
- [环境要求](#环境要求)
- [通过docker镜像启动](#通过docker镜像启动)
- [编译安装](#编译安装)
- [测试](#测试)
- [目录结构](#目录结构)
- [关键特性](#关键特性)
## 简介
Torch-MLU-Ops是寒武纪设计和开发的PyTorch第三方算子库。对于使用PyTorch框架的开发者通过Torch-MLU-Ops能够便捷地使用这些自定义算子进行算子的集成、评测和业务部署。
Torch-MLU-Ops已全量覆盖LLMLarge Language Model推理场景下的常见算子。作为后端已支持Cambricon vLLM、Cambricon TGI、Cambricon Stable Diffusion web UI、Cambricon ComfyUI以及Cambricon Diffusers。
## 安装
### 环境要求
Torch-MLU-Ops支持的操作系统以及软件依赖如下。
不同平台支持的软件版本如下:
| 操作系统 | 编译器版本 | CMake版本 |
| ------------------------| ----------------------| -------------------------- |
| Ubuntu22.04 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 |
| Debian10.11 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 |
编译、运行Torch-MLU-Ops同时还需要配置以下依赖。
| Torch-MLU-Ops | Torch-MLU | CNNL | CNNL_Extra | CNCL | CNToolkit |
| -----------------| ------------------| ---------| --------------- | --------| -----------|
| v1.3.2 | v1.24.1 | v1.28.3 | v1.12.3 | v1.24.1 | v3.15.6 |
| v1.3.1 | v1.24.1 | v1.28.3 | v1.12.2 | v1.24.1 | v3.15.5 |
| v1.3.0 | v1.24.0 | v1.28.2 | v1.12.1 | v1.24.0 | v3.15.4 |
| v1.2.3 | v1.24.0 | v1.28.1 | v1.12.1 | v1.24.0 | v3.15.3 |
| v1.2.2 | v1.23.1 | v1.27.4 | v1.11.3 | v1.23.0 | v3.14.3 |
| v1.2.1 | v1.23.1 | v1.27.3 | v1.11.3 | v1.22.3 | v3.14.2 |
| v1.2.0 | v1.23.0 | v1.27.1 | v1.11.1 | v1.22.1 | v3.14.1 |
| v1.1.4 | v1.22.2 | v1.26.7 | v1.10.4 | v1.21.1 | v3.13.7 |
| v1.1.3 | v1.22.2 | v1.26.6 | v1.10.3 | v1.21.1 | v3.13.5 |
此外运行Torch-MLU-Ops还依赖PyTorch环境。Torch-MLU-Ops已支持的PyTorch版本请参考。
| Torch-MLU-Ops | PyTorch |
| -----------------| ------------------|
| v1.3.2 | v2.1, v2.4, v2.5 |
| v1.3.1 | v2.1, v2.4, v2.5 |
| v1.3.0 | v2.1, v2.4, v2.5 |
| v1.2.3 | v2.1, v2.4, v2.5 |
| v1.2.2 | v2.1, v2.4, v2.5 |
| v1.2.1 | v2.1, v2.4, v2.5 |
| v1.2.0 | v2.1, v2.4, v2.5 |
| v1.1.4 | v2.1, v2.3, v2.4 |
| v1.1.3 | v2.1, v2.3 |
### 通过docker镜像启动
寒武纪提供的解决方案镜像已提供了Torch-MLU-Ops所需要的依赖请参考Cambricon PyTorch Container Image使用。您可以从[寒武纪开发者社区](https://account.cambricon.com/interaction/SMmmcGFmw837gghtZ7VR4)获取该镜像。
### 编译安装
Torch-MLU-Ops编译脚本依赖于`NEUWARE_HOME`环境变量该环境变量指向寒武纪CNToolkit的安装路径通常为`/usr/local/neuware`
在运行示例程序前,您需要将`NEUWARE_HOME`中的`lib64`目录添加到`LD_LIBRARY_PATH`环境变量中,例如:
```shell
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NEUWARE_HOME/lib64
```
如果您使用寒武纪提供的docker镜像以上环境变量均已经被设置好了。
设置好`NEUWARE_HOME`环境变量后您可以使用以下命令编译Torch-MLU-Ops。
从远端clone仓库进入项目的主目录
```shell
cd torch_mlu_ops
```
源码安装:
```shell
pip install -e .
```
wheel包安装
```shell
python setup.py bdist_wheel
pip install dist/torch_mlu_ops*.whl
```
> _注意如果安装过程中遇到权限问题请获取相关权限或切换到拥有权限的user执行pip install 这步操作。_
移除Torch-MLU-OpsOptional
```shell
pip uninstall torch_mlu_ops
```
### 测试
在测试前请确保Torch-MLU-Ops已安装您可以通过 `pip list` 命令查询Torch-MLU-Ops的安装情况。若有如下类似打印则表明安装成功。
```shell
torch-mlu-ops 1.1.0+pt21
```
Torch-MLU-Ops模块的导入方法请参考。
```python
import torch_mlu_ops as tmo
```
`./tests/ops_pytest`目录下,提供了`ops`级别的测试示例程序,您可以参考如下命令进行测试。
```shell
cd tests/ops_pytest/
./run_test.sh
# or 单独测试某个测例,例如
python test_flash_attention.py
```
`./tests/kernels_pytest`目录下,提供了`kernels`级别的测试示例程序,您可以参考如下命令进行测试。
```shell
cd tests/kernels_pytest
./build.sh # 需要先编译
./run_test.sh
```
## 目录结构
```
torch_mlu_ops
|-- benchmarks ## 性能benchmarks测试代码目录
|-- csrc ## C以及BangC源码目录
| |-- common
| |-- kernels
| |-- ops
| `-- torch_api
|-- docs ## 文档目录
| |-- release_notes
| `-- user_guide
|-- tests ## 测试代码目录包含了kernel和ops级别的测试
| |-- kernels_pytest
| `-- ops_pytest
|-- tools ## 工具目录主要包含CI相关代码
| |-- ci
`-- torch_mlu_ops ## PyTorch第三方算子接口实现相关代码
```
## 关键特性
Torch-MLU-Ops提供的自定义算子列表以及接口说明请参考[Torch-MLU-Ops用户手册](./docs/user_guide/)。

View File

@@ -0,0 +1,51 @@
## benchmark测试脚本使用方式
Torch-MLU-Ops benchmark测试脚本为用户提供了进行算子性能测试的便捷入口。
用户可通过以下命令获取各个参数的含义。
```bash
# 测试命令帮助
python3 benchmark_xxx.py --help
```
各个参数含义如下:
`options`:
- -h, --help show this help message and exit
- --repeat_times REPEAT_TIMES repeat times for testing
- --csv write the report data to csv
- -o O specify the output folder name under --csv mode
```bash
# 测试命令示例如下
python3 benchmark_active.py --repeat_times 10 --csv -o './active/'
```
支持如下算子:
| op_name |
| ---------------------------------|
| active |
| apply_rotary |
| attention_project |
| ffn |
| flash_attn |
| fused_layer_norm |
| fused_moe |
| fused_norm_attention_project |
| fused_norm_residual_ffn |
| fused_rms_norm |
| group_gemm |
| matmul |
| offline_quant_to_linear_cache |
| per_token_smooth_quantize |
| preload |
| quantize |
| reshape_linear_cache |
| quant_to_linear_cache |
| reshape_paged_cache |
| single_query_cached_kv_attn |
| smooth_quant_matmul |
| weight_only_quant_matmul |
| moe_gen_idx |
| moe_expand_input |
| moe_softmax_topk |
| moe_combine_result |

View File

@@ -0,0 +1,64 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 1024, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 4096, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 8192, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
{"batch": 32768, "seq_len": 5, "hidden_size": 1024,
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["input_shape", "act_mode", "is_gated", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
act_mode = params_dict["act_mode"]
is_gated = params_dict["is_gated"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.active,
input,
act_mode,
is_gated,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated)
io_eff = io_bytes / hardware_time / bd
content = [f"{batch,seq_len,hidden_size}", f"{act_mode}", f"{is_gated}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,84 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "rotary_dim": 64,
"interleaved": True, "discrete": False, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "rotary_dim": 64,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "rotary_dim": 96,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 4, "seq_len": 1, "head_num": 80, "head_size": 128, "rotary_dim": 128,
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "head_num", "head_size", "rotary_dim", "interleaved", "discrete", "dynamic_ntk", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
head_num = params_dict["head_num"]
head_size = params_dict["head_size"]
# full/partial
rotary_dim = params_dict["rotary_dim"]
# cross/fold
interleaved = params_dict["interleaved"]
# discrete
discrete = params_dict["discrete"]
# dynamic_ntk
dynamic_ntk = params_dict["dynamic_ntk"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, head_num, head_size).to(device).to(dtype) # [batch, seqlen, head_num, head_size]
if dynamic_ntk:
sin_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
cos_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
else:
sin_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
cos_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
if discrete:
pos_ids = torch.randint(0, seq_len, (batch * seq_len,)).to(device).to(torch.int32)
else:
pos_ids = None
hardware_time, e2e_time = benchmark_forward(tmo.apply_rotary,
input,
sin_cache,
cos_cache,
pos_ids,
None,
interleaved,
discrete,
dynamic_ntk,
seq_len,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{head_num}", f"{head_size}", f"{rotary_dim}", f"{interleaved}", f"{discrete}", f"{dynamic_ntk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,74 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "hidden_size": 1600,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 2048, "hidden_size": 2048,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 4096, "hidden_size": 4096,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6144, "hidden_size": 6144,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6656, "hidden_size": 6656,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 8192, "hidden_size": 8192,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 12288, "hidden_size": 12288,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 14336, "hidden_size": 14336,
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "input_size", "hidden_size", "has_residual", "has_bias", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
input_size = params_dict["input_size"]
hidden_size = params_dict["hidden_size"]
has_residual = params_dict["has_residual"]
has_bias = params_dict["has_bias"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
x = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
weight = torch.randn(hidden_size, input_size).to(dtype).to(device)
residual, bias = None, None
if has_residual:
residual = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
if has_bias:
bias = torch.randn(hidden_size).to(dtype).to(device)
hardware_time, e2e_time = benchmark_forward(tmo.attention_project,
x,
weight,
bias,
residual,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{has_residual}", f"{has_bias}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,60 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 2, "m": 1024, "k": 1600, "n": 6400, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 2048, "n": 8192, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 4096, "n": 11008, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 5120, "n": 16384, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 2, "m": 1024, "k": 6144, "n": 24576, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "m", "k", "n", "has_c", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
has_c = params_dict["has_c"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(batch, m, k).to(device).to(dtype)
b = torch.randn(batch, n, k).to(device).to(dtype)
c = None
if has_c:
c = torch.randn(batch, m, n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.batch_matmul,
a,
b,
c,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{batch}", f"{m}", f"{k}", f"{n}", f"{has_c}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,120 @@
import argparse
import random
import os
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
from itertools import product
from tabulate import tabulate
e2e_time_param_dict_list = [
{"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
"head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "head_size": 128,
"input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [4, 8],
"use_offset": True},
]
def main():
parser = argparse.ArgumentParser(description="Benchmark for dequant from linear cache.")
parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["max_batch_size", "batch_size", "max_context_len", "head_num_q", "head_num_kv",
"cache_mem_len", "head_size", "input_dytpe", "quant_mode", "quant_bit",
"use_offset", "hardware_time(us)", "e2e_latency(us)"]
contents = []
mlu_name = torch.mlu.get_device_name()
for params_dict in e2e_time_param_dict_list:
max_batch_size = params_dict["max_batch_size"]
batch_size_list = params_dict["batch_size"]
max_context_len_list = params_dict["max_context_len"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
cache_mem_len = params_dict["cache_mem_len"]
input_dtype_list = params_dict["input_dtype"]
quant_mode_list = params_dict["quant_mode"]
quant_bit_list = params_dict["quant_bit"]
use_offset = params_dict["use_offset"]
for batch_size, max_context_len, quant_mode, quant_bit, dtype in list(product( \
batch_size_list, max_context_len_list, quant_mode_list, quant_bit_list, \
input_dtype_list)):
torch.manual_seed(2766)
torch.mlu.manual_seed(2766)
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \
* head_size >= 2**31 - 1):
print("large tensor is not support on {}, skip".format(mlu_name))
continue
total_heads = head_num_q + head_num_kv * 2
assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
"equal to cache_mem_len."
max_seq_offset = cache_mem_len - max_context_len
# Generates key and cache from context
context_lens = torch.randint(size=[batch_size], low=max_context_len,
high=max_context_len + 1,
dtype=torch.int32, device="mlu")
if use_offset:
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
dtype=torch.int32, device="mlu")
else:
context_paddings = torch.zeros_like(context_lens)
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
total_seqlen = cu_context_lens[-1]
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
context_seq_offset[1:] = cu_context_lens[:-1]
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
key = context[..., head_num_q:head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
# Generates key_cache and value_cache
cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu()
cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size],
dtype=torch.int32, device="mlu")
if quant_bit == 4:
key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
head_num_kv, cache_mem_len, head_size // 2), device="mlu")
value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
head_num_kv, cache_mem_len // 2, head_size), device="mlu")
key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8)
else:
cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size),
low=-128, high=127, dtype=torch.int32, device="mlu")
cache = cache.to(torch.int8)
key_cache, value_cache = cache[[0, 1]]
# Generates key_cache_scale and value_cache_scale
if quant_mode == 0: # quant_mode == 0 is per channel
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
else: # quant_mode != 1 (== 1 for extend) is per head
cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len),
dtype=torch.float, device="mlu")
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_linear_cache,
key, value, key_cache, value_cache,
key_cache_scale, value_cache_scale,
context_lens, max_context_len,
context_seq_offset if use_offset else None,
cache_bs_id, cache_seq_offset, quant_mode,
quant_bit, repeats=args.repeat_times)
content = [f"{max_batch_size}", f"{batch_size}", f"{max_context_len}", f"{head_num_q}",
f"{head_num_kv}", f"{cache_mem_len}", f"{head_size}", f"{dtype}", f"{quant_mode}",
f"{quant_bit}", f"{quant_mode}", f"{use_offset}", f"{hardware_time}",
f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,110 @@
import argparse
import math
import random
import os
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
from itertools import product
from tabulate import tabulate
e2e_time_param_dict_list = [
{"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
"head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "block_size": 16, "head_size": 128,
"input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [8],
"use_offset": True},
]
def main():
parser = argparse.ArgumentParser(description="Benchmark for dequant from paged cache.")
parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
if "MLU3" in torch.mlu.get_device_name():
print("Op dequant_from_paged_cache does not support MLU300 devices.")
return
titles = ["batch_size", "max_context_len", "head_num_q", "head_num_kv",
"cache_mem_len", "block_size", "head_size", "input_dytpe", "quant_mode", "quant_bit",
"use_offset", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch_size_list = params_dict["batch_size"]
max_context_len_list = params_dict["max_context_len"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
block_size = params_dict["block_size"]
head_size = params_dict["head_size"]
cache_mem_len = params_dict["cache_mem_len"]
input_dtype_list = params_dict["input_dtype"]
quant_mode_list = params_dict["quant_mode"]
quant_bit_list = params_dict["quant_bit"]
use_offset = params_dict["use_offset"]
for quant_mode, batch_size, max_context_len, quant_bit, dtype in list(product( \
quant_mode_list, batch_size_list, max_context_len_list, quant_bit_list, \
input_dtype_list)):
torch.manual_seed(2766)
torch.mlu.manual_seed(2766)
total_heads = head_num_q + head_num_kv * 2
assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
"equal to cache_mem_len."
max_seq_offset = cache_mem_len - max_context_len
max_block_num = int(math.ceil(max_context_len / block_size))
total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size
block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size,
max_block_num)
# Generates key and cache from context
context_lens = torch.randint(size=[batch_size], low=max_context_len, high=max_context_len + 1,
dtype=torch.int32, device="mlu")
if use_offset:
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
dtype=torch.int32, device="mlu")
else:
context_paddings = torch.zeros_like(context_lens)
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
total_seqlen = cu_context_lens[-1]
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
context_seq_offset[1:] = cu_context_lens[:-1]
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
key = context[..., head_num_q:head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
# Generates key_cache and value_cache
cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size),
low=-128, high=127, dtype=torch.int32, device="mlu")
cache = cache.to(torch.int8)
key_cache, value_cache = cache[[0, 1]]
# Generates key_cache_scale and value_cache_scale
if quant_mode == 0: # quant_mode == 0 is per channel
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
else: # quant_mode != 1 (== 1 for extend) is per head
cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size),
dtype=torch.float, device="mlu")
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_paged_cache,
key, value, key_cache, value_cache,
key_cache_scale, value_cache_scale,
context_lens, max_context_len,
context_seq_offset if use_offset else None,
block_tables, quant_mode,
quant_bit, repeats=args.repeat_times)
content = [f"{batch_size}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}",
f"{cache_mem_len}", f"{block_size}", f"{head_size}", f"{dtype}", f"{quant_mode}",
f"{quant_bit}", f"{use_offset}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,89 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
"gated_ffn": False, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
inner_size = params_dict["inner_size"]
gated_ffn = params_dict["gated_ffn"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
gate_up_proj_weight, gate_up_proj_bias = None, None
if gated_ffn:
gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.ffn,
input,
up_proj_weight,
up_proj_bias,
down_proj_weight,
down_proj_bias,
gate_up_proj_weight,
gate_up_proj_bias,
act_mode,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,92 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
# for e2e time test
e2e_time_param_dict_list = [{"batch": 1, "seq_q": 32768, "seq_kv": 32768, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 16384, "seq_kv": 16384, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 8192, "seq_kv": 24576, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 4096, "seq_kv": 28672, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_q": 4096, "seq_kv": 32768, "head_num": 8,
"head_num_kv": 1, "head_size": 128, "use_causal": True,
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_q", "seq_kv", "head_num", "head_num_kv", "head_size", "use_causal", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_q = params_dict["seq_q"]
seq_kv = params_dict["seq_kv"]
head_num = params_dict["head_num"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
use_causal = params_dict["use_causal"]
softmax_scale = params_dict["softmax_scale"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
if seq_q == seq_kv:
qkv = torch.randn(batch, seq_q, head_num + 2 * head_num_kv, head_size).to(dtype).to(device)
q = qkv[:, :, : head_num, :]
k = qkv[:, :, head_num : head_num + head_num_kv, :]
v = qkv[:, :, head_num + head_num_kv : head_num + head_num * 2, :]
elif seq_q < seq_kv:
q = torch.randn(batch, seq_q, head_num, head_size).to(device).to(dtype)
kv = torch.randn(batch, seq_kv, head_num_kv * 2, head_size).to(device).to(dtype)
k = kv[:, :, : head_num_kv, :]
v = kv[:, :, head_num_kv :, :]
hardware_time, e2e_time = benchmark_forward(tmo.flash_attention,
q = q,
k = k,
v = v,
out = None,
cu_seq_lens_q = None,
cu_seq_lens_kv = None,
alibi_slope = None,
attn_bias = None,
max_seq_len_q = seq_q,
max_seq_len_kv = seq_kv,
softmax_scale = softmax_scale,
is_causal = use_causal,
window_size_left = -1,
window_size_right = -1,
compute_dtype = dtype,
return_lse = False,
block_tables = None,
k_cache_quant_scale = None,
v_cache_quant_scale = None,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_q}", f"{seq_kv}", f"{head_num}", f"{head_num_kv}", f"{head_size}", f"{use_causal}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,103 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [
{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 8192, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 490, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
{"batch": 525, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
"has_bias": False, "has_quant": True, "dynamic_quant": True,
"input_dtype": torch.bfloat16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "has_residual", "has_bias", "has_quant",
"dynamic_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
has_residual = params_dict["has_residual"]
has_bias = params_dict["has_bias"]
has_quant = params_dict["has_quant"]
dynamic_quant = params_dict["dynamic_quant"]
dtype = params_dict["input_dtype"]
eps = 1e-6
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
x = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
residual, bias, quant_scale = None, None, None
if has_residual:
residual = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
if has_bias:
bias = torch.randn(hidden_size, dtype=dtype, device=device)
if has_quant or dynamic_quant:
quant_scale = torch.randn(hidden_size, dtype=torch.float, device=device)
store_output_before_norm = has_residual
hardware_time, e2e_time = benchmark_forward(tmo.fused_layer_norm,
x,
residual,
gamma,
beta,
bias,
eps,
store_output_before_norm,
quant_scale,
None,
dynamic_quant,
repeats=args.repeat_times)
n = x.nelement()
sizeoft = x.element_size()
io_bytes = (sizeoft + 1) * n + \
(1 + store_output_before_norm) * (sizeoft * n if has_residual else 0) + \
sizeoft * hidden_size * 2 + \
(hidden_size * 4 if has_quant else 0) + \
(batch * seq_len * 4 if dynamic_quant else 0)
io_eff = io_bytes / hardware_time / bd
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{has_residual}", f"{has_bias}",
f"{has_quant}", f"{dynamic_quant}", f"{dtype}",
f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,143 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [
{"batch": 1, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 490, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 525, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "num_expert", "topk", "act_mode", "quant_weight", "dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
inner_size = params_dict["inner_size"]
gated_ffn = params_dict["gated_ffn"]
act_mode = params_dict["act_mode"]
num_expert = params_dict["num_expert"]
start_expert_id = params_dict["start_expert_id"]
expert_size = params_dict["expert_size"]
topk = params_dict["topk"]
has_residual = params_dict["has_residual"]
smooth_quant = params_dict["smooth_quant"]
renormalize = params_dict["renormalize"]
input_dtype_list = params_dict["dtype"]
# print(f"batch:{batch}, seq_len:{seq_len}, hidden_size:{hidden_size}, inner_size:{inner_size}, "
# f"gated_ffn:{gated_ffn}, act_mode:{act_mode}, num_expert:{num_expert}, topk:{topk}")
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
hidden_states = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
router_logit = torch.randn(batch, seq_len, num_expert, device=device, dtype=torch.float32)
if False: # print token_count
softmax = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
topk_logit, expert_id = torch.topk(softmax, k=topk, dim=1)
if renormalize:
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
sorted_expert_id, indices = expert_id.int().flatten().sort()
token_cout = torch.bincount(sorted_expert_id, minlength=num_expert).int()
print(token_cout)
residual = None
if has_residual:
residual = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
weight1 = torch.randn(num_expert, inner_size*(1+gated_ffn), hidden_size, device=device, dtype=dtype)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device=device, dtype=data_type)
weight2 = torch.randn(num_expert, hidden_size, inner_size, device=device, dtype=dtype)
bias2 = None # torch.randn(expert_num, hidden_size, device=device, dtype=data_type)
input_smooth, act_smooth, w1_scale, w2_scale = None, None, None, None
if smooth_quant:
input_smooth = torch.randn(expert_size, hidden_size, device=device, dtype=torch.float32).abs() + 0.1
act_smooth = torch.randn(expert_size, inner_size, device=device, dtype=torch.float32).abs() + 0.1
weight1 = torch.randint(-128, 127, (num_expert, inner_size*(1+gated_ffn), hidden_size)).to(torch.int8).mlu()
weight2 = torch.randint(-128, 127, (num_expert, hidden_size, inner_size)).to(torch.int8).mlu()
w1_scale = torch.randn(expert_size, (1+gated_ffn)*inner_size).to(device).to(torch.float32)
w2_scale = torch.randn(expert_size, hidden_size).to(device).to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.fused_moe,
hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth,
act_smooth,
w1_scale,
w2_scale,
topk,
renormalize,
gated_ffn,
act_mode,
start_expert_id,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{num_expert}", f"{topk}", f"{act_mode}", f"{smooth_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,71 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "head_size": 80, "hidden_size": 1600, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 2048, "head_size": 128, "hidden_size": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 4096, "head_size": 128, "hidden_size": 4096, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6144, "head_size": 128, "hidden_size": 6144, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 6656, "head_size": 128, "hidden_size": 6656, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 8192, "head_size": 128, "hidden_size": 8192, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 12288, "head_size": 128, "hidden_size": 12288, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "input_size": 14336, "head_size": 128, "hidden_size": 14336, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "input_size", "hidden_size", "head_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
input_size = params_dict["input_size"]
hidden_size = params_dict["hidden_size"]
head_size = params_dict["head_size"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, input_size).to(device).to(dtype)
weight = torch.randn(hidden_size * 3, input_size).to(device).to(dtype)
bias = torch.randn(hidden_size * 3).to(device).to(dtype)
weights = torch.chunk(weight, 3)
biases = torch.chunk(bias, 3)
norm_weight = torch.randn(input_size).to(device).to(dtype)
norm_bias = torch.randn(input_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_attention_project,
input,
weights[0],
biases[0],
weights[1],
biases[1],
weights[2],
biases[2],
norm_weight,
norm_bias,
1e-6,
'nthc',
head_size,
False,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{head_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,97 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
"gated_ffn": False, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "residual_is", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
hidden_size = params_dict["hidden_size"]
inner_size = params_dict["inner_size"]
gated_ffn = params_dict["gated_ffn"]
residual_is = params_dict["residual_is"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
layernorm_weight = torch.randn(hidden_size).to(device).to(dtype)
layernorm_bias = torch.randn(hidden_size).to(device).to(dtype)
gate_up_proj_weight, gate_up_proj_bias = None, None
if gated_ffn:
gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_residual_ffn,
input,
up_proj_weight,
up_proj_bias,
down_proj_weight,
down_proj_bias,
gate_up_proj_weight,
gate_up_proj_bias,
layernorm_weight,
layernorm_bias,
1e-6,
act_mode,
residual_is,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{residual_is}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,90 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
from itertools import product
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 16, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 96, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "head_num": 112, "head_size": 128, "has_residual": True,
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["input_shape", "has_residual", "has_bias", "has_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
head_num = params_dict["head_num"]
head_size = params_dict["head_size"]
has_residual = params_dict["has_residual"]
has_quant = params_dict["has_quant"]
has_bias = params_dict["has_bias"]
eps = params_dict["eps"]
dynamic_quant_list = params_dict["dynamic_quant"]
input_dtype_list = params_dict["input_dtype"]
dynamic_quant_list = params_dict["dynamic_quant"]
input_dtype_list = params_dict["input_dtype"]
iters = product(dynamic_quant_list, input_dtype_list)
for dynamic_quant, dtype in iters:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
x = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
beta = torch.randn(head_size).to(dtype).to(device)
gamma = torch.randn(head_size).to(dtype).to(device)
residual, bias, quant_scale = None, None, None
if has_residual:
residual = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
if has_bias:
bias = torch.randn(head_size).to(dtype).to(device)
if has_quant or dynamic_quant:
quant_scale = torch.randn(head_size).to(device)
hardware_time, e2e_time = benchmark_forward(tmo.fused_rms_norm,
x,
residual,
gamma,
beta,
bias,
eps,
False,
quant_scale,
None,
dynamic_quant,
repeats=args.repeat_times)
content = [f"{batch, seq_len, head_num, head_size}", f"{has_residual}", f"{has_bias}", f"{has_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,207 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
"block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
bs = params_dict["batch"]
seq_len = params_dict["seq_len"]
q_heads = params_dict["head_num_q"]
kv_heads = params_dict["head_num_k"]
head_size = params_dict["head_size"]
rope_dim = params_dict["rotary_dim"]
quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
max_decode_len = 0
num_blocks = 0
block_size = 0
if paged_cache:
num_blocks = params_dict["num_blocks"]
block_size = params_dict["block_size"]
else:
max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
discrete_batch = True
max_bs = bs + 1 if discrete_batch else bs
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
input = torch.randn(size=input_shape, dtype=dtype).mlu()
input_ref = input.clone()
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
cache_dtype = dtype
if quant_kv:
k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
cache_dtype = torch.int8
k_scale_ops = 1 / k_scale
v_scale_ops = 1 / v_scale
else:
k_scale = None
v_scale = None
k_scale_ops = None
v_scale_ops = None
if paged_cache:
cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
else:
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
if quant_kv:
cache = (cache - 0.5) * 256
cache = cache.to(cache_dtype)
k_cache = cache[0]
v_cache = cache[1]
cache_bs_id = None
cache_seq_offsets = None
slot_mapping = None
if not paged_cache:
if discrete_batch:
cache_bs_id = random.sample([*range(0, max_bs)], bs)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
dtype=torch.int32, device='mlu')
else:
slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
slot_mapping = torch.IntTensor(slot_mapping).mlu()
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
k_cache_lp = None
v_cache_lp = None
k_scale_lp = None
v_scale_lp = None
cache_bs_id_lp = None
cache_seq_offsets_lp = None
if mixed_cache:
max_decode_len_lp = 1024
k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(k_cache_raw))
k_cache_raw = k_cache_raw * (7 / max_value)
max_value = torch.amax(torch.abs(v_cache_raw))
v_cache_raw = v_cache_raw * (7 / max_value)
k_cache_lp = k_cache_raw.to(torch.int8)
v_cache_lp = v_cache_raw.to(torch.int8)
k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
dtype=torch.int32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
input,
k_cache,
v_cache,
sin_table,
cos_table,
position_id,
gamma,
beta,
k_cache_lp,
v_cache_lp,
cache_bs_id,
cache_seq_offsets,
cache_bs_id_lp,
cache_seq_offsets_lp,
k_scale_ops,
v_scale_ops,
k_scale_lp,
v_scale_lp,
slot_mapping,
None,
1e-5,
repeats=args.repeat_times)
content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,117 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [
{"batch": 1, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 72, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 490, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 490, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 525, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 525, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1024, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 2048, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 2048, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "k": 8192, "n": 1024, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "k": 1024, "n": 8192, "expert_num": 32,
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "seq_len", "k", "n", "expert_num", "topk", "smooth_quant", "dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
k = params_dict["k"]
n = params_dict["n"]
expert_num = params_dict["expert_num"]
topk = params_dict["topk"]
is_quant = params_dict["is_quant"]
input_dtype_list = params_dict["dtype"]
# print(f"batch:{batch}, seq_len:{seq_len}, k:{k}, n:{n}, expert_num:{expert_num}, topk:{topk}, is_quant:{is_quant}")
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
max_m = batch * seq_len
m = batch * seq_len * topk
avg, rem = m // expert_num, m % expert_num
m_list = [avg + (i < rem) for i in range(expert_num)]
token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
if not is_quant:
a = torch.randn(m, k, dtype=dtype, device='mlu')
b = torch.randn(expert_num, n, k, dtype=dtype, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.group_gemm,
a, b, token_count,
None, None, None, None,
max_m,
repeats=args.repeat_times)
else:
a = torch.randint(-128, 127, (m, k)).to(torch.int8).mlu()
b = torch.randint(-128, 127, (expert_num, n, k)).to(torch.int8).mlu()
a_scale = torch.randn(a.size(0), dtype=torch.float32, device='mlu')
b_scale = torch.randn(expert_num, n, dtype=torch.float32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_group_gemm,
a, b, token_count,
None, None, None, None,
a_scale, b_scale, dtype, max_m,
repeats=args.repeat_times)
content = [f"{batch}", f"{seq_len}", f"{k}", f"{n}", f"{expert_num}", f"{topk}",
f"{is_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,75 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"m": 1024, "k": 1600, "n": 6400, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 2048, "n": 8192, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 4096, "n": 11008, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 4096, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 27392, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 6144, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 6656, "n": 17920, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 22016, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 28672, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 8192, "n": 49152, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 12288, "n": 32768, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 14336, "n": 57344, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
has_c = params_dict["has_c"]
has_bias = params_dict["has_bias"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(m, k).to(device).to(dtype)
b = torch.randn(n, k).to(device).to(dtype)
c = None
if has_c:
c = torch.randn(m, n).to(device).to(dtype)
bias = None
if has_bias:
bias = torch.randn(n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.matmul,
a,
b,
bias,
c,
act_mode,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,117 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "inner_size": 1024,
"act_mode": "gelu", "is_gated": True, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 4096, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": False, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 8192, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 32768, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 1, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 16, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 32, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 64, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 128, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 256, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},
{"batch": 512, "seq_len": 1, "inner_size": 1024,
"act_mode": "gelu", "is_gated": False, "has_bias": True,
"is_ep": True, "input_dtype": [torch.bfloat16]},]
def gen_data(num_expert,
total_tokens,
inner_size,
output_stride,
dtype,
is_gated,
has_bias,
is_ep):
ci = inner_size * (1 + is_gated)
input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu')
cusum_token_count, token_count = generate_token_count(num_expert, total_tokens)
output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu')
output.as_strided(output.size(), (output_stride, 1))
start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0
expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert
bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None
return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["input_shape", "act_mode", "is_gated", "has_bias", "expert_num", "start_expert_id",
"expert_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
seq_len = params_dict["seq_len"]
inner_size = params_dict["inner_size"]
act_mode = params_dict["act_mode"]
is_gated = params_dict["is_gated"]
input_dtype_list = params_dict["input_dtype"]
has_bias = params_dict["has_bias"]
is_ep = params_dict["is_ep"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
expert_num = expert_num = random.randint(1, 256)
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(expert_num, batch * seq_len, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
real_bias = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
hardware_time, e2e_time = benchmark_forward(tmo.moe_active,
input,
act_mode,
is_gated,
output,
real_bias,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated) + \
real_bias.element_size() * real_bias.nelement() + \
(cusum_token_count.element_size() * cusum_token_count.nelement()) if has_bias or is_ep else 0
io_eff = io_bytes / hardware_time / bd
content = [f"{batch,seq_len,inner_size}", f"{act_mode}", f"{is_gated}", f"{has_bias}", f"{expert_num}",
f"{start_expert_id}", f"{expert_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,59 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 16, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 128, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
{"batch": 512, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "seq_len", "hidden_size", "expert_num", "input_dtype", "hardware_time(us)",
"e2e_latency(us)", "IO efficiency"]
contents = []
bandwidth = get_band_width()
for param_dict in e2e_time_param_dict_list:
batch = param_dict["batch"]
seq_len = param_dict["seq_len"]
hidden_size = param_dict["hidden_size"]
expert_num = param_dict["expert_num"]
input_dtype = param_dict["input_dtype"]
if input_dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
input_dtype = torch.half
input = torch.randn(batch, seq_len, hidden_size, dtype=input_dtype, device="mlu")
weight = torch.randn(expert_num, hidden_size, dtype=torch.float32, device="mlu")
hardware_time, e2e_time = benchmark_forward(tmo.moe_cast_gating,
input,
weight)
io_bytes = batch * seq_len * hidden_size * input.element_size() + \
expert_num * hidden_size * weight.element_size() + batch * seq_len * expert_num * weight.element_size()
io_coeff = io_bytes / hardware_time / bandwidth
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{expert_num}", f"{input_dtype}",
f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,166 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [
{"num_tokens": 16, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 128, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 490, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 525, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 2048, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 4096, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 8192, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
{"num_tokens": 32768, "num_expert": 32, "topk": 5, "start_expert_id": 0,
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
"dtype": [torch.bfloat16]},
]
def gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
has_bias,
has_residual,
dtype,
device):
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device)
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device)
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device)
bias = None
residual = None
cusum_token_count = None
if has_bias:
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device)
if has_residual:
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
if has_bias or expert_size < num_expert:
cusum_token_count, _ = generate_token_count(num_expert, num_tokens * topk)
cusum_token_count = cusum_token_count.to(device=device)
return input, reduce_weight, gather_ids, residual, bias, cusum_token_count
def get_io_bytes(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
start_expert_id,
has_bias,
has_residual,
dtype,
cusum_token_count,
gather_ids):
io_bytes = 0
dtype_size = 4 if dtype is torch.float32 else 2
if cusum_token_count is not None:
filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \
(gather_ids < cusum_token_count[start_expert_id + expert_size])
filtered_ids = filtered_ids.to(dtype=torch.float32)
io_bytes += torch.sum(filtered_ids).item() * hidden_size * dtype_size
else:
io_bytes += num_tokens * topk * hidden_size * dtype_size
if has_bias:
io_bytes += expert_size * hidden_size * dtype_size
if has_residual:
io_bytes += num_tokens * hidden_size * dtype_size
io_bytes += num_tokens * topk * 4
io_bytes += num_tokens * hidden_size * dtype_size
return io_bytes
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["num_tokens", "num_expert", "topk", "start_expert_id", "expert_size", \
"hidden_size", "has_residual", "dtype", "hardware_time(us)", "e2e_latency(us)", "io_coeff"]
contents = []
bandwidth = get_band_width()
for params_dict in e2e_time_param_dict_list:
num_tokens = params_dict["num_tokens"]
num_expert = params_dict["num_expert"]
topk = params_dict["topk"]
start_expert_id = params_dict["start_expert_id"]
expert_size = params_dict["expert_size"]
has_residual = params_dict["has_residual"]
hidden_size = params_dict["hidden_size"]
dtype_list = params_dict["dtype"]
for dtype in dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
inputs = gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
False,
has_residual,
dtype,
device)
input = inputs[0]
reduce_weight = inputs[1]
gather_ids = inputs[2]
residual = inputs[3]
bias = inputs[4]
cusum_token_count = inputs[5]
io_bytes = get_io_bytes(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
start_expert_id,
False,
has_residual,
dtype,
cusum_token_count,
gather_ids)
hardware_time, e2e_time = benchmark_forward(tmo.moe_combine_result, input, reduce_weight,
gather_ids,residual, cusum_token_count,
start_expert_id, expert_size,
repeats=args.repeat_times)
io_coeff = io_bytes / hardware_time / bandwidth
content = [f"{num_tokens}", f"{num_expert}", f"{topk}", f"{start_expert_id}", \
f"{expert_size}", f"{hidden_size}", f"{has_residual}", f"{dtype}", \
f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,93 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import numpy as np
e2e_time_param_dict_list = [{"token_num": 1, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 16, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 32, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 64, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 128, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 512, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 1024, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 4096, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 8192, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
{"token_num": 32768, "hidden_size": 4096, "expert_num": 32, "topk": 5,
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}]
def gen_tensor(token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype):
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,)).to(torch.int32).to('mlu')
cusum_token_count, _ = generate_token_count(expert_num, token_num * topk)
cusum_token_count = cusum_token_count.to('mlu')
use_all_experts = expert_num == expert_size
if use_all_experts:
cusum_token_count = None
real_token_count = token_num * topk
else:
real_token_count = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
return input, gather_idx, cusum_token_count, real_token_count
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["token_num", "hidden_size", "expert_num", "topk", "start_expert_id", "expert_size", "input_dtype",
"hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
token_num = params_dict["token_num"]
hidden_size = params_dict["hidden_size"]
expert_num = params_dict["expert_num"]
topk = params_dict["topk"]
start_expert_id = params_dict["start_expert_id"]
expert_size = params_dict["expert_size"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input, gather_idx, cusum_token_count, real_token_count = \
gen_tensor(token_num, hidden_size, expert_num,topk, start_expert_id, expert_size, dtype)
hardware_time, e2e_time = benchmark_forward(tmo.moe_expand_input,
input,
gather_idx,
cusum_token_count,
start_expert_id,
expert_size,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() + \
gather_idx.element_size() * gather_idx.nelement() + \
(cusum_token_count.element_size() * cusum_token_count.nelement() if cusum_token_count is not None else 0) + \
real_token_count * input.element_size()
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{hidden_size}", f"{expert_num}", f"{topk}", f"{start_expert_id}", f"{expert_size}",
f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,69 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"token_num": 1, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 16, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 32, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 64, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 512, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 1024, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 4096, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 8192, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 32767, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
{"token_num": 1, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 16, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 32, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 64, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 512, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 1024, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 4096, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 8192, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
{"token_num": 32767, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["token_num", "expert_num", "topk", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
token_num = params_dict["token_num"]
expert_num = params_dict["expert_num"]
topk = params_dict["topk"]
dtype = params_dict["input_dtype"]
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
gather_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
combine_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
token_count = torch.empty((expert_num), dtype=dtype, device='mlu')
cusum_token_count = torch.empty((expert_num + 1), dtype=dtype, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.moe_gen_idx,
expert_id,
expert_num,
repeats=args.repeat_times)
io_bytes = expert_id.element_size() * expert_id.nelement() + \
gather_idx.element_size() * gather_idx.nelement() + \
combine_idx.element_size() * combine_idx.nelement() + \
token_count.element_size() * token_count.nelement() + \
cusum_token_count.element_size() * cusum_token_count.nelement()
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{expert_num}", f"{topk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,114 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
params_dict = [
{"token_num": 1, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 16, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 128, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 490, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 512, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 525, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 2048, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 4096, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 8192, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 32768, "hidden_size": 8192, "expert_num": 32, "topk": 5,
"has_gather_idx": True, "dtype": torch.bfloat16},
{"token_num": 1, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 16, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 128, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 490, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 512, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 525, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 2048, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 4096, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 8192, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
{"token_num": 32768, "hidden_size": 1024, "expert_num": 32, "topk": 5,
"has_gather_idx": False, "dtype": torch.bfloat16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["token_num", "hidden_size", "expert_num", "topk", "has_gather_idx", "dtype",
"hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for param in params_dict:
token_num, hidden_size, expert_num, topk, has_gather_idx, dtype = param.values()
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
if "MLU3" in torch.mlu.get_device_name():
has_gather_idx = False
expand_token_num = token_num * topk
input_shape = (token_num if has_gather_idx else expand_token_num, hidden_size)
input = torch.randn(input_shape).to(device).to(dtype)
scale = torch.randn(expert_num, hidden_size).to(device).to(torch.float32)
avg, rem = expand_token_num // expert_num, expand_token_num % expert_num
m_list = [avg + (i < rem) for i in range(expert_num)]
token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
if has_gather_idx:
gather_idx = torch.arange(0, token_num).repeat([topk])
gather_idx = gather_idx[torch.randperm(gather_idx.size(0))].to(torch.int32).mlu()
else:
gather_idx = None
hardware_time, e2e_time = benchmark_forward(tmo.moe_quantize,
input,
scale,
None,
token_count,
gather_idx,
None,
None,
None,
True,
repeats=args.repeat_times)
expand_num = topk if has_gather_idx else 1
io_bytes = (input.element_size() + 1) * input.nelement() * expand_num
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{hidden_size}", f"{expert_num}",
f"{topk}", f"{has_gather_idx}", f"{dtype}",
f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,87 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 32, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 72, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 8192, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 32768, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 2, "seq_len": 16, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 2, "seq_len": 36, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 8, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 16, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 4, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 2, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 16, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 16, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 64, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 1024, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 2048, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 8192, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
{"num_batch": 1, "seq_len": 32768, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["num_batch", "seq_len", "num_expert", "topk", "num_expert_group", "topk_group", "normalize", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bd = get_band_width()
for params_dict in e2e_time_param_dict_list:
num_batch = params_dict["num_batch"]
seq_len = params_dict["seq_len"]
num_expert = params_dict["num_expert"]
topk = params_dict["topk"]
num_expert_group = params_dict["num_expert_group"]
topk_group = params_dict["topk_group"]
normalize = params_dict["normalize"]
dtype = params_dict["input_dtype"]
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
input = torch.randn(num_batch, seq_len, num_expert, dtype=dtype, device='mlu')
mask = torch.randint(0, 2, (1, seq_len, num_expert), dtype = dtype, device='mlu')
if num_expert_group > 1:
mask = None
normed_by = "softmax_logit"
reduce_weight = torch.empty(num_batch, topk, dtype=torch.float, device='mlu')
expert_id = torch.empty(num_batch, topk, dtype=torch.int32, device='mlu')
hardware_time, e2e_time = benchmark_forward(tmo.moe_softmax_topk,
input,
topk,
normalize,
num_expert_group,
topk_group,
mask,
normed_by,
repeats=args.repeat_times)
io_bytes = input.element_size() * input.nelement() + \
reduce_weight.element_size() * reduce_weight.nelement() + \
expert_id.element_size() * expert_id.nelement()
io_eff = io_bytes / hardware_time / bd
content = [f"{num_batch}", f"{seq_len}", f"{num_expert}", f"{topk}", f"{num_expert_group}", f"{topk_group}", f"{normalize}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,145 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"}
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "quantize_mode", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
max_batch = params_dict["max_batch"]
batch = params_dict["batch"]
cache_mem_len = params_dict["cache_mem_len"]
max_context_len = params_dict["max_context_len"]
max_seq_offset = params_dict["max_seq_offset"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
packed = params_dict["packed"]
quantize_mode = params_dict["quantize_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
context_lens = torch.randint(size=(batch, ), low=max_context_len,
high=max_context_len+1,
dtype=torch.int32, device='mlu')
# max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
high=(cache_mem_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
total_heads = head_num_q + 2 * head_num_kv
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
else:
context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
context = context.to(dtype)
cache = cache.to(dtype)
key = context[..., head_num_q : head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
key_cache = cache[0]
value_cache = cache[1]
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
key_cache = (key_cache - 0.5) * 256
value_cache = (value_cache - 0.5) * 256
key_cache = key_cache.to(torch.int8)
value_cache = value_cache.to(torch.int8)
if packed > 0:
if quantize_mode == "per_channel":
key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
cu_context_lens, max_context_len, 0,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
elif quantize_mode == "per_head":
key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
cu_context_lens, max_context_len, 1,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
else:
if quantize_mode == "per_channel":
key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
context_lens, max_context_len, 0,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
elif quantize_mode == "per_head":
key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_quantize_scale,
value_cache_quantize_scale,
context_lens, max_context_len, 1,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{quantize_mode}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,73 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from itertools import product
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [
{"token_num": [1024, 2048, 3072, 4096], "head_num_kv": 1, "head_size": 128, "block_size": 16,
"input_dtype": [torch.float16, torch.bfloat16]},
{"token_num": [1024 * 32, 2048 * 32, 3072 * 32, 4096 * 32], "head_num_kv": 1, "head_size": 128, "block_size": 16,
"input_dtype": [torch.float16, torch.bfloat16]},
{"token_num": [1024 * 64, 2048 * 64, 3072 * 64, 4096 * 64], "head_num_kv": 1, "head_size": 128, "block_size": 16,
"input_dtype": [torch.float16, torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["token_num", "head_num_kv", "head_size", "block_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
if "MLU3" in torch.mlu.get_device_name():
print("Op offline_quant_to_paged_cache does not support MLU300 devices.")
return
for params_dict in e2e_time_param_dict_list:
token_num_list = params_dict["token_num"]
# block_num = params_dict["block_num"]
block_size = params_dict["block_size"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
input_dtype_list = params_dict["input_dtype"]
for token_num, dtype in product(token_num_list, input_dtype_list):
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
block_num = (token_num + block_size - 1) // block_size
key = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
value = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
key_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
value_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
num_slots = block_num * block_size
slot_mapping = random.sample(range(num_slots), token_num)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu")
key_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
value_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_paged_cache,
key, value,
key_cache_scale, value_cache_scale,
slot_mapping,
key_cache, value_cache,
repeats=args.repeat_times)
content = [f"{token_num}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,61 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
import csv
params_dict = {
"token_num": [n * 5 for n in [1, 72, 512, 1024, 4096, 32768]],
"hidden_size": [1024, 8192],
"input_dtype": [torch.float16]
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
params_list = product(params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
bd = get_band_width()
for params in params_list:
token_num, hidden_size = params[0], params[1]
input_shape = (token_num, hidden_size)
smooth_shape = (hidden_size)
dtype = params[2]
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(input_shape).to(device).to(dtype)
smooth = torch.randn(smooth_shape).to(device).to(torch.float32)
zero = None
token_count = None
hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
input,
smooth,
zero,
token_count,
repeats=args.repeat_times)
io_bytes = (input.element_size() + 1) * input.nelement() + \
smooth.element_size() * smooth.nelement() + \
token_num * 4
io_eff = io_bytes / hardware_time / bd
content = [f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,46 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"input_shape": [100, 100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
{"input_shape": [100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
{"input_shape": [50, 50, 50], "input_dtype": [torch.float16, torch.bfloat16]},
{"input_shape": [1, 100, 1000], "input_dtype": [torch.float16, torch.bfloat16]}
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["input_shape", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
input_shape = params_dict["input_shape"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input = torch.randn(input_shape).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.preload,
input,
input.element_size() * input.numel(),
repeats=args.repeat_times)
content = [f"{input_shape}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,201 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import random
import os
e2e_time_param_dict_list = [
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "max_context_len", "head_num_kv", "head_size", "packed", "input_dtype",
"quant_bit", "group_size", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
bandwidth = get_band_width()
for params_dict in e2e_time_param_dict_list:
max_batch = params_dict["max_batch"]
batch = params_dict["batch"]
cache_mem_len = params_dict["cache_mem_len"]
max_context_len = params_dict["max_context_len"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
packed = params_dict["packed"]
input_dtype_list = params_dict["input_dtype"]
quant_bit = params_dict["quant_bit"]
group_size = params_dict["group_size"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
context_lens = torch.tensor([max_context_len] * batch).to(torch.int32).mlu()
context_seq_offsets = torch.zeros(batch, dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch, ),
low=0,
high = 1 if max_context_len > 1 else cache_mem_len,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
key = torch.randn((total_seqlen, head_num_kv, head_size),
dtype=torch.float, device='mlu')
value = torch.randn((total_seqlen, head_num_kv, head_size),
dtype=torch.float, device='mlu')
else:
key = torch.randn((batch, max_context_len, head_num_kv, head_size),
dtype=torch.float, device='mlu')
value = torch.randn((batch, max_context_len, head_num_kv, head_size),
dtype=torch.float, device='mlu')
key = key.to(dtype)
value = value.to(dtype)
if quant_bit == 8 and group_size == head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
if quant_bit == 8 and group_size != head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
if quant_bit == 4 and group_size == head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
if quant_bit == 4 and group_size != head_size:
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
if packed > 0:
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_scale,
value_cache_scale,
cu_context_lens, max_context_len,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
quant_bit,
repeats=args.repeat_times)
else:
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
key, value,
key_cache, value_cache,
key_cache_scale,
value_cache_scale,
context_lens, max_context_len,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
quant_bit,
repeats=args.repeat_times)
io_bytes = key.nelement() * (key.element_size() + 1) * 2
io_eff = io_bytes / hardware_time / bandwidth
content = [f"{batch}", f"{max_context_len}", f"{head_num_kv}", f"{head_size}", f"{packed}",
f"{dtype}", f"{quant_bit}", f"{group_size}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,61 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import random
params_dict = {"dynamic": [True],
"token_num": [1, 72, 490, 512, 525, 1024, 4096, 8192, 32768],
"hidden_size": [8192, 1024],
"input_dtype": [torch.bfloat16]}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["dynamic", "token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
contents = []
params_list = product(params_dict["dynamic"], params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
bd = get_band_width()
for param in params_list:
dynamic, token_num, hidden_size, dtype = param[0], param[1], param[2], param[3]
input_shape = (token_num, hidden_size)
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
dtype = torch.half
input = torch.randn(input_shape).to(device).to(dtype)
scale = torch.randn(input_shape[-1]).to(device).to(torch.float32)
zero = None
if dynamic:
hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
input,
scale,
zero,
None,
repeats=args.repeat_times)
else:
hardware_time, e2e_time = benchmark_forward(tmo.quantize,
input,
scale,
zero,
repeats=args.repeat_times)
io_bytes = (input.element_size() + 1) * input.nelement() + scale.element_size() * scale.nelement()
io_eff = io_bytes / hardware_time / bd
content = [f"{dynamic}", f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,109 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
"packed": False, "input_dtype": [torch.float16, torch.bfloat16]}
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
max_batch = params_dict["max_batch"]
batch = params_dict["batch"]
cache_mem_len = params_dict["cache_mem_len"]
max_context_len = params_dict["max_context_len"]
max_seq_offset = params_dict["max_seq_offset"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
packed = params_dict["packed"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
context_lens = torch.randint(size=(batch, ), low=max_context_len,
high=max_context_len+1,
dtype=torch.int32, device='mlu')
# max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
high=(cache_mem_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
total_heads = head_num_q + 2 * head_num_kv
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
else:
context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
context = context.to(dtype)
cache = cache.to(dtype)
key = context[..., head_num_q : head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
key_cache = cache[0]
value_cache = cache[1]
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
if packed > 0:
hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
key, value,
key_cache, value_cache,
cu_context_lens, max_context_len,
packed > 0, None,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
else:
hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
key, value,
key_cache, value_cache,
context_lens, max_context_len,
packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,76 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
"head_num_kv": 32, "head_size": 128, "quantize": True, "input_dtype": [torch.float16, torch.bfloat16]},
{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
"head_num_kv": 32, "head_size": 128, "quantize": False, "input_dtype": [torch.float16, torch.bfloat16]}
]
def main():
if 'MLU3' in torch.mlu.get_device_name():
exit()
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
titles = ["num_tokens", "num_block", "block_size", "head_num_q", "head_num_kv", "head_size", "input_dytpe", "quantize", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
num_tokens = params_dict["num_tokens"]
num_blocks = params_dict["num_block"]
block_size = params_dict["block_size"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
quantize = params_dict["quantize"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
qkv = torch.randn(num_tokens, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
key = qkv[:, head_num_q : head_num_q + head_num_kv, :]
value = qkv[:, head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
key_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
value_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
num_slots = num_blocks * block_size
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1
if not quantize:
hardware_time, e2e_time = benchmark_forward(tmo.reshape_paged_cache,
key, value,
key_cache, value_cache,
slot_mapping,
repeats=args.repeat_times)
else:
k_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
v_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_paged_cache,
key, value,
key_cache, value_cache,
k_cache_quant_scale,
v_cache_quant_scale,
slot_mapping,
repeats=args.repeat_times)
content = [f"{num_tokens}", f"{num_blocks}", f"{block_size}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{dtype}", f"{quantize}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,116 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import math
import random
e2e_time_param_dict_list = [
{"batch": 16, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 128, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 512, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 16, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
{"batch": 128, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "max_seq_len", "head_num_q", "head_num_kv", "head_size", "block_size", "alibi_bias", "kv_cache_dtype", "use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
max_seq_len = params_dict["max_seq_len"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
block_size = params_dict["block_size"]
alibi_bias = params_dict["alibi_bias"]
kv_cache_dtype = params_dict["kv_cache_dtype"]
use_paged_attn = params_dict["use_paged_attn"]
input_dtype_list = params_dict["input_dtype"]
is_pertoken = params_dict["is_pertoken"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
input_q = input_qkv[..., 0 : head_num_q, :]
context_lens = torch.randint(max_seq_len, max_seq_len + 1, (batch, ), dtype=torch.int32).to(device)
max_context_len = int(max(context_lens))
if use_paged_attn:
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
continue
block_size = 16
else:
block_size = max_seq_len + 512
num_blocks = batch * ((max_seq_len + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape = (num_blocks, head_num_kv, block_size, head_size)
scale_shape = (num_blocks, head_num_kv, block_size) if is_pertoken else (head_num_kv, head_size)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
if kv_cache_dtype is not torch.int8:
key_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
value_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
key_cache_scale = None
value_cache_scale = None
else:
key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
alibi_slopes = None
if alibi_bias:
alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
hardware_time, e2e_time = benchmark_forward(tmo.single_query_cached_kv_attn,
input_q,
key_cache,
value_cache,
None,
block_tables,
context_lens,
key_cache_scale,
value_cache_scale,
alibi_slopes,
max_context_len,
-1,
-1,
softmax_scale,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_seq_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{alibi_bias}", f"{kv_cache_dtype}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,131 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import *
import argparse
from tabulate import tabulate
import os
import math
import random
e2e_time_param_dict_list = [{"batch": 16, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 512, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": True, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 16, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
{"batch": 128, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}]
def gen_cache(batch, num_kv_heads, head_size, is_pagedattn, max_context_len, data_type, quant_bit, quant_mode):
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
context_lens = torch.randint(max_context_len, max_context_len + 1, (batch, ), dtype=torch.int32).mlu()
block_size = 16
if is_pagedattn is False:
block_size = max_context_len
num_blocks = (int)(batch * ((max_context_len + block_size - 1)/ block_size))
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if quant_mode == "per_token":
scale_shape = (num_blocks, num_kv_heads, block_size, 1)
else: # per channel
scale_shape = (num_kv_heads, head_size)
if quant_bit == 4:
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2)
cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size)
key_cache = torch.zeros(cache_shape_k_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape_v_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
elif quant_bit == 8:
key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
elif quant_bit == -1:
key_cache = torch.randn(cache_shape, dtype=data_type).mlu()
value_cache = torch.randn(cache_shape, dtype=data_type).mlu()
key_cache_scale = None
value_cache_scale = None
else:
print("!!!!!!!!!!!gen case error, quant_bit must be in {-1, 4, 8}")
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
return key_cache, value_cache, key_cache_scale, value_cache_scale, context_lens, block_tables
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "max_seq_len_lp", "max_seq_len_hp", "head_num_q", "head_num_kv", "head_size", "alibi_bias", "quant_bit_lp", "quant_bit_hp","use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
max_seq_len_lp = params_dict["max_seq_len_lp"]
max_seq_len_hp = params_dict["max_seq_len_hp"]
head_num_q = params_dict["head_num_q"]
head_num_kv = params_dict["head_num_kv"]
head_size = params_dict["head_size"]
alibi_bias = params_dict["alibi_bias"]
quant_bit_lp = params_dict["quant_bit_lp"]
quant_bit_hp = params_dict["quant_bit_hp"]
use_paged_attn = params_dict["use_paged_attn"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
input_q = input_qkv[..., 0 : head_num_q, :]
params_lp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_lp, dtype, quant_bit_lp, "per_token")
params_hp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_hp, dtype, quant_bit_hp, "per_channel")
key_cache_lp, value_cache_lp, key_cache_scale_lp, value_cache_scale_lp, context_lens_lp, block_tables_lp = params_lp
key_cache_hp, value_cache_hp, key_cache_scale_hp, value_cache_scale_hp, context_lens_hp, block_tables_hp = params_hp
alibi_slopes = None
if alibi_bias:
alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
hardware_time, e2e_time = benchmark_forward(tmo.single_query_mixed_cached_kv_attn,
input_q,
key_cache_lp, value_cache_lp,
key_cache_hp, value_cache_hp,
None, #output
block_tables_lp, block_tables_hp,
context_lens_lp, context_lens_hp,
key_cache_scale_lp, value_cache_scale_lp,
key_cache_scale_hp, value_cache_scale_hp,
alibi_slopes,
max_seq_len_lp, max_seq_len_hp,
softmax_scale, True,
quant_bit_lp, quant_bit_hp,
repeats=args.repeat_times)
content = [f"{batch}", f"{max_seq_len_lp}", f"{max_seq_len_hp}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{alibi_bias}", f"{quant_bit_lp}", f"{quant_bit_hp}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,69 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
"act_mode": "none", "output_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
"act_mode": "silu", "output_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "output_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
has_c = params_dict["has_c"]
has_bias = params_dict["has_bias"]
act_mode = params_dict["act_mode"]
output_dtype_list = params_dict["output_dtype"]
for dtype in output_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(m, k).to(device).to(torch.int8)
b = torch.randn(n, k).to(device).to(torch.int8)
a_scale = torch.randn(m).to(device)
b_scale = torch.randn(n).to(device)
c = None
if has_c:
c = torch.randn(m, n).to(device).to(dtype)
bias = None
if has_bias:
bias = torch.randn(n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_matmul,
a,
a_scale,
b,
b_scale,
dtype,
bias,
c,
act_mode,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,15 @@
#!/bin/bash
LOG_PATH=${LOG_PATH:-.}
files=($(ls benchmark_*.py))
for file in "${files[@]}"; do
echo "test ${file}..."
op_name=$(basename "$file" .py)
python "$file" > ${LOG_PATH}/${op_name}.log 2>&1
ret_tmp=$?
cat ${LOG_PATH}/${op_name}.log
if [ $ret_tmp != 0 ]; then
echo "${sc} test failed..."
exit $ret_tmp
fi
done

View File

@@ -0,0 +1,99 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv
import argparse
from tabulate import tabulate
import os
import random
e2e_time_param_dict_list = [{"batch": 16, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 128, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 512, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 1024, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
"dtype": [torch.float16], "pack": False},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 1024, "max_seq_len": 1024,
"dtype": [torch.float16], "pack": True},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 2048, "max_seq_len": 2048,
"dtype": [torch.float16], "pack": True},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 8192, "max_seq_len": 8192,
"dtype": [torch.float16], "pack": True},
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 32768, "max_seq_len": 32768,
"dtype": [torch.float16], "pack": True},]
def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack):
if not pack:
out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype)
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype)
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
seq_offset = None
cu_seqs = None
block_cu_seqs = None
else:
seq_lens = torch.randint(low=max_seq_len, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32)
block_seq_lens = torch.randint(low=block_seq_len, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32)
block_seq_lens = torch.minimum(seq_lens, block_seq_lens)
seq_offset = torch.zeros_like(seq_lens)
for i in range(batch):
seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32)
seq_offset = seq_offset.mlu()
cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu()
block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu()
total_seqs = torch.sum(seq_lens)
block_total_seqs = torch.sum(block_seq_lens)
out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype)
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype)
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["batch", "head_num", "head_size", "block_seq_len", "max_seq_len", "dtype", "pack", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
batch = params_dict["batch"]
head_num = params_dict["head_num"]
head_size = params_dict["head_size"]
block_seq_len = params_dict["block_seq_len"]
max_seq_len = params_dict["max_seq_len"]
dtype_list = params_dict["dtype"]
pack = params_dict["pack"]
for dtype in dtype_list:
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
hardware_time, e2e_time = benchmark_forward(tmo.update_out_and_lse,
out,
lse,
block_out,
block_lse,
seq_offset,
cu_seqs,
block_cu_seqs,
repeats=args.repeat_times)
content = [f"{batch}", f"{head_num}", f"{head_size}", f"{block_seq_len}", f"{max_seq_len}", f"{dtype}", f"{pack}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,70 @@
import torch
import torch_mlu
import torch_mlu_ops as tmo
from common import benchmark_forward, save_to_csv, save_to_csv
import argparse
from tabulate import tabulate
import os
e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
"act_mode": "none", "quant_bit": 8, "input_dtype": [torch.float16, torch.bfloat16]},
{"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
"act_mode": "silu", "quant_bit": 4, "input_dtype": [torch.float16, torch.bfloat16]}]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
args = parser.parse_args()
device = 'mlu'
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "quant_bit", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
contents = []
for params_dict in e2e_time_param_dict_list:
m = params_dict["m"]
k = params_dict["k"]
n = params_dict["n"]
quant_bit = params_dict["quant_bit"]
has_c = params_dict["has_c"]
has_bias = params_dict["has_bias"]
act_mode = params_dict["act_mode"]
input_dtype_list = params_dict["input_dtype"]
for dtype in input_dtype_list:
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
continue
a = torch.randn(m, k).to(device).to(dtype)
b = torch.randn(n, k if quant_bit == 8 else k//2).to(device).to(torch.int8)
scale = torch.randn(n).to(device)
zero = None
c = None
if has_c:
c = torch.randn(m, n).to(device).to(dtype)
bias = None
if has_bias:
bias = torch.randn(n).to(device).to(dtype)
hardware_time, e2e_time = benchmark_forward(tmo.weight_only_quant_matmul,
a,
b,
scale,
zero,
bias,
c,
act_mode,
quant_bit,
1.0,
1.0,
repeats=args.repeat_times)
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{quant_bit}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
contents.append(content)
table = [titles] + contents
print(tabulate(table, headers="firstrow", tablefmt="grid"))
if args.csv:
current_file_path = __file__
_, file_name = os.path.split(current_file_path)
save_to_csv(table, args.o, file_name)
if __name__=="__main__":
main()

View File

@@ -0,0 +1,65 @@
import torch
import torch_mlu
import time
from pathlib import Path
import csv
import os
import subprocess
from itertools import product
def benchmark_forward(fn, *inputs, repeats=1, **kwinputs):
notify_start = torch.mlu.Event(enable_timing=True)
notify_end = torch.mlu.Event(enable_timing=True)
notify_start.record()
t0 = time.perf_counter()
for _ in range(repeats):
fn(*inputs, **kwinputs)
notify_end.record()
notify_end.synchronize()
total_e2e_time = time.perf_counter() - t0
average_e2e_time = total_e2e_time / repeats * 1e6
total_hardware_time = notify_start.hardware_time(notify_end)
average_hardware_time = total_hardware_time / repeats
return average_hardware_time, average_e2e_time
def save_to_csv(table, file_path, file_name):
file_name_without_ext, _ = os.path.splitext(file_name)
new_file_name = file_name_without_ext + '.csv'
if file_path is None:
file_path = './'
path = Path(file_path)
if path.suffix:
directory = path.parent
filename = path.name
else:
directory = path
filename = new_file_name
if not directory.exists():
directory.mkdir(parents=True, exist_ok=True)
full_path = directory / filename
if not full_path.exists():
full_path.touch()
with open(full_path, mode="w", newline="") as file:
writer = csv.writer(file)
writer.writerows(table)
print(f"output saved at: {full_path}")
def get_band_width(card_id: int = 0):
cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2"
res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert res.returncode == 0, "Failed to get BandWidth."
bd = int(res.stdout.decode().strip())
return bd
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]

View File

@@ -0,0 +1,8 @@
TMO_VERSION=1.3.2-1
CNCL_VERSION=1.24.1-1
CNNL_VERSION=1.28.3-1
CNNLEXTRA_VERSION=1.12.3-1
CNTOOLKIT_VERSION=3.15.6-1
MLUOPS_VERSION=1.4.2-1
TORCH_VERSION=1.24.1-1
TRITON_VERSION=1.3.1-1

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifdef __GNUC__
#include "stack_exception.h"
#include <cxxabi.h>
#include <dlfcn.h>
#include <execinfo.h>
#include <stdio.h>
#include <stdlib.h>
#define MAX_DEPTH 32
namespace tmo {
namespace stack_exception {
call_stack::call_stack(const size_t num_discard /*= 0*/) {
using namespace abi;
// retrieve call-stack
void *trace[MAX_DEPTH];
int stack_depth = backtrace(trace, MAX_DEPTH);
for (int i = num_discard + 1; i < stack_depth; i++) {
Dl_info dlinfo;
if (!dladdr(trace[i], &dlinfo)) break;
const char *symname = dlinfo.dli_sname;
int status;
char *demangled = abi::__cxa_demangle(symname, NULL, 0, &status);
if (status == 0 && demangled) symname = demangled;
// printf("entry: %s, %s\n", dlinfo.dli_fname,symname);
// store entry to stack
if (dlinfo.dli_fname && symname) {
entry e;
e.file = dlinfo.dli_fname;
e.line = 0; // unsupported
e.function = symname;
stack.push_back(e);
} else {
break; // skip last entries below main
}
if (demangled) free(demangled);
}
}
call_stack::~call_stack() throw() {
// automatic cleanup
}
} // namespace stack_exception
} // namespace tmo
#endif // __GNUC__

View File

@@ -0,0 +1,127 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_COMMON_STACK_EXCEPTION_H_
#define CSRC_COMMON_STACK_EXCEPTION_H_
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
namespace tmo {
namespace stack_exception {
/** Call-stack entry datastructure. */
struct entry {
/** Default constructor that clears all fields. */
entry() : line(0) {}
std::string file; ///< filename
size_t line; ///< line number
std::string function; ///< name of function or method
/** Serialize entry into a text string. */
std::string to_string() const {
std::ostringstream os;
os << file;
if (line > 0) {
os << ":" << line;
}
os << " @" << function;
return os.str();
}
};
/** Stack-trace base class, for retrieving the current call-stack. */
class call_stack {
public:
/** Stack-trace consructor.
\param num_discard - number of stack entries to discard at the top. */
call_stack(const size_t num_discard = 0);
virtual ~call_stack() throw();
/** Serializes the entire call-stack into a text string. */
std::string to_string() const {
std::ostringstream os;
for (size_t i = 0; i < stack.size(); i++)
os << stack[i].to_string() << std::endl;
return os.str();
}
/** Call stack. */
std::vector<entry> stack;
};
/** Abstract base-class for all stack-augmented exception classes.
* Enables catching of all stack-augmented exception classes. */
class stack_exception_base : public call_stack {
public:
stack_exception_base(const bool _show_stack) : call_stack(2), show_stack(_show_stack) {}
virtual ~stack_exception_base() throw() {}
virtual const char *what() const throw() = 0;
/// flag to indicate if stack-trace is included in what() messages
bool show_stack;
};
/** Template for stack-augmented exception classes. */
template <class T>
class stack_exception : public T, public stack_exception_base {
public:
stack_exception(const std::string &msg) : T(msg), stack_exception_base(true) {}
virtual ~stack_exception() throw() {}
stack_exception(const char *file, int line, const char *pretty_function, const std::string &msg)
: T(msg), stack_exception_base(true) {
entry e;
e.file = file;
e.line = line;
e.function = pretty_function;
stack.insert(stack.begin(), e);
}
virtual const char *what() const throw() {
if (show_stack) {
// concatenate message with stack trace
buffer = "[" + std::string(T::what()) + "]\n" + stack_exception::to_string();
return buffer.c_str();
} else {
return T::what();
}
}
private:
mutable std::string buffer;
};
/** Stack-augmented exception classes for all std::exception classes. */
// typedef stack_exception<std::runtime_error> TmoException;
// typedef stack_exception<std::range_error> stack_range_error;
// typedef stack_exception<std::overflow_error> stack_overflow_error;
// typedef stack_exception<std::underflow_error> stack_underflow_error;
// typedef stack_exception<std::logic_error> stack_logic_error;
// typedef stack_exception<std::domain_error> stack_domain_error;
// typedef stack_exception<std::invalid_argument> stack_invalid_argument;
// typedef stack_exception<std::length_error> stack_length_error;
// typedef stack_exception<std::out_of_range> stack_out_of_range;
} // namespace stack_exception
class _TmoException : public stack_exception::stack_exception<std::runtime_error> {
public:
_TmoException(const std::string &msg) : stack_exception<std::runtime_error>(msg) {}
_TmoException(const char *file, int line, const char *pretty_function, const std::string &msg)
: stack_exception<std::runtime_error>(file, line, pretty_function, msg) {}
};
#define TmoException(msg) tmo::_TmoException(__FILE__, __LINE__, __PRETTY_FUNCTION__, msg)
} // namespace tmo
#endif // CSRC_COMMON_STACK_EXCEPTION_H_

View File

@@ -0,0 +1,293 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_COMMON_UTILS_H_
#define CSRC_COMMON_UTILS_H_
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <exception>
#include <functional>
#include <future> // NOLINT
#include <initializer_list>
#include <iostream>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <thread> // NOLINT
#include <tuple>
#include <vector>
#include "cn_api.h"
#include "cnnl.h"
#include "cnnl_extra.h"
#include "cnrt.h"
#include "stack_exception.h"
namespace tmo {
inline cnnlQuantizeLayout_t strToQuantizeLayout(std::string param) {
static std::map<std::string, cnnlQuantizeLayout_t> quantize_layout_map = {
{"quantize_none", CNNL_QUANTIZE_NONE},
{"quantize_per_tensor", CNNL_QUANTIZE_PER_TENSOR},
{"quantize_per_channel", CNNL_QUANTIZE_PER_CHANNEL},
{"quantize_per_token", CNNL_QUANTIZE_PER_TOKEN},
{"quantize_group_wise", CNNL_QUANTIZE_GROUP_WISE}};
return quantize_layout_map[param];
}
inline cnnlActivationMode_t strToActivationMode(std::string param) {
static std::map<std::string, cnnlActivationMode_t> act_mode_map = {
{"gelu", CNNL_ACTIVATION_GELU},
{"relu", CNNL_ACTIVATION_RELU},
{"sigmoid", CNNL_ACTIVATION_SIGMOID},
{"silu", CNNL_ACTIVATION_SWISH},
{"none", CNNL_ACTIVATION_IDENTITY}};
return act_mode_map[param];
}
inline cnnlLLMQuantAlgo_t strToQuantizeAlgo(std::string param) {
static std::map<std::string, cnnlLLMQuantAlgo_t> quant_algo_map = {
{"weight_only", CNNL_WEIGHT_ONLY},
{"smooth_quant", CNNL_SMOOTH_QUANT},
{"none", CNNL_NO_QUANT}};
return quant_algo_map[param];
}
namespace lnres {
namespace internal {
using LnresEnum = cnnlTransformerLayernormResidualStructure_t;
struct Helper {
int layernorm_position; // 0: no layernorm, 1: pre layernorm, 2: post layernorm
int residual_position; // 0: no residual, 1: layernorm inside residual, 2: layernorm outside
// residual
constexpr Helper(cnnlTransformerLayernormResidualStructure_t mode);
constexpr Helper(int layernorm_position, int residual_position)
: layernorm_position(layernorm_position), residual_position(residual_position) {}
constexpr bool operator==(const Helper &other) const {
return layernorm_position == other.layernorm_position &&
residual_position == other.residual_position;
}
constexpr operator cnnlTransformerLayernormResidualStructure_t() const;
};
constexpr int NO = 0;
constexpr int PRE = 1;
constexpr int POST = 2;
constexpr int CONTAIN = 1;
constexpr int EXCLUDE = 2;
using TPair = std::pair<Helper, LnresEnum>;
constexpr std::array<TPair, 9> pairs = {
TPair{{NO, NO}, CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL}, // noResidual
{{NO, CONTAIN}, CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL}, // useInputAsResidual
{{NO, EXCLUDE}, CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL}, // useInputAsResidual
{{PRE, NO}, CNNL_TRANSFORMER_PRE_LAYERNORM_NO_RESIDUAL}, // noResidual
{{PRE, CONTAIN}, CNNL_TRANSFORMER_PRE_LAYERNORM_INSIDE_RESIDUAL}, // useInputAsResidual
// residualThenLayernorm
{{PRE, EXCLUDE}, CNNL_TRANSFORMER_PRE_LAYERNORM_OUTSIDE_RESIDUAL}, // useLayernormAsResidual
// residualThenLayernorm
{{POST, NO}, CNNL_TRANSFORMER_POST_LAYERNORM_NO_RESIDUAL}, // noResidual
{{POST, CONTAIN}, CNNL_TRANSFORMER_POST_LAYERNORM_INSIDE_RESIDUAL}, // useInputAsResidual
// layernormThenResidual
{{POST, EXCLUDE}, CNNL_TRANSFORMER_POST_LAYERNORM_OUTSIDE_RESIDUAL}, // useInputAsResidual
// residualThenLayernorm
};
constexpr Helper from(LnresEnum mode) {
for (size_t i = 0; i < pairs.size(); ++i) {
if (pairs[i].second == mode) {
return pairs[i].first;
}
}
// throw TmoException("Invalid cnnlTransformerLayernormResidualStructure_t");
return Helper(NO, NO);
}
constexpr LnresEnum to(Helper mode) {
for (size_t i = 0; i < pairs.size(); ++i) {
if (pairs[i].first == mode) {
return pairs[i].second;
}
}
return CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL;
// throw TmoException("Invalid Helper");
}
constexpr Helper::Helper(LnresEnum mode) : Helper(from(mode)) {}
constexpr Helper::operator LnresEnum() const {
return to(*this);
}
} // namespace internal
using namespace internal;
inline LnresEnum makeLnresEnum(bool has_ln, bool has_residual, bool residual_is_input) {
return Helper(has_ln ? PRE : NO, has_residual ? (residual_is_input ? CONTAIN : EXCLUDE) : NO);
}
inline LnresEnum removeResidual(LnresEnum mode) {
Helper helper(mode);
return Helper(helper.layernorm_position, NO);
}
inline LnresEnum removeLayernorm(LnresEnum mode) {
Helper helper(mode);
return Helper(NO, helper.residual_position);
}
inline bool useLayernormAsResidual(LnresEnum mode) {
Helper helper(mode);
return helper.layernorm_position == PRE && helper.residual_position == EXCLUDE;
}
inline bool useInputAsResidual(LnresEnum mode) {
Helper helper(mode);
return helper.residual_position == CONTAIN ||
(helper.layernorm_position == NO && helper.residual_position != NO) ||
(helper.layernorm_position == POST && helper.residual_position == EXCLUDE);
}
inline bool hasResidual(LnresEnum mode) {
Helper helper(mode);
return helper.residual_position != NO;
}
inline bool hasLayernorm(LnresEnum mode) {
Helper helper(mode);
return helper.layernorm_position != NO;
}
inline bool isPostLayernorm(LnresEnum mode) {
Helper helper(mode);
return helper.layernorm_position == POST;
}
inline bool isPreLayernorm(LnresEnum mode) {
Helper helper(mode);
return helper.layernorm_position == PRE;
}
inline bool residualThenLayernorm(LnresEnum first_layer, LnresEnum second_layer) {
Helper h1(first_layer);
Helper h2(second_layer);
if (h1.residual_position == NO) { // h1 has no residual
return false;
}
if (h1.layernorm_position == POST && h2.layernorm_position == PRE) {
throw TmoException("too many layernorms");
}
return (h1.residual_position != NO && h1.layernorm_position != POST &&
h2.layernorm_position == PRE) || // l1 residual + l2 pre layernorm
(h1.layernorm_position == POST && h2.layernorm_position != PRE &&
h1.residual_position == EXCLUDE); // l1 inside residual + l1 post layernorm
}
inline bool layernormThenResidual(LnresEnum first_layer, LnresEnum second_layer) {
Helper h1(first_layer);
Helper h2(second_layer);
if (h1.residual_position == NO) { // h1 has no residual
return false;
}
if (h1.layernorm_position == POST && h2.layernorm_position == PRE) {
throw TmoException("too many layernorms");
}
return (h1.layernorm_position == POST && h1.residual_position == CONTAIN);
}
inline bool residualOnly(LnresEnum first_layer, LnresEnum second_layer) {
Helper h1(first_layer);
Helper h2(second_layer);
return h1.residual_position != NO && h1.layernorm_position != POST &&
h2.layernorm_position != PRE;
}
} // namespace lnres
} // namespace tmo
#ifndef CNNL_CHECK
#define CNNL_CHECK(expr) \
if (expr != CNNL_STATUS_SUCCESS) { \
std::cerr << __FILE__ << ":" << __LINE__ \
<< " Check failed: " #expr " == CNNL_STATUS_SUCCESS. " << std::endl; \
}
#endif
#define CNNL_CHECK_FATAL(expr) \
if ((expr) != CNNL_STATUS_SUCCESS) { \
std::cerr << __FILE__ << ":" << __LINE__ << ": " \
<< " Check failed: " #expr " == CNNL_STATUS_SUCCESS. " << std::endl; \
throw TmoException("Check failed: " #expr " == CNNL_STATUS_SUCCESS."); \
}
#define TMO_KERNEL_CHECK_FATAL(expr) \
if ((expr) != tmo::KernelStatus::KERNEL_STATUS_SUCCESS) { \
std::cerr << __FILE__ << ":" << __LINE__ << ": " \
<< " Check failed: " #expr " == KernelStatus::KERNEL_STATUS_SUCCESS. " << std::endl; \
throw TmoException("Check failed: " #expr " == KernelStatus::KERNEL_STATUS_SUCCESS."); \
}
#define CHECK_FATAL(expr, ...) \
if (!(expr)) { \
std::cerr << __FILE__ << ":" << __LINE__ << ": " \
<< " Check failed: " #expr ". " << tmo::stringize(__VA_ARGS__) << std::endl; \
throw TmoException("Check failed: " #expr ". " + tmo::stringize(__VA_ARGS__)); \
}
#undef CNRT_CHECK
#define CNRT_CHECK(val) \
do { \
cnrtRet_t __ret = val; \
if (__ret) { \
printf("[%s:%d] CNRT error, code=%d(%s) \"%s\" \n", __FILE__, __LINE__, (unsigned int)__ret, \
cnrtGetErrorStr(__ret), #val); \
throw TmoException(cnrtGetErrorStr(__ret)); \
} \
} while (0)
#define CN_CHECK(val) \
do { \
CNresult __ret = val; \
if (__ret) { \
const char *cn_err_string = nullptr; \
cnGetErrorString(__ret, &cn_err_string); \
printf("[%s:%d] CN error, code=%d(%s) \"%s\" \n", __FILE__, __LINE__, (unsigned int)__ret, \
cn_err_string, #val); \
throw TmoException(cn_err_string); \
} \
} while (0)
#define PAD_UP_DIV(x, y) (((x) + (y) - 1) / (y))
#define TMO_EXPORT __attribute__((__visibility__("default")))
#define TMO_HIDDEN __attribute__((__visibility__("hidden")))
#define DELETE_COPY_ASSIGN_CONSTRUCT(CLASSNAME) \
CLASSNAME(const CLASSNAME &) = delete; \
CLASSNAME(CLASSNAME &&) = delete; \
CLASSNAME &operator=(const CLASSNAME &) = delete; \
CLASSNAME &operator=(CLASSNAME &&) = delete;
// Note: Return type without const when const object called.
#define CLASS_CAST_TYPE_OPERATOR_DEFINE(DESCNAME, DESCOBJECT) \
inline operator DESCNAME() const { \
return const_cast<DESCNAME>(DESCOBJECT); \
} \
inline operator DESCNAME() { \
return DESCOBJECT; \
}
#endif // CSRC_COMMON_UTILS_H_

View File

@@ -0,0 +1,103 @@
cmake_minimum_required(VERSION 3.8)
project(tmo_kernels)
message(STATUS "project name: ${PROJECT_NAME}")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
################################################################################
# Build Evironment
################################################################################
set(BANG_TARGET_CPU_ARCH ${TARGET_CPU_ARCH})
message("-- TARGET_CPU_ARCH=${TARGET_CPU_ARCH}")
set(TARGET_MLU_ARCH ${TARGET_MLU_ARCH})
message("-- TARGET_MLU_ARCH=${TARGET_MLU_ARCH}")
set(NEUWARE_HOME ${NEUWARE_HOME})
message("-- NEUWARE_HOME=${NEUWARE_HOME}")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
"${CMAKE_SOURCE_DIR}/cmake"
"${NEUWARE_HOME}/cmake"
"${NEUWARE_HOME}/cmake/modules"
)
find_package(BANG)
if(NOT BANG_FOUND)
message(FATAL_ERROR "BANG cannot be found.")
else ()
if (NOT BANG_CNCC_EXECUTABLE)
message(FATAL_ERROR "cncc not found, please ensure cncc is in your PATH env or set variable BANG_CNCC_EXECUTABLE from cmake. Otherwise you should check path used by find_program(BANG_CNCC_EXECUTABLE) in FindBANG.cmake")
endif()
endif()
set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}/test")
set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -pthread -pipe")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${CMAKE_C_FLAGS} -g3 -O0")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${CMAKE_C_FLAGS} -O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fPIC -std=c++17 -pthread -pipe")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${CMAKE_CXX_FLAGS} -g3 -O0")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS_RELEASE} -Wl,--gc-sections -fPIC")
set(BANG_CNCC_FLAGS "-Wall -Werror -Wdeprecated-declarations -fPIC -std=c++17 -pthread --target=${TARGET_CPU_ARCH}")
if ( "${_cncc_version}" VERSION_LESS "5.0.0") # [CNNLCORE-19128]
message(STATUS "Default rounding mode will be rn when computing float numbers, otherwise will be tz when computing int numbers")
# This compile option was enabled by JIRA: CNNLCORE-12027
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -Xbang-cnas --deprecated-cvt-default-round-mode-rn")
endif()
if(${TARGET_CPU_ARCH} MATCHES ".*x86_64.*")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -mcmodel=large")
endif()
string(TOLOWER ${CMAKE_BUILD_TYPE} _CMAKE_BUILD_TYPE_LOWER)
if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "debug")
message(STATUS "Build debug mode")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g3 -O0")
endif()
if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "release")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3 -DNDEBUG")
endif()
if(${TARGET_MLU_ARCH} MATCHES "CNFATBIN")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=mtp_592 --bang-mlu-arch=mtp_613 --no-neuware-version-check")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64")
else()
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=${TARGET_MLU_ARCH}")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64")
endif()
# setup predefined macro for host sources, only for single mlu arch, useful for edge
if (${TARGET_MLU_ARCH} MATCHES "^(m?tp_)?([0-9]+)$")
# convert mtp_xxx or tp_xxx to xxx
string(REGEX REPLACE "^(m?tp_)?([0-9]+)$" "\\2" _TARGET_MLU_ARCH ${TARGET_MLU_ARCH})
add_definitions(-DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH})
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH}")
endif()
################################################################################
# Neuware Evironment
################################################################################
if(EXISTS ${NEUWARE_HOME})
include_directories("${NEUWARE_HOME}/include")
link_directories("${NEUWARE_HOME}/lib64")
link_directories("${NEUWARE_HOME}/lib")
else()
message(FATAL_ERROR "NEUWARE cannot be found, refer README.md to prepare NEUWARE_HOME environment.")
endif()
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
################################################################################
# Build TMO kernels
################################################################################
# aux_source_directory(src DIR_SRCS)
file(GLOB_RECURSE bang_src_files FOLLOW_SYMLINKS "${CMAKE_CURRENT_SOURCE_DIR}/*.mlu")
bang_add_library(tmo_kernels STATIC "${bang_src_files}")
target_link_libraries(tmo_kernels cnnl cnrt cndrv dl)

View File

@@ -0,0 +1,28 @@
#include "add_scalar.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define ONCHIP_DATA_NUM ((int)(__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)))
__nram__ int nram_buffer[ONCHIP_DATA_NUM];
__mlu_global__ void MLUBlockAddScalar(int *dst, int *src, int count, int scalar) {
int offset = ONCHIP_DATA_NUM * taskId;
int deal_num = std::min(ONCHIP_DATA_NUM, count - offset);
if (deal_num <= 0) return;
__memcpy(nram_buffer, src + offset, deal_num * sizeof(int), GDRAM2NRAM);
__bang_add_scalar(nram_buffer, nram_buffer, scalar, deal_num);
__memcpy(dst + offset, nram_buffer, deal_num * sizeof(int), NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar) {
uint32_t task_dim = (count + ONCHIP_DATA_NUM - 1) / ONCHIP_DATA_NUM;
cnrtDim3_t dim{task_dim, 1, 1};
kernels::MLUBlockAddScalar<<<dim, cnrtFuncTypeBlock, queue>>>(dst, src, count, scalar);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,29 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_ADD_SCALAR_MLUH_
#define CSRC_KERNELS_ADD_SCALAR_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Add src with a scalar and save the result to dst.
* @param queue: The queue for mlu.
* @param dst: Pointer to the MLU memory of dst.
* @param src: Pointer to the MLU memory of src.
* @param count: The elements number in src.
* @param scalar: The scalar to add.
* @note: only support int. dst can overlap with src.
*/
KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar);
} // namespace tmo
#endif // CSRC_KERNELS_ADD_SCALAR_MLUH_

View File

@@ -0,0 +1,205 @@
#!/bin/bash
set -e
TOP_DIR="$( cd "$( dirname "$0" )" && pwd )"
cd ${TOP_DIR}
################################################################################
# Evironment Variables
# BUILD_MODE: release/debug
# BUILD_DIR: build(default)
# TARGET_MLU_ARCH: CNFATBIN/MLU590
# TARGET_CPU_ARCH: x86_64-linux-gnu
# TARGET_C_COMPILER: C comppiler full-path
# TARGET_CXX_COMPILER: CXX comppiler full-path
# STRIP strip tool path
################################################################################
BUILD_MODE=${BUILD_MODE:-release}
BUILD_DIR="${BUILD_DIR:-build}"
BUILD_JOBS=${BUILD_JOBS:-32}
TARGET_MLU_ARCH=${TARGET_MLU_ARCH:-CNFATBIN}
TARGET_CPU_ARCH=${TARGET_CPU_ARCH:-$(uname -m)-linux-gnu}
TARGET_C_COMPILER=${TARGET_C_COMPILER:-gcc}
TARGET_CXX_COMPILER=${TARGET_CXX_COMPILER:-g++}
STRIP="${STRIP}" # empty by default, check later
# to forward variable to other scripts
export BUILD_DIR
################################################################################
# Shell Common Functions
################################################################################
check_deb_package() {
if [ -z "$(dpkg -l | grep ${1})" ]; then
echo "-- Please sudo apt install ${1}"
exit -1
fi
}
check_rpm_package() {
if [ -z "$(rpm -qa | grep ${1})" ]; then
echo "-- Please sudo yum install ${1}"
exit -1
fi
}
usage () {
echo "USAGE: build.sh <options>"
echo
echo " If need specify neuware path, please:"
echo " export NEUWARE_HOME=/path/of/your/neuware"
echo
echo "OPTIONS:"
echo " -h, --help Print usage"
echo " <null> If no --mluxxx specified, default arch is cnfatbin which contain all mlu arch"
echo " --mlu590 Build for target product MLU590: __BANG_ARCH__ = 592"
echo " cncc --bang-mlu-arch=mtp_592, cnas --mlu-arch mtp_592"
echo " -d, --debug Build test case with debug mode"
echo " -v, --verbose Build with verbose output"
echo " -j, --jobs=* Build parallel jobs"
echo " --cache Build without deleting BUILD_DIR contents first"
}
################################################################################
# Build Main Entry
################################################################################
# 1. Check cmake tool for build, cmake-3.23.1 is recommended
if [ -f "/etc/os-release" ]; then
source /etc/os-release
if [[ "${NAME}" == Ubuntu* ]] || [[ "${NAME}" == Debian* ]]; then
check_deb_package cmake
CMAKE=cmake
elif [[ "${NAME}" == CentOS* ]] || [[ "${NAME}" == Kylin* ]]; then
if [[ "${VERSION_ID}" == 7 ]]; then
check_rpm_package cmake3
CMAKE=cmake3
else
check_rpm_package cmake
CMAKE=cmake
fi
elif [[ "${NAME}" == Anolis* ]];then
check_rpm_package cmake
CMAKE=cmake
else
echo "-- Not support build on this os!"
exit -1
fi
else
echo "-- Not support build on this os!"
exit -1
fi
# 2. Create build dir
if [ ! -d "$BUILD_DIR" ]; then
mkdir "$BUILD_DIR"
fi
# 3. Handle build options
cmdline_args=$(getopt -o h,d,v,j: --long help,debug,verbose,jobs:,mlu590,cache -n 'build.sh' -- "$@")
eval set -- "$cmdline_args"
if [ $? != 0 ]; then echo "Unknown options, use -h or --help" >&2 ; exit -1; fi
if [ $# != 0 ]; then
while true; do
case "$1" in
--mlu590)
TARGET_MLU_ARCH="mtp_592"
shift
;;
-h | --help)
usage
exit 0
;;
-d | --debug)
BUILD_MODE="debug"
echo "-- Using debug mode."
shift
;;
-v | --verbose)
BUILD_VERBOSE="VERBOSE=1"
shift
;;
-j | --jobs)
shift
BUILD_JOBS=$1
shift
;;
--cache)
FLAG_KEEP_CACHE=1
shift
;;
--)
shift
break
;;
*)
echo "-- Unknown options ${1}, use -h or --help"
usage
exit -1
;;
esac
done
fi
# 5. Check NEUWARE_HOME and cncc
if [ ! -z "${NEUWARE_HOME}" ]; then
echo "-- using NEUWARE_HOME = ${NEUWARE_HOME}"
else
echo "-- NEUWARE_HOME is null, refer README.md to prepare NEUWARE_HOME environment."
exit -1
fi
# 6. Check device compiler
export PATH="${NEUWARE_HOME}/bin":$PATH
export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64":$LD_LIBRARY_PATH
if [ -z $(which cncc) ]; then
echo "-- ERROR: cannot find cncc"
exit -1
fi
cncc --version || ( echo "-- ERROR: cncc is not for current CPU target" && exit -1 )
echo "-- cncc: $(which cncc)"
# Check host compiler
## check compiler version and consider activate devtoolset for CentOS 7
if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then
if [ ! -f "/opt/rh/devtoolset-7/enable" ]; then
echo "You are using CentOS 7 but without 'devtoolset-7' installed."
echo "Please install devtoolset-7 or gnu-g++ that verion >= 5."
sleep 2
else
source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \
|| echo "devtoolset-7 has installed on your server, but source failed."
fi
fi
if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then
echo "we do not support g++<5, try to use higher version"
exit 1
fi
TARGET_C_COMPILER=$(which gcc)
TARGET_CXX_COMPILER=$(which g++)
echo "-- TARGET_C_COMPILER: " ${TARGET_C_COMPILER}
echo "-- TARGET_CXX_COMPILER: " ${TARGET_CXX_COMPILER}
export CC=$(basename ${TARGET_C_COMPILER})
export CXX=$(basename ${TARGET_CXX_COMPILER})
################################################################################
# Project Build
################################################################################
CMAKE_EXTRA_OPTIONS=()
SOURCE_DIR=${TOP_DIR}
pushd ${BUILD_DIR}
if [[ -z "${FLAG_KEEP_CACHE}" ]]; then
echo "Remove cmake cache ${PWD}"
rm -rf ./*
fi
${CMAKE} -DCMAKE_BUILD_TYPE="${BUILD_MODE}" \
-DNEUWARE_HOME="${NEUWARE_HOME}" \
-DTARGET_MLU_ARCH="${TARGET_MLU_ARCH}" \
-DTARGET_CPU_ARCH="${TARGET_CPU_ARCH}" \
-DCMAKE_C_COMPILER="$(basename ${TARGET_C_COMPILER})" \
-DCMAKE_CXX_COMPILER="$(basename ${TARGET_CXX_COMPILER})" \
-DCMAKE_STRIP="${STRIP}" \
${CMAKE_EXTRA_OPTIONS[@]} ${SOURCE_DIR}
popd
${CMAKE} --build ${BUILD_DIR} -- ${BUILD_VERBOSE} -j${BUILD_JOBS}

View File

@@ -0,0 +1,192 @@
#include <stdint.h>
#include <cmath>
#include <iostream>
#include <vector>
#include "cnnl.h"
#include "cnrt.h"
#include "copy_blocks.mluh"
#include "kernel_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
#define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753
#define LAYER_SIZE 128
#define BLOCK_PAIR_SIZE 512
#define ALIGN_BYTES 64
struct CopyBlocksInfo {
void *key_addrs[LAYER_SIZE];
void *value_addrs[LAYER_SIZE];
unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2];
bool has_value_cache = true;
};
namespace kernels {
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__mlu_func__ void copyBlocksNodld(CopyBlocksInfo info,
uint32_t num_per_core,
uint32_t block_mapping_offset,
int32_t num_layers,
uint32_t block_size_in_bytes) {
for (uint32_t i = 0; i < num_per_core; i++) {
uint32_t map_offset = block_mapping_offset + i * 2;
uint32_t src_idx = info.mapping_addrs[map_offset];
uint32_t dst_idx = info.mapping_addrs[map_offset + 1];
int64_t src_offset = block_size_in_bytes * src_idx;
int64_t dst_offset = block_size_in_bytes * dst_idx;
for (uint32_t j = 0; j < num_layers; j++) {
__memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset,
block_size_in_bytes, GDRAM2GDRAM);
if (info.has_value_cache) {
__memcpy((int8_t *)info.value_addrs[j] + dst_offset,
(int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM);
}
}
}
}
__mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info,
int32_t num_pairs,
int32_t num_layers,
uint32_t block_size_in_bytes) {
uint32_t num_per_core = num_pairs / taskDim;
uint32_t remain_for_core = num_pairs % taskDim;
num_per_core += ((taskId < remain_for_core) ? 1 : 0);
uint32_t block_mapping_offset =
num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core);
block_mapping_offset *= 2;
#if (__BANG_ARCH__ >= 592)
if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) {
auto num_pair_data_width = sizeof(int32_t);
uint32_t align_num = ALIGN_BYTES / num_pair_data_width;
unsigned int num_per_core_2 = num_per_core * 2;
unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num;
unsigned int *gather_src_offset = (unsigned int *)nram_buffer;
unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align;
int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align);
uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2;
unsigned int *scatter_dst_offset = gather_src_offset + num_per_core;
uint32_t num_per_loop = nram_remain / block_size_in_bytes;
uint32_t repeat = num_per_core / num_per_loop;
uint32_t remain = num_per_core % num_per_loop;
for (int i = 0; i < num_per_core; i++) {
unsigned int mapping_addrs_idx = block_mapping_offset + i * 2;
block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx];
block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1];
}
__bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes,
num_per_core_2);
__sync();
for (uint32_t k = 0; k < num_layers; k++) {
for (uint32_t i = 0; i < repeat; i++) {
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
if (info.has_value_cache) {
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, num_per_loop);
}
}
if (remain != 0) {
uint32_t repeat_nums = repeat * num_per_loop;
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, remain);
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, remain);
if (info.has_value_cache) {
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums,
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
(unsigned int)block_size_in_bytes, remain);
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
(unsigned int)block_size_in_bytes, remain);
}
}
}
} else {
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
}
#else
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
#endif
}
} // namespace kernels
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
const std::vector<void *> &key_caches,
const std::vector<void *> &value_caches,
const std::vector<int32_t> &block_mapping_vec,
const size_t block_size_in_bytes) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
if (key_caches.empty()) {
std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!value_caches.empty() && key_caches.size() != value_caches.size()) {
std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches "
<< "size if value_caches is not empty." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t mapping_size = block_mapping_vec.size();
int32_t num_pairs = mapping_size / 2;
uint32_t task_dim = std::min(num_pairs, cluster_num * core_num);
cnrtDim3_t k_dim{task_dim, 1, 1};
int32_t num_layers = key_caches.size();
int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE);
int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num);
int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE);
int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num);
CopyBlocksInfo info;
if (value_caches.empty()) {
info.has_value_cache = false;
}
for (int32_t i = 0; i < layer_loop_num; i++) {
int32_t sub_num_layers =
std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop);
for (int32_t l = 0; l < sub_num_layers; l++) {
info.key_addrs[l] = key_caches[l + i * layer_num_per_loop];
if (info.has_value_cache) {
info.value_addrs[l] = value_caches[l + i * layer_num_per_loop];
}
}
for (int32_t j = 0; j < pair_loop_num; j++) {
int32_t sub_num_pairs =
std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop);
int32_t lens_block_mapping = sub_num_pairs * 2;
int32_t block_vec_offset = j * pair_num_per_loop * 2;
for (int32_t m = 0; m < lens_block_mapping; m++) {
info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset];
}
kernels::launchCopyBlocksKernel<<<k_dim, k_type, queue>>>(info, sub_num_pairs, sub_num_layers,
block_size_in_bytes);
}
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,37 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_COPY_BLOCKS_MLUH_
#define CSRC_KERNELS_COPY_BLOCKS_MLUH_
#include <vector>
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform copy_blocks operation.
* @param queue: The queue for mlu.
* @param key_caches: Output/Input. Pointer to the MLU memory that stores the key_caches
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
* @param value_caches: Output/Input. Pointer to the MLU memory that stores the value_caches
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
* @param block_mapping_vec: block_mapping vector.
* @param block_size_in_bytes: one block data size.
*/
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
const std::vector<void *> &key_caches,
const std::vector<void *> &value_caches,
const std::vector<int32_t> &block_mapping_vec,
const size_t block_size_in_bytes);
} // namespace tmo
#endif // CSRC_KERNELS_COPY_BLOCKS_MLUH_

View File

@@ -0,0 +1,271 @@
#include <cmath>
#include <cstddef>
#include "cnnl.h"
#include "cnrt.h"
#include "create_cos_sin_table.mluh"
namespace {
// constexpr int LINEAR_SCALING = 0;
// constexpr int FIX_NTK_SCALING = 1;
constexpr int DYNAMIC_NTK_SCALING = 2;
} // namespace
namespace tmo {
namespace kernels {
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - 32 * 1024];
__nram__ const float range[64] = {
0.0F, 2.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 20.0F,
22.0F, 24.0F, 26.0F, 28.0F, 30.0F, 32.0F, 34.0F, 36.0F, 38.0F, 40.0F, 42.0F,
44.0F, 46.0F, 48.0F, 50.0F, 52.0F, 54.0F, 56.0F, 58.0F, 60.0F, 62.0F, 64.0F,
66.0F, 68.0F, 70.0F, 72.0F, 74.0F, 76.0F, 78.0F, 80.0F, 82.0F, 84.0F, 86.0F,
88.0F, 90.0F, 92.0F, 94.0F, 96.0F, 98.0F, 100.0F, 102.0F, 104.0F, 106.0F, 108.0F,
110.0F, 112.0F, 114.0F, 116.0F, 118.0F, 120.0F, 122.0F, 124.0F, 126.0F};
__mlu_func__ void genRangeDims(float *range_nram, int elem_count) {
int count = 64;
__bang_move(range_nram, range, std::min(count, elem_count) * sizeof(float));
while (count < elem_count) {
__bang_add_scalar(range_nram + count, range_nram, (float)count * 2.0F,
std::min(count, elem_count - count));
count *= 2;
}
}
__mlu_func__ int getBatchMaxSeqLen(int *seq_lens_nram, int *seq_lens, int batch) {
__memcpy(seq_lens_nram, seq_lens, batch * sizeof(int), GDRAM2NRAM);
__bang_argmax((float *)seq_lens_nram, (float *)seq_lens_nram, batch);
return __load_nram(seq_lens_nram);
}
__mlu_func__ float getNTKAlpha(int curr_seq_len, int max_position_embeddings, int kv_seq_len) {
int seq_len = kv_seq_len > max_position_embeddings ? curr_seq_len : kv_seq_len;
float context_value = std::log2((float)seq_len / (float)max_position_embeddings) + 1.0F;
float ntk_alpha = std::pow(2.0F, std::ceil(context_value)) - 1.0F;
return std::max(ntk_alpha, 1.0F);
}
__mlu_func__ void getRotaryInvFreq(float *inv_freq_nram,
float *base_nram,
float *range_nram,
float base,
int rotary_dim,
int elem_count) {
// inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
__bang_write_value(base_nram, elem_count, base);
__bang_mul_scalar(inv_freq_nram, range_nram, 1.0F / (float)rotary_dim, elem_count);
__bang_log(base_nram, base_nram, elem_count);
__bang_mul(inv_freq_nram, inv_freq_nram, base_nram, elem_count);
__bang_pow2(inv_freq_nram, inv_freq_nram, elem_count);
__bang_recip(inv_freq_nram, inv_freq_nram, elem_count);
}
template <typename T>
__mlu_func__ void convertCosSinTable(float *cos_table, float *sin_table, int elem_count) {}
template <>
__mlu_func__ void convertCosSinTable<half>(float *cos_table, float *sin_table, int elem_count) {
__bang_float2half((half *)cos_table, cos_table, elem_count);
__bang_float2half((half *)sin_table, sin_table, elem_count);
}
template <>
__mlu_func__ void convertCosSinTable<bfloat16_t>(float *cos_table,
float *sin_table,
int elem_count) {
__bang_float2bfloat16((bfloat16_t *)cos_table, cos_table, elem_count);
__bang_float2bfloat16((bfloat16_t *)sin_table, sin_table, elem_count);
}
__mlu_global__ void MLUUpdateCachedAlpha(float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch) {
int *seq_lens_nram = (int *)nram_buffer; // [batch]
int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch);
rotary_emb_alpha_cached[taskIdY] =
getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len);
}
template <typename T>
__mlu_global__ void MLUCreateCosSinTableKernel(void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
int seq_seg,
bool interleaved,
cnnlDataType_t dtype) {
int half_rotary_dim = rotary_dim / 2;
float *base_nram = (float *)nram_buffer; // [rotary_dim / 2]
float *range_nram = base_nram + half_rotary_dim; // [rotary_dim / 2]
float *inv_freq_nram = range_nram + half_rotary_dim; // [rotary_dim / 2]
float *freqs_nram = inv_freq_nram + half_rotary_dim; // [rotary_dim / 2]
float *cos_nram = freqs_nram + half_rotary_dim; // [rotary_dim]
float *sin_nram = cos_nram + rotary_dim; // [rotary_dim]
float *swap_nram = sin_nram + rotary_dim; // [rotary_dim]
int *seq_lens_nram = (int *)(swap_nram + rotary_dim); // [batch]
genRangeDims(range_nram, half_rotary_dim);
float adjust_base = rotary_base;
if (rotary_scaling_type == DYNAMIC_NTK_SCALING) {
int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch);
float ntk_alpha = getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len);
if (rotary_emb_alpha_cached[taskIdY] == ntk_alpha) {
return;
}
adjust_base = rotary_base * std::pow(ntk_alpha, (float)rotary_dim / (float)(rotary_dim - 2));
}
getRotaryInvFreq(inv_freq_nram, base_nram, range_nram, adjust_base, rotary_dim, half_rotary_dim);
int seq_start = taskIdX * seq_seg;
int seq_end = (taskIdX + 1) * seq_seg > rotary_seq_len ? rotary_seq_len : (taskIdX + 1) * seq_seg;
T *cos_table = (T *)cos_sin_table + (size_t)taskIdY * batch_stride;
T *sin_table = cos_table + rotary_dim;
for (int idx = seq_start; idx < seq_end; ++idx) {
__bang_mul_scalar(freqs_nram, inv_freq_nram, idx, half_rotary_dim);
__bang_cos(cos_nram, freqs_nram, half_rotary_dim);
__bang_sin(sin_nram, freqs_nram, half_rotary_dim);
convertCosSinTable<T>(cos_nram, sin_nram, half_rotary_dim);
if (!interleaved) {
__memcpy(cos_table + idx * rotary_stride, cos_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM,
half_rotary_dim * sizeof(T), 0, 1);
__memcpy(sin_table + idx * rotary_stride, sin_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM,
half_rotary_dim * sizeof(T), 0, 1);
} else {
__bang_move((T *)cos_nram + half_rotary_dim, (T *)cos_nram, half_rotary_dim * sizeof(T));
__bang_transpose((T *)swap_nram, (T *)cos_nram, 2, half_rotary_dim);
__memcpy(cos_table + idx * rotary_stride, (T *)swap_nram, half_rotary_dim * 2 * sizeof(T),
NRAM2GDRAM);
__bang_move((T *)sin_nram + half_rotary_dim, (T *)sin_nram, half_rotary_dim * sizeof(T));
__bang_transpose((T *)cos_nram, (T *)sin_nram, 2, half_rotary_dim);
__memcpy((T *)sin_table + idx * rotary_stride, (T *)cos_nram, half_rotary_dim * 2 * sizeof(T),
NRAM2GDRAM);
}
}
}
#if __BANG_ARCH__ < 592
template <>
__mlu_global__ void MLUCreateCosSinTableKernel<bfloat16_t>(void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
int seq_seg,
bool interleaved,
cnnlDataType_t dtype) {}
#endif
} // namespace kernels
KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue,
void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
bool interleaved,
cnnlDataType_t data_type) {
bool is_supported_dtype = data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_FLOAT ||
data_type == CNNL_DTYPE_BFLOAT16;
if (!is_supported_dtype) {
std::cerr << "[invokeCreateCosSinTable]: unsupport data type for create cos sin table kernel."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
// clang-format off
void (*create_sin_cos_kernels[])(
void*, /* cos_sin_table */
float*, /* rotary_emb_alpha_cached */
int*, /* seq_lens */
int, /* max_position_embeddings */
int, /* batch */
int, /* batch_stride */
int, /* rotary_seq_len */
int, /* rotary_dim */
int, /* rotary_stride */
float, /* rotary_base */
float, /* rotary_scaling */
int, /* rotary_scaling_type */
int, /* seq_seg */
bool, /* interleaved */
cnnlDataType_t /* data_type */
) = {
kernels::MLUCreateCosSinTableKernel<half>,
kernels::MLUCreateCosSinTableKernel<float>,
kernels::MLUCreateCosSinTableKernel<bfloat16_t>
};
// clang-format on
int kernel_index = 0;
if (data_type == CNNL_DTYPE_HALF) {
kernel_index = 0;
} else if (data_type == CNNL_DTYPE_FLOAT) {
kernel_index = 1;
} else {
kernel_index = 2;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num = 1;
int core_num = 1;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
int used_core_num = std::min(rotary_seq_len, cluster_num * core_num);
int seq_seg = (rotary_seq_len + used_core_num - 1) / used_core_num;
cnrtDim3_t dim1;
dim1.x = used_core_num;
dim1.y = rotary_scaling_type == DYNAMIC_NTK_SCALING ? batch : 1;
dim1.z = 1;
if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeCreateCosSinTable]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
create_sin_cos_kernels[kernel_index]<<<dim1, cnrtFuncTypeBlock, queue>>>(
cos_sin_table, rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch,
batch_stride, rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling,
rotary_scaling_type, seq_seg, interleaved, data_type);
if (rotary_scaling_type == DYNAMIC_NTK_SCALING) {
cnrtDim3_t dim2;
dim2.x = 1;
dim2.y = batch;
dim2.z = 1;
kernels::MLUUpdateCachedAlpha<<<dim2, cnrtFuncTypeBlock, queue>>>(
rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,62 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
#define CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Create cos and sin table for rotary embedding.
* @param queue: The queue for mlu.
* @param cos_sin_table: Output. Pointer to the MLU memory that stores the cos and sin table.
* If rotary_scaling_type is linear, the shape is [rotary_seq_len, rotary_stride].
* If rotary_scaling_type is dynamic ntk, the shape is [batch, rotary_seq_len,
* rotary_stride].
* @param rotary_emb_alpha_cached: Output/Input. Pointer to the MLU memory that
* stores the ntk alpha cache. Only used in dynamic ntk, the shape is [batch].
* @param seq_lens: Input. Pointer to the MLU memory that stores the true sequence len.
* The shape is [batch].
* @param max_position_embeddings: The maximum rotary embedding positions.
* @param batch: Batch size.
* @param batch_stride: The stride for batch dim of cos_sin_table.
* Only used in dynamic ntk, the value is rotary_seq_len * rotary_stride.
* @param rotary_seq_len: The rotary sequence length of cos and sin table.
* @param rotary_dim: The rotary dim value of cos and sin table.
* @param rotary_stride: The stride of rotary_seq_len dim for cos and sin table.
* @param rotary_base: The rotary base, value is usually 10000.
* @param rotary_scaling: The rotary scaling, value is usually 1.
* @param rotary_scaling_type: The rotary scaling type, value is linear or dynamic ntk.
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
* @param dtype: Data type of cos and sin table generated.
*/
KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue,
void *cos_sin_table,
float *rotary_emb_alpha_cached,
int *seq_lens,
int max_position_embeddings,
int batch,
int batch_stride,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
float rotary_base,
float rotary_scaling,
int rotary_scaling_type,
bool interleaved,
cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_

View File

@@ -0,0 +1,812 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cassert>
#include <climits>
#include <cstddef>
#include <iostream>
#include <type_traits>
#include "dequant_from_linear_cache.mluh"
#include "quant_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#pragma bang walign(16)
#define REM_FOR_STACK (32 * 1024)
#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK)
#define DEQUANT_LINEAR_PERHEAD kernels::MLUDequantFromLinearCacheKernelPerHead
#define DEQUANT_LINEAR_PERCHANNEL kernels::MLUDequantFromLinearCacheKernelPerChannel
#define DEQUANT_FUNC_LEN (24)
#define DEQUANT_BATCH_NUM (1024)
__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE];
__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE];
__nram__ uint8_t pre_table_nram[TRANS_TABLE_SIZE];
// Uses 8K = 1K * (4 + 4) to process offsets
__nram__ int32_t n_lens[DEQUANT_BATCH_NUM];
__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM];
__mlu_func__ void calcu_offsets_per_channel(int32_t &cache_id,
size_t &context_offset,
size_t &cache_offset,
size_t &scale_offset,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t cache_mem_len,
const int32_t seq_len,
const int32_t seq_begin,
const int32_t seq_offset,
const int32_t batch_idx,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride) {
cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
int32_t cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset = context_seq_stride * (seq_offset + seq_begin);
cache_offset = cache_bs_stride * cache_id + cache_seq_stride * (cache_seq_offset + seq_begin);
scale_offset = scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
}
__mlu_func__ void calcu_offsets_per_head(int32_t &cache_id,
size_t &context_offset,
size_t &key_cache_offset,
size_t &value_cache_offset,
size_t &scale_offset,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t cache_mem_len,
const int32_t seq_len,
const int32_t seq_begin,
const int32_t seq_offset,
const int32_t batch_idx,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride) {
cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
int32_t cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset = context_seq_stride * (seq_offset + seq_begin);
key_cache_offset =
cache_bs_stride * cache_id + key_cache_seq_stride * (cache_seq_offset + seq_begin);
value_cache_offset =
cache_bs_stride * cache_id + value_cache_seq_stride * (cache_seq_offset + seq_begin);
scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_channel(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
if (scale_bs_stride != 0) {
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM,
head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
}
if (std::is_same<Tc, int4x2_t>::value) {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM,
head_size >> 1, head_num - 1, scale_num >> 1, seq_num - 1, cache_head_stride,
head_num - 1, cache_seq_stride, seq_num - 1);
} else {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM,
head_size * sizeof_(Tc), head_num - 1, scale_num * sizeof_(Tc), seq_num - 1,
cache_head_stride * sizeof_(Tc), head_num - 1, cache_seq_stride * sizeof_(Tc),
seq_num - 1);
}
dequantize<T, Tc, Ts>((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf,
seq_num * scale_num, scale_num);
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_value_per_channel(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
int8_t *temp_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const bool pad_front) {
/* Step 1. load scale [head_num, head_size]*/
if (scale_bs_stride != 0) {
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM,
head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
}
/* Step 2. load cache [load_seq_num, head_num, head_size] */
int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2);
int32_t deal_seq_num = load_seq_num << 1;
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size,
head_num - 1, scale_num, load_seq_num - 1, cache_head_stride, head_num - 1,
cache_seq_stride, load_seq_num - 1);
/* Step 3. convert into int8 [load_seq_num, head_num, head_size, 2] */
convert((int8_t *)output_nram, (int4x2_t *)input_nram, deal_seq_num * scale_num);
/* Step 4. transpose to [deal_seq_num (load_seq_num, 2), head_num, head_size] */
trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram,
load_seq_num, head_num, head_size, 2);
/* Step 5. dequantize [save_seq_num, head_num, head_size] */
int save_seq_num = pad_front ? seq_num - 1 : seq_num;
dequantize<T, int8_t, Ts>((T *)output_nram, (int8_t *)temp_nram + (pad_front ? scale_num : 0),
(Ts *)scale_nram, (Ts *)nbuf, save_seq_num * scale_num, scale_num);
/* Step 6. store [save_seq_num, head_num, head_size]*/
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
save_seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T),
save_seq_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_head(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *temp_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_head_stride) {
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, seq_num * sizeof_(Ts), GDRAM2NRAM,
seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
if (std::is_same<Tc, int4x2_t>::value) {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM,
head_size >> 1, seq_num - 1, seq_num * (head_size >> 1), head_num - 1,
cache_seq_stride, seq_num - 1, cache_head_stride, head_num - 1);
} else {
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM,
head_size * sizeof_(Tc), seq_num - 1, seq_num * head_size * sizeof_(Tc), head_num - 1,
cache_seq_stride * sizeof_(Tc), seq_num - 1, cache_head_stride * sizeof_(Tc),
head_num - 1);
}
convert((float *)output_nram, (Tc *)input_nram, head_num * seq_num * head_size);
if (std::is_same<T, float>::value) {
conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * seq_num, head_size, 1);
} else {
conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * seq_num, head_size, 1);
output_nram = (T *)temp_nram;
}
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_seq_stride * sizeof_(T), seq_num - 1, context_head_stride * sizeof_(T),
head_num - 1, head_size * sizeof_(T), seq_num - 1, head_size * seq_num * sizeof_(T),
head_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_value_per_head(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *temp_nram,
T *data,
Tc *cache,
Ts *scale,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t cache_offset,
const size_t scale_offset,
const size_t context_seq_stride,
const size_t context_head_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const bool pad_front) {
int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2);
int32_t deal_seq_num = load_seq_num << 1;
/* Step1. load scale first, [head_num, deal_seq_num] */
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, deal_seq_num * sizeof_(Ts), GDRAM2NRAM,
deal_seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
/* Step2. load cache input, [head_num, load_seq_num, head_size, 2] for int4 */
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size,
load_seq_num - 1, load_seq_num * head_size, head_num - 1, cache_seq_stride,
load_seq_num - 1, cache_head_stride, head_num - 1);
convert((int8_t *)output_nram, (Tc *)input_nram, head_num * head_size * deal_seq_num);
/* Step3. trans to [head_num, load_seq_num, 2, head_size]*/
trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram,
head_num * load_seq_num, head_size, 1, 2);
/* Step4. dequant to T [head_num, deal_seq_num, head_size] */
convert((float *)output_nram, (int8_t *)temp_nram, head_num * deal_seq_num * head_size);
if (std::is_same<T, float>::value) {
conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * deal_seq_num, head_size, 1);
} else {
conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
head_num * deal_seq_num, head_size, 1);
output_nram = (T *)temp_nram;
}
/* Step5. save [head_num, save_seq_num, head_size]*/
int32_t save_seq_num = pad_front ? seq_num - 1 : seq_num;
__memcpy((T *)data + context_offset, (T *)output_nram + (pad_front ? head_size : 0),
head_size * sizeof_(T), NRAM2GDRAM, context_seq_stride * sizeof_(T), save_seq_num - 1,
context_head_stride * sizeof_(T), head_num - 1, head_size * sizeof_(T), save_seq_num - 1,
head_size * deal_seq_num * sizeof_(T), head_num - 1);
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromLinearCacheKernelPerChannel(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[head_num, head_size]|output/input[seq_block, head_num, head_size]|
*/
int32_t scale_num = head_num * head_size;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram = (Tc *)output_nram +
(std::is_same<Tc, int4x2_t>::value
? (7 * seq_block * (scale_num >> 1))
: (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)));
// temp_nram for nram to store int8 input
int8_t *temp_nram = (int8_t *)output_nram + seq_block * scale_num * 3;
int32_t seq_offset;
int32_t seq_len;
int32_t seq_begin;
int32_t deal_seq_num;
int32_t cache_id;
int32_t cache_seq_offset;
size_t context_offset;
size_t cache_offset;
size_t scale_offset;
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
if (has_key) {
load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
calcu_offsets_per_channel(cache_id, context_offset, cache_offset, scale_offset,
(int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets, cache_mem_len,
seq_len, seq_begin, seq_offset, batch_idx, context_seq_stride,
cache_bs_stride, key_cache_seq_stride, scale_bs_stride);
if (cache_id < 0) continue;
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key,
(Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num, head_num,
head_size, context_offset, cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
key_cache_seq_stride, scale_bs_stride, scale_head_stride);
}
}
if (has_value) {
if (std::is_same<Tc, int4x2_t>::value) {
__reshape_nhwc2nchw_smallc_init<int8_t>(pre_table_nram, 2);
}
load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
if (std::is_same<Tc, int4x2_t>::value) {
seq_begin = taskIdZ * seq_block;
cache_id =
cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
cache_seq_offset = cache_seq_offsets == nullptr
? 0
: __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
// move seq_begin left by 1 when cache_seq_offset is odd
seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
if (cache_id >= 0 && cache_seq_offset >= 0 &&
(cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset =
context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0));
// value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t
cache_offset = cache_bs_stride * cache_id +
value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2);
scale_offset = scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
if (cache_id < 0) continue;
dequantize_value_per_channel(
(T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (int8_t *)temp_nram, (T *)value,
(Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num, head_num, head_size,
context_offset, cache_offset, scale_offset, context_seq_stride, context_head_stride,
cache_head_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride,
seq_begin == -1);
} else {
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
calcu_offsets_per_channel(
cache_id, context_offset, cache_offset, scale_offset, (int32_t *)cache_bs_id,
(int32_t *)cache_seq_offsets, cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx,
context_seq_stride, cache_bs_stride, value_cache_seq_stride, scale_bs_stride);
if (cache_id < 0) continue;
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value,
(Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num,
head_num, head_size, context_offset, cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
value_cache_seq_stride, scale_bs_stride, scale_head_stride);
}
}
}
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromLinearCacheKernelPerHead(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *cache_bs_id,
const int32_t *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[seq_block, head_num]|output/input[head_size, seq_block, head_num]|temp|
*/
int32_t scale_num = seq_block * head_num;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram = (Tc *)output_nram +
(std::is_same<Tc, int4x2_t>::value
? (7 * head_size * (scale_num >> 1))
: head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc));
// temp_nram for nram to store converted output
float *temp_nram = (float *)output_nram + head_size * scale_num;
int32_t seq_offset;
int32_t seq_len;
int32_t seq_begin;
int32_t deal_seq_num;
int32_t cache_id;
size_t context_offset;
size_t key_cache_offset;
size_t value_cache_offset;
size_t scale_offset;
__bang_write_value((float *)nbuf, head_size * 16, 1.0f);
mvNram2WramLT16<float>((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16);
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
calcu_offsets_per_head(cache_id, context_offset, key_cache_offset, value_cache_offset,
scale_offset, (int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets,
cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx,
context_seq_stride, cache_bs_stride, key_cache_seq_stride,
value_cache_seq_stride, scale_bs_stride);
if (cache_id < 0) continue;
if (has_key) {
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)key, (Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num,
head_num, head_size, context_offset, key_cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
key_cache_seq_stride, scale_head_stride);
}
if (has_value && std::is_same<Tc, int8_t>::value) {
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)value, (Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num,
head_num, head_size, context_offset, value_cache_offset, scale_offset,
context_seq_stride, context_head_stride, cache_head_stride,
value_cache_seq_stride, scale_head_stride);
}
}
// process value int4 differently
if (has_value && std::is_same<Tc, int4x2_t>::value) {
int32_t cache_seq_offset;
__reshape_nhwc2nchw_smallc_init<int8_t>(pre_table_nram, 2);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
cache_id =
cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
cache_seq_offset =
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
// move seq_begin left by 1 when cache_seq_offset is odd
seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
context_offset =
context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0));
// value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t
value_cache_offset = cache_bs_stride * cache_id +
value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2);
scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id;
} else {
cache_id = -1;
}
if (cache_id < 0) continue;
dequantize_value_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram,
(T *)temp_nram, (T *)value, (Tc *)value_cache, (Ts *)value_scale,
scale_num, deal_seq_num, head_num, head_size, context_offset,
value_cache_offset, scale_offset, context_seq_stride,
context_head_stride, cache_head_stride, value_cache_seq_stride,
scale_bs_stride, scale_head_stride, seq_begin == -1);
}
}
}
} // namespace kernels
#define DEQUANT_LINEAR_INIT(T, Tc, Ts, C, Name) \
template __mlu_global__ void kernels::MLUDequantFromLinearCacheKernel##Name<T, Tc, Ts, C>( \
void *key, void *value, const void *key_cache, const void *value_cache, \
const void *key_scale, const void *value_scale, const int32_t *context_lens, \
const int32_t *context_seq_offsets, const int32_t *cache_bs_id, \
const int32_t *cache_seq_offsets, const int32_t max_context_len, const int32_t batch_size, \
const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \
const int32_t cache_mem_len, const int32_t head_size, const int32_t seq_block, \
const size_t context_head_stride, const size_t context_seq_stride, \
const size_t cache_bs_stride, const size_t cache_head_stride, \
const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \
const size_t scale_bs_stride, const size_t scale_head_stride);
DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerChannel)
DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerHead)
DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerChannel)
DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerHead)
DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerHead)
typedef void (*DequantFromLinearCachePointer)(void *, // key
void *, // value
const void *, // key_cache
const void *, // value_cache
const void *, // key_scale
const void *, // value_scale
const int32_t *, // context_lens
const int32_t *, // context_seq_offsets
const int32_t *, // cache_bs_id
const int32_t *, // cache_seq_offsets
const int32_t, // max_context_len
const int32_t, // batch_size
const int32_t, // head_num
const int32_t, // key_group_num
const int32_t, // value_group_num
const int32_t, // cache_mem_len
const int32_t, // head_size
const int32_t, // seq_block
const size_t, // context_head_stride
const size_t, // context_seq_stride
const size_t, // cache_bs_stride
const size_t, // cache_head_stride
const size_t, // key_cache_seq_stride
const size_t, // value_cache_seq_stride
const size_t, // scale_bs_stride
const size_t); // scale_head_stride
static DequantFromLinearCachePointer DequantFromLinearCacheFuncArr[DEQUANT_FUNC_LEN] = {
DEQUANT_LINEAR_PERCHANNEL<half, int8_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int8_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<float, int8_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<half, int4x2_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int4x2_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<float, int4x2_t, float, false>,
DEQUANT_LINEAR_PERHEAD<half, int8_t, float, false>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int8_t, float, false>,
DEQUANT_LINEAR_PERHEAD<float, int8_t, float, false>,
DEQUANT_LINEAR_PERHEAD<half, int4x2_t, float, false>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int4x2_t, float, false>,
DEQUANT_LINEAR_PERHEAD<float, int4x2_t, float, false>,
DEQUANT_LINEAR_PERCHANNEL<half, int8_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int8_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<float, int8_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<half, int4x2_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int4x2_t, float, true>,
DEQUANT_LINEAR_PERCHANNEL<float, int4x2_t, float, true>,
DEQUANT_LINEAR_PERHEAD<half, int8_t, float, true>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int8_t, float, true>,
DEQUANT_LINEAR_PERHEAD<float, int8_t, float, true>,
DEQUANT_LINEAR_PERHEAD<half, int4x2_t, float, true>,
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int4x2_t, float, true>,
DEQUANT_LINEAR_PERHEAD<float, int4x2_t, float, true>};
uint32_t getDequantLinearIdx(cnnlDataType_t dtype,
int32_t quant_mode,
int32_t quant_bit,
const void *context_seq_offset) {
uint32_t idx = 0;
idx += (quant_mode != 0) ? 6 : 0;
idx += (quant_bit != 8) ? 3 : 0;
idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0;
idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0;
idx += (context_seq_offset == nullptr) ? 12 : 0;
return idx;
}
void getBlockAndDimForLinear(int32_t &seq_block,
cnrtDim3_t &task_dim,
cnrtFunctionType_t &task_type,
const int32_t max_context_len,
const int32_t head_num,
const int32_t batch_size,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const cnnlDataType_t dtype) {
int32_t core_dim;
int32_t cluster_dim;
int32_t nram_size = 480 * 1024;
int32_t wram_size = 512 * 1024;
int32_t sram_size = 2016 * 1024;
getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK);
if (quant_mode == 0) {
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1);
if (quant_bit == 4) {
if (seq_block <= 1) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof_(float) should be less than "
<< (nram_size >> 1) << " when quant_mode is 0." << std::endl;
}
} else {
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof_(float) should be less than " << nram_size
<< " when quant_mode is 0." << std::endl;
}
}
} else {
int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half);
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) +
(int64_t)head_num * head_size * dtype_size));
if (quant_bit == 4) {
if (seq_block <= 1) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
"context_dtype_size) "
<< "should be less than " << (nram_size >> 1) << " when quant_mode is 1."
<< std::endl;
}
} else {
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
"context_dtype_size) "
<< "should be less than " << nram_size << " when quant_mode is 1." << std::endl;
}
}
/* head_size * 64B put in the wram. */
if (head_size * ONE_LINE >= wram_size) {
std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than "
<< wram_size << " when quant_mode is 1." << std::endl;
}
}
seq_block = std::min(seq_block, max_context_len);
if (seq_block > 16 && seq_block < max_context_len) {
seq_block = PAD_DOWN(seq_block, 16);
}
if (quant_bit == 4) {
seq_block = PAD_DOWN(seq_block, 2);
}
int seq_seg = DIV_UP(max_context_len, seq_block);
// need an extra seg block to dealwith int4 value_cache [...,seq_len/2, head_size]
if (quant_bit == 4) {
seq_seg += 1;
}
uint32_t core_num = cluster_dim * core_dim;
if (batch_size * seq_seg <= (core_num / 2)) {
int times = core_num / batch_size / seq_seg;
seq_block = std::max(seq_block / times, 2);
if (quant_bit == 4) {
seq_block = PAD_DOWN(seq_block, 2);
}
seq_seg = DIV_UP(max_context_len, seq_block);
// same as above to dealwise int4 value_cache with an extra seg block
if (quant_bit == 4) {
seq_seg += 1;
}
}
task_dim.x = 1;
task_dim.y = uint32_t(std::min(batch_size, cluster_dim * core_dim));
task_dim.z = uint32_t(seq_seg);
task_type = cnrtFuncTypeBlock;
}
KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *cache_bs_id,
const void *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const cnnlDataType_t dtype) {
if (dtype == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeDequantFromPagedCache]: "
"MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t index;
int32_t seq_block;
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
getBlockAndDimForLinear(seq_block, k_dim, k_type, max_context_len, head_num, batch_size,
head_size, quant_mode, quant_bit, dtype);
index = getDequantLinearIdx(dtype, quant_mode, quant_bit, context_seq_offsets);
auto dequant_linear_func = DequantFromLinearCacheFuncArr[index];
dequant_linear_func<<<k_dim, k_type, queue>>>(
(void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache,
(const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens,
(const int32_t *)context_seq_offsets, (const int32_t *)cache_bs_id,
(const int32_t *)cache_seq_offsets, max_context_len, batch_size, head_num, key_group_num,
value_group_num, cache_mem_len, head_size, seq_block, context_head_stride, context_seq_stride,
cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride,
scale_bs_stride, scale_head_stride);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,108 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief De-quantizes the key and value tensors from the provided linear cache and scale.
* @param queue: The queue for mlu.
* @param key: Pointer to the MLU memory that stores the key tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be float32, half,
* or bfloat16. This parameter can be nullptr.
* @param value: Pointer to the MLU memory that stores the value tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be float32,
* half, or bfloat16.This parameter can be nullptr.
* @param key_cache: Pointer to the MLU memory that stores the key cache tensor,
* with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization
* or [max_bs, head_num, cache_mem_len, head_size//2] for 4-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache tensor,
* with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization
* or [max_bs, head_num, cache_mem_len//2, head_size] for 4-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale.
* Shape depends on quantization mode:
* - For per-channel quantization (quant_mode = 0): [head_num, head_size].
* - For per-token quantization (quant_mode = 1): [max_batch_size, head_num, cache_mem_len].
* Data type must be float32. This parameter can be nullptr.
* @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale,
* with the same shape as key_scale. Data type must be float32. This parameter can be
nullptr.
* @param context_lens: Pointer to the MLU memory that stores the sequence lengths.
* The shape must be [batch].
* @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the
context.
* The shape must be [batch]. If nullptr, the default value is the cumulative sum of
context_lengths.
* @param cache_bs_id: Pointer to the MLU memory that stores the batch index in the cache.
* The shape must be [batch]. If nullptr, the default value is {0, 1, 2, ..., batch - 1}.
* @param cache_seq_offset: Pointer to the MLU memory that stores the sequence offset in the cache.
* The shape must be [batch]. If nullptr, the default value is 0 for every batch.
* @param max_contxt_len: The maximum sequence length of context.
* @param batch: Batch size.
* @param head_num: Head number.
* @param key_group_num: group number of key group-wise quantization.
* @param value_group_num: group number of value group-wise quantization.
* @param cache_mem_len: The maximum sequence length of cache.
* @param head_size: Head size.
* @param quant_mode: An integer value indicating the quantization mode:
* 0 for per-channel quantization and 1 for per-token quantization.
* @param quant_bit: An integer value indicating the quantization bit width:
* 8 for 8-bit quantization and 4 for 4-bit quantization.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param key_cache_seq_stride: The stride of cache_mem_len in key cache.
* @param value_cache_seq_stride: The stride of cache_mem_len in value cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param dtype: The data type of the key and value tensors.
* @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key.
* If any of value/value_cache/value_scale is nullptr, no operation is performed on the value.
*/
KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *cache_bs_ids,
const void *cache_seq_offsets,
const int32_t max_context_len,
const int32_t batch,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t cache_mem_len,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t cache_scale_bs_stride,
const size_t cache_scale_head_stride,
const cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_

View File

@@ -0,0 +1,616 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cassert>
#include <climits>
#include <cstddef>
#include <iostream>
#include <type_traits>
#include "dequant_from_paged_cache.mluh"
#include "quant_utils.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#pragma bang walign(16)
#define REM_FOR_STACK (32 * 1024)
#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK)
#define DEQUANT_PAGED_PERHEAD kernels::MLUDequantFromPagedCacheKernelPerHead
#define DEQUANT_PAGED_PERCHANNEL kernels::MLUDequantFromPagedCacheKernelPerChannel
#define DEQUANT_FUNC_LEN (24)
#define DEQUANT_BATCH_NUM (1024)
__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE];
__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE];
// Uses 8K = 1K * (4 + 4) to process offsets
__nram__ int32_t n_lens[DEQUANT_BATCH_NUM];
__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM];
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_channel(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *data,
const int32_t scale_num,
const int32_t seq_num,
const int32_t head_num,
const int32_t head_size,
const size_t context_offset,
const size_t context_seq_stride,
const size_t context_head_stride) {
dequantize<T, Tc, Ts>((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf,
seq_num * scale_num, scale_num);
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1);
}
template <typename T, typename Tc, typename Ts>
__mlu_func__ void dequantize_per_head(T *output_nram,
Tc *input_nram,
Ts *scale_nram,
T *temp_nram,
T *data,
const int32_t seq_num,
const int32_t head_num,
const int32_t block_size,
const int32_t head_size,
const size_t context_offset,
const size_t context_seq_stride,
const size_t context_head_stride) {
int block_count = DIV_UP(seq_num, block_size);
int rem_token = seq_num % block_size;
convert((float *)output_nram, (Tc *)input_nram, block_count * head_num * block_size * head_size);
T *res_nram = std::is_same<T, float>::value ? (T *)output_nram : (T *)temp_nram;
// dequantize [block_count, head_num, block_size, head_size]
conv_fuse_mul_cvt((T *)res_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
block_count * head_num * block_size, head_size, 1);
// copy to [seq_num, head_num, head_size]
int whole_block_count = block_count - int(rem_token > 0);
if (whole_block_count) {
for (int i = 0; i < head_num; ++i) {
// copy from [whole_block_count, i, block_size, head_size]
// to [whole_block_count * block_size, 1, head_size]
__memcpy((T *)data + context_offset + i * context_head_stride,
(T *)res_nram + i * block_size * head_size, head_size * sizeof_(T), NRAM2GDRAM,
context_seq_stride * sizeof_(T), block_size - 1,
block_size * context_seq_stride * sizeof_(T), whole_block_count - 1,
head_size * sizeof_(T), block_size - 1,
head_num * block_size * head_size * sizeof_(T), whole_block_count - 1);
}
}
if (rem_token) {
// copy from [last, head_num, block_size(rem_token), head_size]
// to [rem_token, head_num, head_size]
__memcpy((T *)data + context_offset + whole_block_count * block_size * context_seq_stride,
(T *)res_nram + whole_block_count * head_num * block_size * head_size,
head_size * sizeof_(T), NRAM2GDRAM, context_head_stride * sizeof_(T), head_num - 1,
context_seq_stride * sizeof_(T), rem_token - 1, block_size * head_size * sizeof_(T),
head_num - 1, head_size * sizeof_(T), rem_token - 1);
}
}
template <typename Tc>
__mlu_func__ void load_input_per_channel(Tc *input_nram,
Tc *cache,
Tc *temp_nram,
int32_t *block_offsets,
uint32_t *cache_offsets,
const int32_t *block_tables,
const int32_t batch_idx,
const int32_t scale_num,
const int32_t max_block_num,
const int32_t head_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_begin,
const int32_t deal_seq_num,
const size_t cache_bs_stride,
const size_t cache_head_stride) {
int32_t block_start = batch_idx * max_block_num + seq_begin / block_size;
int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size;
int32_t block_count = block_end - block_start + 1;
// make sure elements in block_tables >= 0
__memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start,
block_count * sizeof_(int32_t), GDRAM2NRAM);
__bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
#if __BANG_ARCH__ >= 500
// gather [block_count, head_num, block_size, head_size]
__gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
if (head_num != 1 && block_size != 1) {
// mv to [head_num, whole_block_count, block_size, head_size]
__memcpy((Tc *)temp_nram, (Tc *)input_nram, block_size * head_size * sizeof(Tc), NRAM2NRAM,
block_size * head_size * sizeof(Tc), block_count - 1,
block_count * block_size * head_size * sizeof(Tc), head_num - 1,
head_num * block_size * head_size * sizeof(Tc), block_count - 1,
block_size * head_size * sizeof(Tc), head_num - 1);
// mv to [whole_block_count, block_size, head_num, head_size]
__memcpy((Tc *)input_nram, (Tc *)temp_nram, head_size * sizeof(Tc), NRAM2NRAM,
head_size * sizeof(Tc), head_num - 1, head_num * head_size * sizeof(Tc),
block_count * block_size - 1, block_count * block_size * head_size * sizeof(Tc),
head_num - 1, head_size * sizeof(Tc), block_count * block_size - 1);
}
#endif
}
template <typename Tc, typename Ts>
__mlu_func__ void load_input_per_head(Tc *input_nram,
Ts *scale_nram,
Tc *cache,
Ts *scale,
Tc *temp_nram,
int32_t *block_offsets,
uint32_t *cache_offsets,
uint32_t *scale_offsets,
const int32_t *block_tables,
const int32_t batch_idx,
const int32_t max_block_num,
const int32_t head_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_begin,
const int32_t deal_seq_num,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
int32_t block_start = batch_idx * max_block_num + seq_begin / block_size;
int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size;
int32_t block_count = block_end - block_start + 1;
// make sure elements in block_tables >= 0
__memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start,
block_count * sizeof_(int32_t), GDRAM2NRAM);
__bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
__bang_mul_scalar((uint32_t *)scale_offsets, (uint32_t *)block_offsets,
(uint32_t)scale_bs_stride * sizeof(Ts), block_count);
#if __BANG_ARCH__ >= 500
// gather [block_count, head_num, block_size, head_size]
__gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets,
(uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM,
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
// gather [block_count, head_num, block_size]
__gather((Ts *)scale_nram, (Ts *)scale, (uint32_t *)scale_offsets,
(uint32_t)scale_bs_stride * sizeof(Ts), GDRAM2NRAM,
(uint32_t)scale_bs_stride * sizeof(Ts), block_count);
#endif
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromPagedCacheKernelPerChannel(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[head_num, head_size] fp32|output/input[seq_block, head_num, head_size] fp32|
*/
int32_t scale_num = head_num * head_size;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram = (Tc *)output_nram +
(std::is_same<Tc, int4x2_t>::value
? (7 * seq_block * (scale_num >> 1))
: (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)));
int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + seq_block * scale_num);
uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size);
int32_t seq_offset;
int32_t seq_len;
int32_t seq_begin;
int32_t deal_seq_num;
size_t context_offset;
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
if (has_key) {
load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
// seq_begin % block_size != 0 only when seq_block < block_size
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_channel((Tc *)input_nram, (Tc *)key_cache, (Tc *)output_nram,
(int32_t *)block_offsets, (uint32_t *)cache_offsets,
(int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num,
block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride,
cache_head_stride);
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key,
scale_num, deal_seq_num, head_num, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
if (has_value) {
load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride,
scale_head_stride);
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
// seq_begin % block_size != 0 only when seq_block < block_size
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_channel((Tc *)input_nram, (Tc *)value_cache, (Tc *)output_nram,
(int32_t *)block_offsets, (uint32_t *)cache_offsets,
(int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num,
block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride,
cache_head_stride);
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value,
scale_num, deal_seq_num, head_num, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
}
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
__mlu_global__ void MLUDequantFromPagedCacheKernelPerHead(void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const int32_t *context_lens,
const int32_t *context_seq_offsets,
const int32_t *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t seq_block,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride) {
bool has_key = (key && key_cache && key_scale);
bool has_value = (value && value_cache && value_scale);
if (!(has_key || has_value)) {
return;
}
/* *********************************nram space **************************************
* NRAM |scale[seq_block, head_num] * fp32|output/input[seq_block, head_num, head_size] * fp32|
* |temp[seq_block, head_num, head_size] * output_dtype|
* WRAM |head_size * 64B|
*/
int32_t scale_num = seq_block * head_num;
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
float *output_nram = (float *)nbuf + scale_num;
Tc *input_nram =
(Tc *)output_nram + (head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc));
// temp_nram for nram to store temp output
float *temp_nram = (float *)output_nram + head_size * scale_num;
int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + head_size * scale_num);
uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size);
uint32_t *scale_offsets = (uint32_t *)cache_offsets + DIV_UP(seq_block, block_size);
int32_t seq_len;
int32_t seq_begin;
int32_t seq_offset;
int32_t deal_seq_num;
size_t context_offset;
__bang_write_value((float *)nbuf, head_size * 16, 1.0f);
mvNram2WramLT16<float>((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16);
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
(int32_t *)context_seq_offsets, batch_size);
if (has_key) {
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)key_cache, (Ts *)key_scale,
(Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets,
(uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx,
max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num,
cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride);
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)key, deal_seq_num, head_num, block_size, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
if (has_value) {
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
batch_idx);
seq_begin = taskIdZ * seq_block;
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
if (deal_seq_num <= 0 || seq_offset < 0) continue;
context_offset = context_seq_stride * (seq_offset + seq_begin);
load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)value_cache, (Ts *)value_scale,
(Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets,
(uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx,
max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num,
cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride);
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
(T *)value, deal_seq_num, head_num, block_size, head_size, context_offset,
context_seq_stride, context_head_stride);
}
}
}
} // namespace kernels
#define DEQUANT_PAGED_INIT(T, Tc, Ts, C, Name) \
template __mlu_global__ void kernels::MLUDequantFromPagedCacheKernel##Name<T, Tc, Ts, C>( \
void *key, void *value, const void *key_cache, const void *value_cache, \
const void *key_scale, const void *value_scale, const int32_t *context_lens, \
const int32_t *context_seq_offsets, const int32_t *block_tables, \
const int32_t max_context_len, const int32_t max_block_num, const int32_t batch_size, \
const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \
const int32_t block_size, const int32_t head_size, const int32_t seq_block, \
const size_t context_head_stride, const size_t context_seq_stride, \
const size_t cache_bs_stride, const size_t cache_head_stride, \
const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \
const size_t scale_bs_stride, const size_t scale_head_stride);
DEQUANT_PAGED_INIT(half, int8_t, float, false, PerChannel)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerChannel)
DEQUANT_PAGED_INIT(float, int8_t, float, false, PerChannel)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerChannel)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerChannel)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerChannel)
DEQUANT_PAGED_INIT(half, int8_t, float, false, PerHead)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerHead)
DEQUANT_PAGED_INIT(float, int8_t, float, false, PerHead)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerHead)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerHead)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerHead)
DEQUANT_PAGED_INIT(half, int8_t, float, true, PerChannel)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerChannel)
DEQUANT_PAGED_INIT(float, int8_t, float, true, PerChannel)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerChannel)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerChannel)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerChannel)
DEQUANT_PAGED_INIT(half, int8_t, float, true, PerHead)
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerHead)
DEQUANT_PAGED_INIT(float, int8_t, float, true, PerHead)
// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerHead)
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerHead)
// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerHead)
typedef void (*DequantFromPagedCachePointer)(void *, // key
void *, // value
const void *, // key_cache
const void *, // value_cache
const void *, // key_scale
const void *, // value_scale
const int32_t *, // context_lens
const int32_t *, // context_seq_offsets
const int32_t *, // block_tables
const int32_t, // max_context_len
const int32_t, // max_block_num
const int32_t, // batch_size
const int32_t, // head_num
const int32_t, // key_group_num
const int32_t, // value_group_num
const int32_t, // block_size
const int32_t, // head_size
const int32_t, // seq_block
const size_t, // context_head_stride
const size_t, // context_seq_stride
const size_t, // cache_bs_stride
const size_t, // cache_head_stride
const size_t, // key_cache_seq_stride
const size_t, // value_cache_seq_stride
const size_t, // scale_bs_stride
const size_t); // scale_head_stride
static DequantFromPagedCachePointer DequantFromPagedCacheFuncArr[DEQUANT_FUNC_LEN] = {
DEQUANT_PAGED_PERCHANNEL<half, int8_t, float, false>,
DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int8_t, float, false>,
DEQUANT_PAGED_PERCHANNEL<float, int8_t, float, false>, nullptr, nullptr, nullptr,
// DEQUANT_PAGED_PERCHANNEL<half, int4x2_t, float, false>,
// DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int4x2_t, float, false>,
// DEQUANT_PAGED_PERCHANNEL<float, int4x2_t, float, false>,
DEQUANT_PAGED_PERHEAD<half, int8_t, float, false>,
DEQUANT_PAGED_PERHEAD<bfloat16_t, int8_t, float, false>,
DEQUANT_PAGED_PERHEAD<float, int8_t, float, false>, nullptr, nullptr, nullptr,
// DEQUANT_PAGED_PERHEAD<half, int4x2_t, float, false>,
// DEQUANT_PAGED_PERHEAD<bfloat16_t, int4x2_t, float, false>,
// DEQUANT_PAGED_PERHEAD<float, int4x2_t, float, false>,
DEQUANT_PAGED_PERCHANNEL<half, int8_t, float, true>,
DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int8_t, float, true>,
DEQUANT_PAGED_PERCHANNEL<float, int8_t, float, true>, nullptr, nullptr, nullptr,
// DEQUANT_PAGED_PERCHANNEL<half, int4x2_t, float, true>,
// DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int4x2_t, float, true>,
// DEQUANT_PAGED_PERCHANNEL<float, int4x2_t, float, true>,
DEQUANT_PAGED_PERHEAD<half, int8_t, float, true>,
DEQUANT_PAGED_PERHEAD<bfloat16_t, int8_t, float, true>,
DEQUANT_PAGED_PERHEAD<float, int8_t, float, true>, nullptr, nullptr, nullptr};
// DEQUANT_PAGED_PERHEAD<half, int4x2_t, float, true>,
// DEQUANT_PAGED_PERHEAD<bfloat16_t, int4x2_t, float, true>,
// DEQUANT_PAGED_PERHEAD<float, int4x2_t, float, true>};
uint32_t getDequantPagedIdx(cnnlDataType_t dtype,
int32_t quant_mode,
int32_t quant_bit,
const void *context_seq_offset) {
uint32_t idx = 0;
idx += (quant_mode != 0) ? 6 : 0;
idx += (quant_bit != 8) ? 3 : 0;
idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0;
idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0;
idx += (context_seq_offset == nullptr) ? 12 : 0;
return idx;
}
void getBlockAndDimForPaged(int32_t &seq_block,
cnrtDim3_t &task_dim,
cnrtFunctionType_t &task_type,
const int32_t max_context_len,
const int32_t head_num,
const int32_t batch_size,
const int32_t head_size,
const int32_t block_size,
const int32_t quant_mode,
const int32_t quant_bit,
const cnnlDataType_t dtype) {
int32_t core_dim;
int32_t cluster_dim;
int32_t nram_size = 480 * 1024;
int32_t wram_size = 512 * 1024;
int32_t sram_size = 2016 * 1024;
getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK);
if (quant_mode == 0) {
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1);
if (seq_block < block_size) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof_(float) should be less than "
<< nram_size / block_size << " when quant_mode is 0." << std::endl;
}
} else {
int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half);
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) +
(int64_t)head_num * head_size * dtype_size));
if (seq_block < block_size) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
"context_dtype_size) "
<< "should be less than " << nram_size / block_size << " when quant_mode is 1."
<< std::endl;
}
/* head_size * 64B put in the wram. */
if (head_size * ONE_LINE >= wram_size) {
std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than "
<< wram_size << " when quant_mode is 1." << std::endl;
}
}
// seq_block should be a multiply of block_size
seq_block = PAD_DOWN(seq_block, block_size);
int seq_seg = DIV_UP(max_context_len, seq_block);
int32_t core_num = cluster_dim * core_dim;
if (batch_size * seq_seg <= (core_num / 2)) {
int times = core_num / batch_size / seq_seg;
seq_block = std::max(seq_block / times, 2);
if (seq_block > block_size) {
seq_block = PAD_DOWN(seq_block, block_size);
} else {
seq_block = block_size;
}
seq_seg = DIV_UP(max_context_len, seq_block);
}
task_dim.x = 1;
task_dim.y = uint32_t(std::min(batch_size, core_num));
task_dim.z = uint32_t(seq_seg);
task_type = cnrtFuncTypeBlock;
}
KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch_size,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t scale_bs_stride,
const size_t scale_head_stride,
const cnnlDataType_t dtype) {
if (is_arch300()) {
std::cerr << "[invokeDequantFromPagedCache]: kernel does not support MLU300 devices."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t index;
int32_t seq_block;
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
getBlockAndDimForPaged(seq_block, k_dim, k_type, max_context_len, head_num, batch_size, head_size,
block_size, quant_mode, quant_bit, dtype);
index = getDequantPagedIdx(dtype, quant_mode, quant_bit, context_seq_offsets);
auto dequant_paged_func = DequantFromPagedCacheFuncArr[index];
dequant_paged_func<<<k_dim, k_type, queue>>>(
(void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache,
(const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens,
(const int32_t *)context_seq_offsets, (const int32_t *)block_tables, max_context_len,
max_block_num, batch_size, head_num, key_group_num, value_group_num, block_size, head_size,
seq_block, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
key_cache_seq_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,106 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
#define CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief De-quantizes the key and value tensors from the provided paged cache and scale.
* @param queue: The queue for mlu.
* @param key: Pointer to the MLU memory that stores the key tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be half
* or bfloat16. This parameter can be nullptr.
* @param value: Pointer to the MLU memory that stores the value tensor,
* with shape [total_seqlen, head_num, head_size]. Data type can be
* half or bfloat16.This parameter can be nullptr.
* @param key_cache: Pointer to the MLU memory that stores the key cache tensor,
* with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache tensor,
* with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization.
* Data type must be int8. This parameter can be nullptr.
* @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale.
* Shape depends on quantization mode:
* - For per-channel quantization (quant_mode = 0): [head_num, head_size].
* - For per-token quantization (quant_mode = 1): [total_blocks, head_num, block_size].
* Data type must be float32. This parameter can be nullptr.
* @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale,
* with the same shape as key_scale. Data type must be float32. This parameter can be
nullptr.
* @param context_lens: Pointer to the MLU memory that stores the sequence lengths.
* The shape must be [batch].
* @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the
context.
* The shape must be [batch]. If nullptr, the default value is the cumulative sum of
context_lengths.
* @param block_tables: Pointer to the MLU memory that stores the block tables for indexing.
* The shape must be [batch, max_block_num].
* @param max_contxt_len: The maximum sequence length of context.
* @param max_block_num: The maximum block number of each batch.
* @param batch: Batch size.
* @param head_num: Head number.
* @param key_group_num: group number of key group-wise quantization.
* @param value_group_num: group number of value group-wise quantization.
* @param block_size: The block size of the cache.
* @param head_size: Head size.
* @param quant_mode: An integer value indicating the quantization mode:
* 0 for per-channel quantization and 1 for per-token quantization.
* @param quant_bit: An integer value indicating the quantization bit width:
* 8 for 8-bit quantization.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param key_cache_seq_stride: The stride of cache_mem_len in key cache.
* @param value_cache_seq_stride: The stride of cache_mem_len in value cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param dtype: The data type of the key and value tensors.
* @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key.
* If any of value/value_cache/value_scale is nullptr, no operation is performed on the value.
*/
KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue,
void *key,
void *value,
const void *key_cache,
const void *value_cache,
const void *key_scale,
const void *value_scale,
const void *context_lens,
const void *context_seq_offsets,
const void *block_tables,
const int32_t max_context_len,
const int32_t max_block_num,
const int32_t batch,
const int32_t head_num,
const int32_t key_group_num,
const int32_t value_group_num,
const int32_t block_size,
const int32_t head_size,
const int32_t quant_mode,
const int32_t quant_bit,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t key_cache_seq_stride,
const size_t value_cache_seq_stride,
const size_t cache_scale_bs_stride,
const size_t cache_scale_head_stride,
const cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_

View File

@@ -0,0 +1,254 @@
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "dequantify.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
template <typename T>
struct PackValueNum {
const static int value = 1;
};
template <>
struct PackValueNum<int4x2_t> {
const static int value = 2;
};
__nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)];
__nram__ int8_t nram_buf_scale[8192];
__mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) {
__bang_int82float(dst, src, count, 0);
__bang_mul_scalar(dst, dst, scale, count);
}
__mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) {
__bang_int42float_rn(dst, src, count, 0);
__bang_mul_scalar(dst, dst, scale, count);
}
__mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) {
__bang_int82half(dst, src, count, 0);
__bang_mul_scalar(dst, dst, (half)scale, count);
}
__mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) {
__bang_int42half_rn(dst, src, count, 0);
__bang_mul_scalar(dst, dst, (half)scale, count);
}
template <typename T>
__mlu_func__ void swap(T *&ping, T *&pong) {
T *tmp = ping;
ping = pong;
pong = tmp;
}
template <typename TDst, typename TSrc>
__mlu_global__ void dequantifyPerTensor(void *all_dst,
const void *all_src,
size_t all_src_count,
float scale) {
scale = 1.0f / scale;
size_t src_per_core = all_src_count / taskDim;
size_t src_remain = all_src_count % taskDim;
size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain);
const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0);
TDst *dst = reinterpret_cast<TDst *>(all_dst) + start * PackValueNum<TSrc>::value;
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + start;
constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
128; // align to 128
constexpr int src_num_unit = size_unit / sizeof(TSrc);
constexpr int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
int8_t *nram_buf_ping = nram_buf;
int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2;
TSrc *nram_src_ping = reinterpret_cast<TSrc *>(nram_buf_ping);
TDst *nram_dst_ping =
reinterpret_cast<TDst *>(nram_buf_ping + static_cast<int>(sizeof(TSrc)) * size_unit);
TSrc *nram_src_pong = reinterpret_cast<TSrc *>(nram_buf_pong);
TDst *nram_dst_pong =
reinterpret_cast<TDst *>(nram_buf_pong + static_cast<int>(sizeof(TSrc)) * size_unit);
int loop_count = src_count / src_num_unit;
int remain_count = src_count % src_num_unit;
// L
__memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
// L C
__memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
// L C S
for (int i = 0; i < loop_count - 2; ++i) {
__memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit,
GDRAM2NRAM);
__memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
}
// C S
__memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
NRAM2GDRAM);
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
swap(nram_src_ping, nram_src_pong);
swap(nram_dst_ping, nram_dst_pong);
__sync_io_move_compute();
// S
__memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
NRAM2GDRAM);
__sync_io_move_compute();
if (remain_count > 0) {
__memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count,
GDRAM2NRAM);
convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum<TSrc>::value, scale);
__memcpy(dst + loop_count * dst_num_unit, nram_dst_ping,
sizeof(TDst) * remain_count * PackValueNum<TSrc>::value, NRAM2GDRAM);
}
}
// does not use a pipeline because per channel is more complicated but it's a one-time operation, so
// performance doesn't matter.
template <typename TDst, typename TSrc>
__mlu_global__ void dequantifyPerChannel(void *all_dst,
const void *all_src,
int src_ci,
int all_co,
const void *scale) {
const int co_per_core = all_co / taskDim;
const int co_remain = all_co % taskDim;
const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain);
const int co_count = co_per_core + (taskId < co_remain ? 1 : 0);
assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst));
constexpr int size_unit = sizeof(nram_buf) /
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
128; // align to 128
// yes, we only deal with 1 channel at a time
// no, there's no need to optimize a one-time operation
const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci);
const int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
TSrc *const nram_src = reinterpret_cast<TSrc *>(nram_buf);
TDst *const nram_dst =
reinterpret_cast<TDst *>(nram_buf + static_cast<int>(sizeof(TSrc)) * size_unit);
const TDst *nram_scale = reinterpret_cast<const TDst *>(nram_buf_scale);
const int loop_one_channel = src_ci / src_num_unit;
const int remain_one_channel = src_ci % src_num_unit;
for (int o = start_co; o < start_co + co_count; ++o) {
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + o * src_ci;
TDst *dst = reinterpret_cast<TDst *>(all_dst) + o * src_ci;
const TDst scale_value = 1. / nram_scale[o];
for (int i = 0; i < loop_one_channel; ++i) {
__memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
convert(nram_dst, nram_src, dst_num_unit, scale_value);
__memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
}
if (remain_one_channel > 0) {
__memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel,
GDRAM2NRAM);
convert(nram_dst, nram_src, remain_one_channel * PackValueNum<TSrc>::value, scale_value);
__memcpy(dst + loop_one_channel * dst_num_unit, nram_dst,
sizeof(TDst) * remain_one_channel * PackValueNum<TSrc>::value, NRAM2GDRAM);
}
}
}
} // namespace kernels
static const std::map<std::pair<int, cnnlDataType_t>,
decltype(&kernels::dequantifyPerTensor<half, int4x2_t>)>
per_tensor_func_map = {
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int4x2_t>},
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int4x2_t>},
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int8_t>},
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int8_t>},
};
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
size_t src_count,
float scale) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
CNdev dev;
cnnlGetDevice(handle, &dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
if (iter == per_tensor_func_map.end()) {
std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth
<< " dst_dtype: " << dst_dtype;
return KernelStatus::KERNEL_STATUS_FAILED;
}
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_count, scale);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
static const std::map<std::pair<int, cnnlDataType_t>,
decltype(&kernels::dequantifyPerChannel<half, int4x2_t>)>
per_channel_func_map = {
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int4x2_t>},
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int4x2_t>},
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int8_t>},
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int8_t>},
};
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
int src_ci,
int co,
const void *scale) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
CNdev dev;
cnnlGetDevice(handle, &dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
if (iter == per_channel_func_map.end()) {
std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth
<< " dst_dtype: " << dst_dtype;
return KernelStatus::KERNEL_STATUS_FAILED;
}
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_ci, co, scale);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,57 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_DEQUANTIFY_MLUH_
#define CSRC_KERNELS_DEQUANTIFY_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Dequantify per tensor.
* @param handle: The handle of cnnl.
* @param src: Input. Pointer to the MLU memory that stores the input.
* @param src_bitwidth: The bitwidth of input quantized data.
* @param dst: Output. Pointer to the MLU memory that stores the output.
* @param dst_dtype: The data type of output.
* @param src_count: The number of elements in input.
* @param scale: The scale for dequantify.
*/
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
size_t src_count,
float scale);
/**
* @brief Dequantify per channel.
* @param handle: The handle of cnnl.
* @param src: Input. Pointer to the MLU memory that stores the input.
* @param src_bitwidth: The bitwidth of input quantized data.
* @param dst: Output. Pointer to the MLU memory that stores the output.
* @param dst_dtype: The data type of output.
* @param src_ci: The ci of input.
* @param co: The co of input.
* @param scale: Pointer to the MLU memory that stores the scale for dequantify.
*/
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
const void *src,
int src_bitwidth,
void *dst,
cnnlDataType_t dst_dtype,
int src_ci,
int co,
const void *scale);
} // namespace tmo
#endif // CSRC_KERNELS_DEQUANTIFY_MLUH_

View File

@@ -0,0 +1,310 @@
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "embedding.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define MAX_UINT32 (4294967295)
#define MAX_SINT32 (2147483647)
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) {
int base = total / num;
int tail = total - base * num;
every = base + (id < tail ? 1 : 0);
offset = base * id + (id < tail ? id : tail);
}
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
template <typename T>
__mlu_func__ void embeddingImpl_500(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
// 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex <elem_count> must be divisible
// by 8
int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) /
(input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t));
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T)
T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T)
int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
int *idxs_nram = ones_nram + limit; // limit * sizeof(int)
int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int)
int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int)
uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t)
__bang_write_zero(zeros_nram, input_size);
__bang_write_zero(zeros_offset_nram, limit);
__bang_write_value(ones_nram, limit, 1);
int repeat = bs_core / limit;
int remain = bs_core % limit;
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
__bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num);
__bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num);
__bang_mul(mask_nram, mask_nram, temp_nram, num);
__bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram,
num_pad); // gather valid mask
__bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask
__bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index
__bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram,
(unsigned int)input_size * sizeof(T), num); // gather offset
__sync();
__gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T),
GDRAM2NRAM, input_size * sizeof(T), num);
__gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T),
NRAM2NRAM, input_size * sizeof(T), num);
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_func__ void write_zero(T *dst, unsigned int elem_count) {
__bang_write_zero(dst, elem_count);
}
template <>
__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) {
#if __BANG_ARCH__ >= 500
__bang_write_zero(dst, elem_count);
#endif
}
template <typename T>
__mlu_func__ void embeddingImpl_300(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
limit = PAD_DOWN(limit, 2);
int repeat = bs_core / limit;
int remain = bs_core % limit;
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
int idx1 = idxs_nram[0];
int idx2 = idxs_nram[1];
bool first = (idx1 >= vocab_start && idx1 <= vocab_end);
bool second = (idx2 >= vocab_start && idx2 <= vocab_end);
for (int n = 0; n < num / 2 * 2; n += 2) {
if (first && second) {
__memcpy_async(emb_nram + n * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T),
1);
} else if (!first && !second) {
write_zero(emb_nram + n * input_size, 2 * input_size);
} else if (first && !second) {
write_zero(emb_nram + (n + 1) * input_size, input_size);
__memcpy_async(emb_nram + n * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + n * input_size, input_size);
__memcpy_async(emb_nram + (n + 1) * input_size,
filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
}
idx1 = idxs_nram[n + 2];
idx2 = idxs_nram[n + 3];
first = (idx1 >= vocab_start && idx1 <= vocab_end);
second = (idx2 >= vocab_start && idx2 <= vocab_end);
} // copy loop
// last idx copy
if (num % 2 == 1) {
if (first) {
__memcpy_async(emb_nram + (num - 1) * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + (num - 1) * input_size, input_size);
}
}
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_func__ void embeddingImpl_generic(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
limit = PAD_DOWN(limit, 2);
int repeat = bs_core / limit;
int remain = bs_core % limit;
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
int idx = idxs_nram[0];
bool hit = (idx >= vocab_start && idx <= vocab_end);
for (int n = 0; n < num; n++) {
if (hit) {
__memcpy_async(emb_nram + n * input_size,
filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + n * input_size, input_size);
}
idx = idxs_nram[n + 1];
hit = (idx >= vocab_start && idx <= vocab_end);
}
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_global__ void MLUEmbeddingKernel(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq) {
#if __BANG_ARCH__ > 372
// __gather index maximum dtype is unsigned int
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) {
embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
} else {
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
total_seq);
}
#else
// __memcpy 2D src_stride dtype is int
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) {
embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
} else {
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
total_seq);
}
#endif
}
} // namespace kernels
KernelStatus invokeEmbedding(cnrtQueue_t queue,
void *filter,
void *input_ids,
void *output,
const cnnlDataType_t dtype,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<float *>(filter), (int *)input_ids, static_cast<float *>(output), vocab_offset,
vocab_size, total_vocab_size, input_size, total_seq);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<half *>(filter), (int *)input_ids, static_cast<half *>(output), vocab_offset,
vocab_size, total_vocab_size, input_size, total_seq);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<bfloat16_t *>(filter), (int *)input_ids, static_cast<bfloat16_t *>(output),
vocab_offset, vocab_size, total_vocab_size, input_size, total_seq);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,63 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_EMBEDDING_MLUH_
#define CSRC_KERNELS_EMBEDDING_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Look up table for ids which greater than vocab_offset and less than
* vocab_offset + vocab_size, and write the results back to the position
* corresponding to the ids. For ids that are not in the range, write 0
* to the corresponding position.
* @example
* filter:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [4, 3, 2, 1]]
* input_ids:
* [[1, 5, 6, 7, 8, 9]]
* vocab_offset = 5
* vocab_size = 3
* input_size = 4
* total_seq = 6
* output:
* [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8],
* [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]]
* @param queue: The queue for mlu.
* @param filter: Input. Pointer to the MLU memory that stores the embedding table,
* the shape must be [vocab_size, input_size].
* @param input_ids: Input. Pointer to the MLU memory that stores the token id,
* the shape must be [batch, seq].
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [batch, seq, input_size].
* @param dtype: Data type.
* @param vocab_offset: embedding table offset.
* @param vocab_size: embedding table size.
* @param total_vocab_size: total embedding table size.
* @param input_size: embedding dim.
* @param total_seq: Total sequence length.
*/
KernelStatus invokeEmbedding(cnrtQueue_t queue,
void *filter,
void *input_ids,
void *output,
const cnnlDataType_t dtype,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq);
} // namespace tmo
#endif // CSRC_KERNELS_EMBEDDING_MLUH_

View File

@@ -0,0 +1,658 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "cnnl.h"
#include "cnrt.h"
#include "fused_rope.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
using bf16 = bfloat16_t;
namespace tmo {
namespace kernels {
#ifndef PAD_UP
#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y))
#endif
#if __BANG_ARCH__ > 500
#include <bang_fusor.h>
template <typename T>
using bang_cycle_fusor = bang::experimental::cycle_fusor<T>;
#endif
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ float nram_mask[256];
__nram__ int nram_rope_offsets[256];
__nram__ float nram_zeros[1024] = {0.f};
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int num) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, num);
} else if (std::is_same<T, bf16>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(dst, (bf16 *)src, num);
#endif
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int num) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, num);
} else if (std::is_same<T, bf16>::value) {
__bang_float2bfloat16_rn((bf16 *)dst, src, num);
}
}
__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin,
int *cache_seq_offsets_begin,
int *slot_mapping_begin,
int *nram_k_cache_offsets,
int *nram_v_cache_offsets,
int *nram_v_onchip_offsets,
int *nram_kv_scale_offsets,
float *nram_cache_mask,
float *nram_zeros,
float *nram_temp,
int task_deal_batch,
int task_begin_batch,
int head_num_k,
int head_size,
int max_decode_len,
int block_size,
int kv_out_size,
int group_num,
bool discrete_batch,
bool paged_cache,
bool mixed_cache) {
// 目前先用标量化计算offset便于理解(性能无影响)
int bh = task_deal_batch * head_num_k;
if (paged_cache) {
int cache_seq_stride = head_size;
int cache_head_stride = block_size * head_size;
int cache_scale_head_stride = block_size * group_num;
int cache_block_stride = head_num_k * cache_head_stride;
int cache_scale_block_stride = head_num_k * cache_scale_head_stride;
int *nram_slot_mapping = (int *)nram_mask;
__memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
for (int i = 0; i < task_deal_batch; i++) {
int mapping_idx = __load_nram(nram_slot_mapping + i);
if (mapping_idx < 0) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int block_idx = mapping_idx / block_size;
int seq_idx = mapping_idx % block_size;
int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride;
int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride;
int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num;
int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
} else {
int *nram_seq_offsets = (int *)nram_mask;
int *nram_bs_id = nram_seq_offsets + 32;
int cache_seq_stride = head_size;
int cache_head_stride = max_decode_len * head_size;
int cache_scale_head_stride = max_decode_len * group_num;
int cache_bs_stride = head_num_k * cache_head_stride;
int cache_scale_bs_stride = head_num_k * cache_scale_head_stride;
__memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
if (discrete_batch) {
__memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
}
for (int i = 0; i < task_deal_batch; i++) {
int bs_idx = __load_nram(nram_bs_id + i);
int seq_idx = __load_nram(nram_seq_offsets + i);
int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i;
int temp_seq_idx = seq_idx;
bool masked = temp_bs_idx < 0 || temp_seq_idx < 0;
if (masked) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride;
int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num;
int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride;
int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
}
// 此处是为了做上scatter指令的 mask如果bs offset或seq offset小于0则需要mask掉
__bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0);
__bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8));
}
__mlu_func__ void layernormImpl(float *nram_k,
float *norm_params,
int task_deal_batch,
int k_hidden,
float eps) {
#if __BANG_ARCH__ > 500
float *buffer = nram_k + task_deal_batch * k_hidden;
for (int i = 0; i < task_deal_batch; i++) {
float *k_ = nram_k + i * k_hidden;
__bang_mul(buffer, k_, k_, k_hidden);
float mean = __bang_sum(k_, k_hidden);
mean = mean / k_hidden;
float rstd = __bang_sum(buffer, k_hidden);
rstd = rstd / k_hidden - mean * mean;
rstd = rstd < 0 ? eps : rstd + eps;
rstd = 1.f / std::sqrt(rstd);
__bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden);
}
__bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden,
task_deal_batch * k_hidden, k_hidden);
#endif
}
__mlu_func__ void foldRotaryImpl(float *nram_qk,
float *nram_qk_rot,
float *nram_table,
int task_deal_batch,
int head_num_qk,
int head_size) {
int rotary_low_dim = task_deal_batch * head_size;
__bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size,
head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim);
}
template <typename T>
__mlu_func__ void quantify(T *input,
float *float_input,
void *output_hp,
void *output_lp,
float *nram_trans,
float *scale_hp,
float *scale_lp,
float *scale_lp_temp,
int batch,
int head_num,
int head_size,
int group_num,
int group_size,
bool quant_kv_hp,
bool mixed_cache) {
if (quant_kv_hp) {
int hidden = head_num * head_size;
int bh = batch * head_num;
toFloat<T>(float_input, input, batch * hidden);
__bang_recip(scale_hp, scale_hp, hidden);
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"fuse.nram.crn.s8.f32 "
"[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])"
";\n\t" ::[dst] "r"(output_hp),
[num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input),
[src1] "r"(scale_hp), [pos] "i"(0));
#endif
if (mixed_cache) {
__bang_transpose(nram_trans, float_input, bh * group_num, group_size);
__bang_abs(float_input, nram_trans, bh * head_size);
__bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1,
1);
__bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num);
__bang_recip(scale_lp_temp, scale_lp, bh * group_num);
__bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size,
bh * group_num);
__bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0);
__bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num);
}
}
}
template <typename T>
__mlu_func__ void fuseRopeImpl(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int task_deal_batch,
int task_begin_batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
float eps) {
#if __BANG_ARCH__ > 500
/*
由于需要支持mixed cachekernel支持cache的功能组合比较多现规定只存在以下几种
1.只存在hp_cache的情况(通过lp tensor不为0判断)cache支持bf16fp16量化下支持离线perchannel int8
支持linear和pagedkey和value cache形状一致key/value_scale_hp形状为[head_num, head_size]
2.mixed cache的情况hp支持离线perchannel int8量化支持linear和pagedkey和value cache形状一致
key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化
key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2]
paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2,
head_size]paged cache形状为 [num_blocks, head_num_k, block_size / 2,
head_size]key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp,
group_num]paged_cache 为 [num_blocks, head_num_k, block_size, group_num]
*/
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool discrete_batch_hp = cache_bs_id_hp != nullptr;
bool discrete_batch_lp = cache_bs_id_lp != nullptr;
bool paged_cache_hp = slot_mapping_hp != nullptr;
bool paged_cache_lp = slot_mapping_lp != nullptr;
int head_num_qk = head_num_q + head_num_k;
int head_num_qkv = head_num_q + head_num_k * 2;
int qkv_hidden = head_num_qkv * head_size;
int qk_hidden = head_num_qk * head_size;
int q_hidden = head_num_q * head_size;
int k_hidden = head_num_k * head_size;
int float_size = sizeof(float);
int dtype_size = sizeof(T);
int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size;
int group_num = mixed_cache ? head_size / group_size : 1;
// task ddr offset
T *input_begin = input + task_begin_batch * qkv_hidden;
int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch;
int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch;
int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch;
// nram_buffer
float *nram_qk = (float *)nram_buffer;
float *nram_qk_rot = nram_qk + batch_cap * qk_hidden;
float *nram_v = nram_qk_rot + batch_cap * qk_hidden;
float *nram_kv_trans = nram_v + batch_cap * k_hidden;
float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden;
float *norm_params = nram_table + 2 * batch_cap * head_size;
float *nram_k_scale_hp = norm_params + 2 * head_size;
float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num;
int8_t *nram_kv_hp =
(int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num);
int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden;
int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden;
int *nram_kv_cache_offsets_hp =
(int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2);
int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k;
int *nram_v_cache_offsets_lp =
nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k;
float *cache_mask_hp =
(float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k);
float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8);
// 这里将qk和qk_rot放在一起是为升位宽可以一起做减少指令同样还有sincostable和norm的gamma和beta
T *qk_in = (T *)nram_qk_rot;
T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden);
T *v_in =
(T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden);
T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size);
T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size);
int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t);
// 生成 kv cache的offset和mask供scatter kv到kvcache使用
genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp,
nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp,
nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k,
head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1,
discrete_batch_hp, paged_cache_hp, false);
if (mixed_cache) {
int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch;
int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch;
int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch;
genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp,
nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets,
nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch,
task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp,
1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache);
}
/*
-----------------------
load v |
-----------------------
load qk | quant v
-----------------------
store v | rope qk
-----------------------
store_q | layernorm k
| quant k
-----------------------
store k |
*/
// prepare v v_scale cache_v rope_offset
__memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM,
k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1);
if (quant_kv_hp) {
__memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM);
__memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM);
}
__memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int),
GDRAM2NRAM);
__sync_io();
if (mixed_cache) {
__gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp,
head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
}
__bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size,
task_deal_batch);
__sync_compute();
/*==============================================================================================*/
// load_qk,rope_table | quant v
__memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1,
qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size,
head_num_qk - 1);
__gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size,
GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets,
head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM);
__memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM);
int8_t *nram_temp = (int8_t *)nram_qk;
if (mixed_cache) {
__bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0,
0);
__bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2);
}
quantify<T>(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp,
nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num,
group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp,
head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
__bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden);
__bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0,
0);
}
__sync_io_move_compute();
/*==============================================================================================*/
// rope | store v
// 将qk的左右部分交换用于生成qk_rot
__memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
__memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
toFloat<T>(nram_qk, qk_in, 2 * batch_cap * qk_hidden);
toFloat<T>(nram_table, table_in, 2 * task_deal_batch * head_size);
toFloat<T>(norm_params, norm_params_in, 2 * head_size);
__bang_write_value(nram_mask, head_size / 2, (float)-1);
__bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1);
foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size);
floatTo<T>((T *)nram_qk, nram_qk, task_deal_batch * q_hidden);
int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v;
__scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp,
cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp,
cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM,
head_size * sizeof(int8_t), head_num_k * task_deal_batch);
__scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets,
cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
/*==============================================================================================*/
// layernrom k quant k | store q
// 从qk的nram buffer中提取出k做layernorm和量化
float *nram_k = nram_qk_rot;
__memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM,
head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1,
task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size,
task_deal_batch - 1);
layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps);
quantify<float>(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp,
nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k,
head_size, group_num, group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0);
}
if (!quant_kv_hp) {
floatTo<T>((T *)nram_k, nram_k, task_deal_batch * k_hidden);
}
// store q
__memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size,
task_deal_batch - 1, head_size * dtype_size, head_num_q - 1,
head_size * dtype_size, task_deal_batch - 1,
task_deal_batch * head_size * dtype_size, head_num_q - 1);
// ===============================================================================================
int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k;
__scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp,
head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp,
head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t),
head_num_k * task_deal_batch);
__scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp,
group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
#endif
}
template <typename T>
__mlu_global__ void MLUFuseRope(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
int task_avg_batch,
float eps) {
int task_begin_batch = taskId * task_avg_batch;
int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch);
if (task_deal_batch <= 0 || __is_mpu()) {
return;
}
int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap;
int once_batch = (task_deal_batch + task_loop - 1) / task_loop;
for (int i = 0; i < task_loop; i++) {
int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch);
int batch_offset = task_begin_batch + once_batch * i;
fuseRopeImpl<T>(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table,
cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp,
key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp,
cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp,
rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size,
once_batch, eps);
}
}
} // namespace kernels
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps) {
if (is_arch300()) {
std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
uint32_t taskdimx = cluster_num * core_num;
int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx;
int float_size = sizeof(float);
int group_num = head_size / group_size;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
int nram_avalible_bytes = 480 * 1024;
int task_max_batch = 32;
int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1);
int nram_params_bytes = 2 * head_size * float_size;
int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size;
int nram_remain_bytes =
nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes;
int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2;
int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1);
int nram_table_bytes = 2 * head_size * float_size;
int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size;
int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size;
int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size;
int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2;
int nram_cache_offsets_hp = head_num_kv * sizeof(int);
int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int);
int batch_cap =
nram_remain_bytes /
(nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes +
nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp);
batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch;
cnrtDim3_t dim{taskdimx, 1, 1};
if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,119 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
#define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Apply query and kery rotary embedding, key layernorm and
* quantize key and value to kv cache.
* @param queue: The queue for mlu.
* @param input: Input/Output. Pointer to the MLU memory that stores the input,
* the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size].
* @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
* head_num_kv, block_size, head_size].
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
* continous. The shape must be [rotary_seq_len, rotary_dim].
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
* continous. The shape must be [rotary_seq_len, rotary_dim].
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
* batch. The shape must be [batch].
* @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm.
* @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm.
* @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
* key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr,
* that means key do not need to be quantized.
* @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
* value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr,
* that means value do not need to be quantized.
* @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
* precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
* [num_blocks, head_num_kv, block_size, group_num].
* @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
* precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
* [num_blocks, head_num_kv, block_size, group_num].
* @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch
* offset of high precision cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence
* offset of high precision cache, the shape must be [batch].
* @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch
* offset of low precision cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence
* offset of low precision cache, the shape must be [batch].
* @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
* which has shape [batch]. Data type of slot mapping must be int32_t.
* @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
* which has shape [batch]. Data type of slot mapping must be int32_t.
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
* @param batch_size: Batch size.
* @param head_num_q: Head number of query.
* @param head_num_kv: Head number of key and value.
* @param head_size: Head size. For simplify, the rotary dim must be the same as head_size.
* @param max_decode_len_hp: The maximum sequence length of high precision cache.
* @param max_decode_len_lp: The maximum sequence length of low precision cache.
* @param block_size_hp: Number of tokens per block of high precision cache.
* @param block_size_lp: Number of tokens per block of low precision cache.
* @param data_type: Data type of all inputs and outputs.
* @param eps: float number use for layernorm.
* @note: Head_num_q and head_num_kv must be in range [1, 32].
* Head_size must be in range [1, 128], and must be divided by 2.
*/
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps);
} // namespace tmo
#endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_

View File

@@ -0,0 +1,130 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "generate_alibi_slope.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
__nram__ int8_t nram_buffer[NRAM_SIZE];
__nram__ float range_1[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64};
__nram__ float range_2[64] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25,
27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51,
53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77,
79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103,
105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127};
__mlu_func__ void genRange(float *range_nram,
float *range_base_nram,
int fill_num,
int base_num,
int offset = 0) {
int loop = (fill_num + base_num - 1) / base_num;
for (int i = 0; i < loop; i++) {
int num = std::min((fill_num - i * base_num), base_num);
float *fill_nram = range_nram + i * base_num;
__bang_move(fill_nram, range_base_nram, num * sizeof(float));
__bang_add_scalar(fill_nram, fill_nram, i * base_num + offset, num);
}
}
__mlu_global__ void MLUAlibiSlopeKernel(float *alibi_slopes,
int *true_seq_lens,
int batch_num,
int head_start,
int head_num,
int head_num_total,
int max_sequence_length,
bool use_dynamic,
int closest_power_of_2,
int farthest_power_of_2,
float base,
float extra_base) {
float *range_nram = (float *)nram_buffer;
float *base_nram = range_nram + head_num;
float *slope_nram = base_nram + head_num;
float scale = 1.0;
float dynamic_base = base;
if (use_dynamic) {
float a0 = 1.0;
float a = a0 * true_seq_lens[taskIdX] / max_sequence_length;
a = std::max(a, 1.0f);
scale = powf(a, (1.0 / (head_num_total - 1)));
dynamic_base = base / scale;
}
int close_head_num = 0;
if (head_start >= closest_power_of_2) {
close_head_num = 0;
} else if (head_start + head_num <= closest_power_of_2) {
close_head_num = head_num;
} else {
close_head_num = closest_power_of_2 - head_start;
}
int far_head_num = head_num - close_head_num;
// fill range: 1, 2..., n1, 1, 3, (n - n1) * 2 - 1
if (close_head_num) {
genRange(range_nram, range_1, close_head_num, 64, head_start);
__bang_write_value(base_nram, close_head_num, dynamic_base);
}
if (far_head_num) {
genRange(range_nram + close_head_num, range_2, far_head_num, 64,
(head_start + close_head_num - closest_power_of_2) * 2);
__bang_write_value(base_nram + close_head_num, far_head_num, extra_base);
}
// base_nram ** range_nram
__bang_log(base_nram, base_nram, head_num);
__bang_mul(slope_nram, base_nram, range_nram, head_num);
__bang_pow2(slope_nram, slope_nram, head_num);
if (use_dynamic) {
__bang_mul_scalar(slope_nram, slope_nram, scale, close_head_num);
}
__memcpy(alibi_slopes + taskIdX * head_num, slope_nram, head_num * sizeof(float), NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
void *alibi_slopes,
void *true_seq_lens,
int batch_num,
int head_start,
int head_num,
int head_num_total,
int max_sequence_length,
bool use_dynamic) {
cnrtDim3_t dim{.x = (uint32_t)batch_num, .y = 1, .z = 1};
int closest_power_of_2 = pow(2, floor(log2(head_num_total)));
int farthest_power_of_2 = closest_power_of_2 * 2;
float base = pow(2, (-pow(2, -(log2(closest_power_of_2) - 3))));
float extra_base = pow(2, (-pow(2, -(log2(2 * closest_power_of_2) - 3))));
kernels::MLUAlibiSlopeKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)alibi_slopes, (int *)true_seq_lens, batch_num, head_start, head_num, head_num_total,
max_sequence_length, use_dynamic, closest_power_of_2, farthest_power_of_2, base, extra_base);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,43 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
#define CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Generate causal mask for context satge.
* @param queue: The queue for mlu.
* @param alibi_slopes: Output. Pointer to the MLU memory that stores the output, the shape must be
* [batch_num, head_num].
* @param true_seq_lens: Input. Pointer to the MLU memory that stores the actual sequence length of
* each batch, the shape must be [batch_num].
* @param batch_num: Batch number.
* @param head_start: The index of first head.
* @param head_num: Head number in this card.
* @param head_num_total: Total head number in all cards.
* @param max_sequence_length: The maximum sequence length used during training.
* @param use_dynamic: A boolean value indicates whether to use dynamic NTK.
*/
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
void *alibi_slopes,
void *true_seq_lens,
int batch_num,
int head_start,
int head_num,
int head_num_total,
int max_sequence_length,
bool use_dynamic);
} // namespace tmo
#endif // CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_

View File

@@ -0,0 +1,214 @@
#include <cstddef>
#include <iostream>
#include "cn_api.h"
#include "cnnl.h"
#include "cnrt.h"
#include "generate_mask.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
template <typename T>
__mlu_func__ void write_value(void *dst, unsigned int elem_count, T value) {
__bang_write_value(dst, elem_count, value);
}
template <>
__mlu_func__ void write_value(void *dst, unsigned int elem_count, bfloat16_t value) {
#if __BANG_ARCH__ >= 500
__bang_write_value(dst, elem_count, value);
#endif
}
// [once_len, once_len]
__nram__ int8_t nram_small[(__MLU_NRAM_SIZE__ * 1 / 4 * 1024)];
// [1 + once_len, 2 * once_len]
__nram__ int8_t nram_large[(__MLU_NRAM_SIZE__ * 2 / 4 * 1024 + 1024)];
// [once_len * 2 + 1]
__nram__ int8_t nram_tiny[2048];
template <typename T>
class GenerateMask {
constexpr static int once_len = sizeof(T) == 4 ? 160 : 256;
// [once_len, once_len]
T *nram_upper = (T *)(nram_small);
// [1 + once_len, 2 * once_len]
T *nram_buf = (T *)(nram_large);
// [once_len, once_len], reuse upper part of nram_buf
T *nram_filled = nram_buf;
// [once_len, once_len], reuse lower part of nram_buf
T *nram_zeros = nram_buf + once_len * once_len;
// [once_len]
T *nram_ones_zeros = (T *)nram_tiny;
__mlu_func__ void initBuffers(T fill_value = -10000) {
/* nram_buf:
|---once_len---||---once_len---|
0, 1, 1, 1, ..., 1, 0, 0, 0, ...
0, 0, 1, 1, ..., 1, 1, 0, 0, ...
0, 0, 0, 1, ..., 1, 1, 1, 0, ...
... */
nram_buf[0] = 0;
constexpr static int copy_size = (once_len * 2 + 1) * sizeof(T);
__memcpy(nram_buf + 1, nram_ones_zeros, copy_size, NRAM2NRAM, copy_size, 0, once_len - 1);
__memcpy(nram_upper, nram_buf, once_len * sizeof(T), NRAM2NRAM, once_len * sizeof(T),
once_len * 2 * sizeof(T), once_len - 1);
// nram_buf is nolonger needed
write_value(nram_filled, once_len * once_len, (T)fill_value);
write_value(nram_zeros, once_len * once_len, (T)0);
}
__mlu_func__ void dealOneBatch(T *output, // [max_seq_len, max_seq_len]
int max_seq_len,
int seq_len) {
/*
| once_len |
+----------+-----------------------------------+
| | | |
| upper | fill_value | |
| | | |
+----------+----------+ | |
| | | | |
| | upper | | fill |
| | | | value |
| +----------+----------+ | |
| | | | |
| | upper | | |
| 0 | | | |
| +----------+---+ |
| | u | |
|--------------------------------+---+ |
| |
| fill_value |
| |
+----------------------------------------------+
*/
int tile_count = seq_len / once_len;
int tile_remain = seq_len % once_len;
int boarder_len = max_seq_len - seq_len;
int row = 0;
for (; row < tile_count * once_len; row += once_len) {
// fill left with zeros
// assume that max_seq_len <= once_len^2
if (row > 0) {
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, once_len - 1);
}
// fill middle with upper
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, once_len * sizeof(T),
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), once_len - 1);
// fill right with fill_value
if (row + once_len < max_seq_len) {
__memcpy_async(output + (size_t)row * max_seq_len + row + once_len, nram_filled,
(max_seq_len - row - once_len) * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, once_len - 1);
}
}
if (tile_remain) {
// fill left with zeros
if (row > 0) {
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, tile_remain - 1);
}
// fill middle with upper
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, tile_remain * sizeof(T),
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), tile_remain - 1);
// fill right with fill_value
if (row + tile_remain < max_seq_len) {
__memcpy_async(output + (size_t)row * max_seq_len + row + tile_remain, nram_filled,
(max_seq_len - row - tile_remain) * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, tile_remain - 1);
}
}
if (boarder_len) {
// fill right boarder with fill_value
__memcpy_async(output + seq_len, nram_filled, boarder_len * sizeof(T), NRAM2GDRAM,
max_seq_len * sizeof(T), 0, (max_seq_len - boarder_len) - 1);
// fill bottom boarder with fill_value
__memcpy_async(output + (size_t)seq_len * max_seq_len, nram_filled, max_seq_len * sizeof(T),
NRAM2GDRAM, max_seq_len * sizeof(T), 0, boarder_len - 1);
}
__sync_io();
}
public:
__mlu_func__ void execute(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
int *batch_seq_len,
int total_batch,
int max_seq_len,
T fill_value = -10000) {
int batch_each = total_batch / taskDimY;
int batch_remain = total_batch % taskDimY;
int batch_start = taskIdY * batch_each + (taskIdY < batch_remain ? taskIdY : batch_remain);
int batch_count = batch_each + (taskIdY < batch_remain ? 1 : 0);
write_value(nram_ones_zeros, once_len, (T)fill_value);
write_value(nram_ones_zeros + once_len, once_len + 1, (T)0);
initBuffers();
for (int n = batch_start; n < batch_start + batch_count; n++) {
T *output = output_ddr + (size_t)n * max_seq_len * max_seq_len;
int seq_len = batch_seq_len[n];
dealOneBatch(output, max_seq_len, seq_len);
}
}
};
template <typename T>
__mlu_global__ void MLUUnion1GenerateMask(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
int *batch_seq_len,
int total_batch,
int max_seq_len,
T fill_value = -10000) {
if (coreId != 0) {
return; // we only use 1 core in a cluster
}
GenerateMask<T>().execute(output_ddr, batch_seq_len, total_batch, max_seq_len, fill_value);
}
} // namespace kernels
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
void *output_ddr,
int *batch_seq_len,
int total_batch,
int max_seq_len,
cnnlDataType_t data_type,
float fill_value) {
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
CNdev dev;
cnnlGetDevice(handle, &dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
cnrtDim3_t dim;
dim.x = 4;
dim.y = cluster_num;
dim.z = 1;
if (data_type == CNNL_DTYPE_FLOAT) {
kernels::MLUUnion1GenerateMask<float><<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<float *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
static_cast<float>(fill_value));
} else if (data_type == CNNL_DTYPE_HALF) {
kernels::MLUUnion1GenerateMask<half><<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<half *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
static_cast<half>(fill_value));
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeGenerateMask]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUUnion1GenerateMask<bfloat16_t><<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<bfloat16_t *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
static_cast<bfloat16_t>(fill_value));
} else {
std::cerr << "[invokeGenerateMask]: invokeGenerateMask: data_type is not supported"
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,37 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GENERATE_MASK_MLUH_
#define CSRC_KERNELS_GENERATE_MASK_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Generate causal mask for context stage.
* @param handle: The handle of cnnl.
* @param output_ddr: Output. Pointer to the MLU memory that stores the output.
* @param batch_seq_len: Input. Pointer to the MLU memory that stores the sequence length.
* @param total_batch: Batch size.
* @param max_seq_len: The maximum sequence length of context.
* @param data_type: Data type.
* @param fill_value: The fill value of the pad part.
*/
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
void *output_ddr,
int *batch_seq_len,
int total_batch,
int max_seq_len,
cnnlDataType_t data_type,
float fill_value);
} // namespace tmo
#endif // CSRC_KERNELS_GENERATE_MASK_MLUH_

View File

@@ -0,0 +1,60 @@
#include <stdexcept>
#include "cnrt.h"
#include "get_glm_position_id.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
__nram__ int nram_buffer[__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)];
__mlu_global__ void MLUBlockSliceIndividualPosId(int *context_pos_id,
int *generate_pos_id,
int batch,
int context_seq_len,
int pos_id_dimension /* 1 for 1D, 2 for 2D */) {
if (taskId != 0) return;
__memcpy(nram_buffer, context_pos_id + context_seq_len - 1, sizeof(int), GDRAM2NRAM, sizeof(int),
context_seq_len * sizeof(int), pos_id_dimension * batch - 1);
if (pos_id_dimension == 2) {
for (int i = 1; i < 2 * batch; i += 2) {
nram_buffer[i] += 1;
}
}
__memcpy(generate_pos_id, nram_buffer, pos_id_dimension * batch * sizeof(int), NRAM2GDRAM);
}
__mlu_global__ void MLUBlockIncrement2DPosId(int *generate_pos_id, int batch) {
if (taskId != 0) return;
__memcpy(nram_buffer, generate_pos_id, 2 * batch * sizeof(int), GDRAM2NRAM);
for (int i = 1; i < 2 * batch; i += 2) {
nram_buffer[i] += 1;
}
__memcpy(generate_pos_id, nram_buffer, 2 * batch * sizeof(int), NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue,
int *context_pos_id,
int *generate_pos_id,
int batch,
int context_seq_len,
int pos_id_dimension /* 1 for 1D, 2 for 2D */) {
if (pos_id_dimension != 1 && pos_id_dimension != 2) {
std::cerr << "[invokeSliceIndividualPosId]: pos_id_dimension must be 1 or 2, but got "
<< pos_id_dimension << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
cnrtDim3_t dim{4, 1, 1};
kernels::MLUBlockSliceIndividualPosId<<<dim, cnrtFuncTypeUnion1, queue>>>(
context_pos_id, generate_pos_id, batch, context_seq_len, pos_id_dimension);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch) {
cnrtDim3_t dim{4, 1, 1};
kernels::MLUBlockIncrement2DPosId<<<dim, cnrtFuncTypeUnion1, queue>>>(generate_pos_id, batch);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,59 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
#define CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Get generate position id from context position id, when position id is 2D,
* increase block position id by one.
* @example
* in GLM network, context_pos_id shape is [batch, 2, context_seq_len], data is
* [[[0, 1, 2, 2, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1]],
* [[0, 1, 2, 3, 4, 5, 5], [0, 0, 0, 0, 0, 0, 1]]]
* after invoke this kernel, the data is
* [[[2], [2]],
* [[5], [2]]]
* @param queue: The queue for mlu.
* @param context_pos_id: Input. Pointer to the MLU memory that stores the position id of
* context.
* @param generate_pos_id: Output. Pointer to the MLU memory that stores the position id of
* generate.
* @param batch: Batch size.
* @param context_seq_len: The sequence length of context.
* @param pos_id_dimension: The dimension of position id, 1 for 1D, 2 for 2D.
*/
KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue,
int *context_pos_id,
int *generate_pos_id,
int batch,
int context_seq_len,
int pos_id_dimension);
/**
* @brief Increase block position id by one in generate stage.
* @example
* in GLM network, generate_pos_id shape is [batch, 2, 1], data is
* [[[2], [1]], [[5], [1]]]
* after invoke this kernel, the data is
* [[[2], [2]], [[5], [2]]]
* @param queue: The queue for mlu.
* @param generate_pos_id: Output/Input. Pointer to the MLU memory that stores the position id of
* generate.
* @param batch: Batch size.
*/
KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch);
} // namespace tmo
#endif // CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_

View File

@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_KERNEL_UTILS_H_
#define CSRC_KERNELS_KERNEL_UTILS_H_
#include <cassert>
#include <iostream>
#include <string>
#include "cnnl.h"
#include "cnrt.h"
namespace tmo {
const std::string arch_370 = "MLU370";
enum class KernelStatus { KERNEL_STATUS_SUCCESS = 0, KERNEL_STATUS_FAILED };
#ifndef PAD_DOWN
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#endif
#ifndef PAD_UP
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
#endif
inline bool isMlu300(const std::string &dev_name) {
if (dev_name.find("MLU3") != std::string::npos) {
return true;
} else {
return false;
}
}
inline bool is_arch300() {
int card_id = -1;
cnrtDeviceProp_t dev_info;
CNRT_CHECK(cnrtGetDevice(&card_id));
CNRT_CHECK(cnrtGetDeviceProperties(&dev_info, card_id));
std::string dev_name = dev_info.name;
return isMlu300(dev_name);
}
inline bool isBf16Supported() {
return !is_arch300();
}
} // namespace tmo
#endif // CSRC_KERNELS_KERNEL_UTILS_H_

View File

@@ -0,0 +1,521 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "add_bias_activation.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
#define USE_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 20 * 1024)
#define USE_SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 20 * 1024)
__nram__ int8_t nram_buffer[USE_NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[USE_SRAM_SIZE];
__mlu_func__ void get_expert_info(int *nram_count,
int *count_sram,
uint32_t *gather_offset,
int real_inner,
int tokens_start,
int tokens_end,
int tokens_load,
int expert_deal_start,
int expert_deal_end,
uint32_t &expert_start,
uint32_t &expert_end,
uint32_t &tokens_deal_first,
int dtype_size) {
bool record_start = false;
// loop expert to find first and last deal expert in current core
for (int expert_id = expert_deal_start; expert_id <= expert_deal_end; expert_id++) {
if (__load_nram(nram_count + expert_id + 1) > tokens_start && !record_start) {
expert_start = expert_id;
tokens_deal_first =
std::min(__load_nram(nram_count + expert_id + 1) - 1, tokens_end) - tokens_start + 1;
record_start = true;
}
if (__load_nram(nram_count + expert_id + 1) > tokens_end) {
expert_end = expert_id;
break;
}
}
// record expert offset to gather bias
__bang_write_zero(gather_offset, tokens_load);
int tokens_load_total = 0;
for (int expert_id = expert_start; expert_id <= expert_end; expert_id++) {
int tokens_expand = __load_sram((int *)count_sram + expert_id);
if (expert_id == expert_start) {
tokens_expand = tokens_deal_first;
} else if (expert_id == expert_end) {
tokens_expand = tokens_load - tokens_load_total;
}
if (tokens_expand == 0) {
continue;
}
__bang_write_value(gather_offset + tokens_load_total, tokens_expand,
(int)((expert_id - expert_deal_start) * real_inner * dtype_size));
tokens_load_total += tokens_expand;
}
}
/*************** functions for compute basic operation ***************/
template <typename T>
__mlu_func__ void add_bias(T *dst_src, T *bias, int number) {
// cycle add bias
__bang_add((T *)dst_src, (T *)dst_src, (T *)bias, number);
}
template <>
__mlu_func__ void add_bias(bfloat16_t *dst_src, bfloat16_t *bias, int number) {
#if __BANG_ARCH__ > 500
__bang_add((bfloat16_t *)dst_src, (bfloat16_t *)dst_src, (bfloat16_t *)bias, number);
#endif
}
template <typename T>
__mlu_func__ void mul_left_right(T *left, T *right, int number) {
__bang_mul((T *)left, (T *)left, (T *)right, number);
}
template <>
__mlu_func__ void mul_left_right(bfloat16_t *left, bfloat16_t *right, int number) {
#if __BANG_ARCH__ > 500
__bang_mul((bfloat16_t *)left, (bfloat16_t *)left, (bfloat16_t *)right, number);
#endif
}
__mlu_func__ void do_activation(float *input_left,
float *act_space,
int number,
float active_coef,
cnnlActivationMode_t act_type) {
if (act_type == CNNL_ACTIVATION_GELU) {
__bang_active_gelu((float *)input_left, (float *)input_left, number);
} else if (act_type == CNNL_ACTIVATION_SWISH) {
float *tmp = input_left;
if (active_coef != 1.0f) {
__bang_mul_scalar(act_space, input_left, active_coef, number);
tmp = act_space;
}
__bang_active_sigmoid((float *)act_space, (float *)tmp, number);
__bang_mul((float *)input_left, (float *)input_left, (float *)act_space, number);
}
}
/*************** functions for steps of each loop ***************/
template <typename T>
__mlu_func__ void gather_bias(T *bias_sram,
T *bias_nram,
uint32_t *gather_offset,
int expert_start,
int expert_end,
int expert_deal_start,
int tokens_deal_first,
int tokens_deal,
int inner_size,
bool is_gated) {
#if __BANG_ARCH__ > 500
if (is_gated) {
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
inner_size * sizeof(T), tokens_deal);
__gather_async((T *)bias_nram + tokens_deal * inner_size, (T *)bias_sram + inner_size,
gather_offset, inner_size * sizeof(T), SRAM2NRAM, inner_size * sizeof(T),
tokens_deal);
} else {
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
inner_size * sizeof(T), tokens_deal);
}
#else
for (int i = 0; i < tokens_deal; i++) {
if (is_gated) {
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
__memcpy_async((T *)bias_nram + (tokens_deal * inner_size + i * inner_size),
(int8_t *)bias_sram + inner_size * sizeof(T) + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
} else {
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
}
}
#endif
}
template <typename T>
__mlu_func__ void loadBiasInput(T *input,
T *left,
T *right,
T *bias_nram,
T *bias_sram,
uint32_t *gather_offset,
size_t input_offset,
int tokens_deal,
int inner_size,
uint32_t expert_start,
uint32_t expert_end,
int expert_deal_start,
uint32_t tokens_deal_first,
bool is_gated,
bool has_bias) {
if (is_gated) {
// if gated, stride io load input, left/right inner to input_left/right
__memcpy_async((T *)left, (T *)input + input_offset, inner_size * sizeof(T), GDRAM2NRAM,
inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
__memcpy_async((T *)right, (T *)input + input_offset + inner_size, inner_size * sizeof(T),
GDRAM2NRAM, inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
} else {
// if not gated, load input to input_left total
__memcpy_async((T *)left, (T *)input + input_offset, tokens_deal * inner_size * sizeof(T),
GDRAM2NRAM);
}
if (has_bias) {
__sync_compute();
gather_bias((T *)bias_sram, (T *)bias_nram, gather_offset, expert_start, expert_end,
expert_deal_start, tokens_deal_first, tokens_deal, inner_size, is_gated);
}
}
template <typename T>
__mlu_func__ void computeAddActivation(T *bias_nram,
T *left_dst,
T *input_right,
float *input_left,
float *act_space,
int tokens_deal,
int inner_size,
bool is_gated,
bool has_bias,
float active_coef,
cnnlActivationMode_t act_type) {
int number = tokens_deal * inner_size;
if (has_bias) {
add_bias((T *)left_dst, (T *)bias_nram, number);
if (is_gated) {
add_bias((T *)input_right, (T *)bias_nram + tokens_deal * inner_size, number);
}
}
// cast half/bfloat16 to float to acvication, if float, left_dst is same as input_left
if (std::is_same<T, half>::value) {
__bang_half2float((float *)input_left, (half *)left_dst, number);
}
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float((float *)input_left, (bfloat16_t *)left_dst, number);
}
#endif
// activation
do_activation(input_left, act_space, number, active_coef, act_type);
// if half/bfloat16, cast float to T to mul
if (std::is_same<T, half>::value) {
__bang_float2half((half *)input_left, (float *)input_left, number);
}
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16((bfloat16_t *)input_left, (float *)input_left, number);
}
#endif
if (is_gated) {
mul_left_right((T *)input_left, (T *)input_right, tokens_deal * inner_size);
}
}
template <typename T>
__mlu_func__ void storeOutput(T *output,
T *output_nram,
size_t output_offset,
int output_stride,
int tokens_deal,
int inner_size) {
__memcpy_async((T *)output + output_offset, (T *)output_nram, inner_size * sizeof(T), NRAM2GDRAM,
output_stride * sizeof(T), inner_size * sizeof(T), tokens_deal - 1);
}
template <typename T>
__mlu_global__ void MLUAddBiasActivationKernel(T *output,
const T *input,
const T *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef) {
// if bias and token_count is nullptr, not add bias, only activation and gated mul.
bool has_bias = (bias != nullptr);
// 1. distrubute nram space
/* if not gated
---------------------------- ----------------------------------
| nram_token_count | bias_ping/pong |
| num_expert * sizeof(int) | 2 * x * inner_size * sizeof(T) |
---------------------------- ----------------------------------
--------------------------------------
| input_ping/pong |
| 2 * x * inner_size * sizeof(float) |
--------------------------------------
------------------------------------------------- -------------------
| act_space | gather_offset |
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
------------------------------------------------- -------------------
*/
/* if gated
---------------------------- ----------------------------------
| nram_token_count | bias_ping/pong |
| num_expert * sizeof(int) | 2 * x * real_inner * sizeof(T) |
---------------------------- ----------------------------------
-------------------------------------- ----------------------------------
| input_left_ping/pong | input_right_ping/pong |
| 2 * x * inner_size * sizeof(float) | 2 * x * inner_size * sizeof(T) |
-------------------------------------- ----------------------------------
------------------------------------------------- -------------------
| act_space | gather_offset |
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
------------------------------------------------- -------------------
*/
// distribute sram
int8_t *count_sram = (int8_t *)sram_buffer;
int8_t *bias_sram = (int8_t *)count_sram + num_expert * sizeof(int);
// distrubute nram
int real_inner = (is_gated) ? inner_size * 2 : inner_size;
int bias_nram_size = (has_bias) ? real_inner * sizeof(T) : 0;
int act_space_size = (act_type == CNNL_ACTIVATION_GELU) ? 0 : inner_size * sizeof(float);
int gated_ext_size = is_gated ? sizeof(T) : 0;
int max_token_deal = (USE_NRAM_SIZE - (num_expert + 1) * sizeof(int)) /
(2 * inner_size * (sizeof(float) + gated_ext_size) + act_space_size +
2 * bias_nram_size + sizeof(int));
int8_t *nram_count = (int8_t *)nram_buffer;
int8_t *bias_nram = (int8_t *)nram_count + (num_expert + 1) * sizeof(int);
int8_t *input_left = (int8_t *)bias_nram + 2 * max_token_deal * bias_nram_size;
int8_t *input_right =
(int8_t *)input_left + 2 * ((is_gated) ? max_token_deal * inner_size * sizeof(float) : 0);
int8_t *act_space = (int8_t *)input_right +
2 * max_token_deal * inner_size * (is_gated ? sizeof(T) : sizeof(float));
int8_t *gather_offset = (int8_t *)act_space + max_token_deal * act_space_size;
// 2. cusum_token_count load to nram, because need to reuse in load bias.
if (has_bias) {
__memcpy((int *)nram_count, (int *)cusum_token_count, (num_expert + 1) * sizeof(int),
GDRAM2NRAM);
if (taskIdX == 0) {
__bang_sub((int *)bias_nram, (int *)nram_count + 1, (int *)nram_count, num_expert);
__sync();
__memcpy((int *)count_sram, (int *)bias_nram, num_expert * sizeof(int), NRAM2SRAM);
}
__sync_cluster();
}
// 3. sram loop to compute
// compute once load bias to sram due to sram_limit
int max_expert_deal = (USE_SRAM_SIZE - num_expert * sizeof(int)) / (real_inner * sizeof(T));
int real_expert = cusum_token_count == nullptr ? num_expert : expert_size;
int sram_loop_rem = real_expert % max_expert_deal;
int sram_loop = real_expert / max_expert_deal + (int)(sram_loop_rem != 0);
if (!has_bias) {
max_expert_deal = real_expert;
sram_loop = 1;
sram_loop_rem = 0;
}
for (int deal_loop = 0; deal_loop < sram_loop; deal_loop++) {
// load current bias, compute each core deal number
int expert_deal =
(deal_loop == (sram_loop - 1) && sram_loop_rem != 0) ? sram_loop_rem : max_expert_deal;
int expert_deal_start = deal_loop * max_expert_deal + start_expert_id;
int expert_deal_end = expert_deal_start + expert_deal - 1;
__sync_all();
if (has_bias && __is_mpu()) {
__memcpy((T *)bias_sram, (T *)bias + deal_loop * max_expert_deal * real_inner,
expert_deal * real_inner * sizeof(T), GDRAM2SRAM);
}
__sync_all();
// get tokens info of each core
int tokens_total_cur = total_tokens;
if (has_bias) {
tokens_total_cur =
__load_nram((int *)nram_count + expert_deal_end + 1) -
(start_expert_id == 0 ? 0 : __load_nram((int *)nram_count + expert_deal_start));
} else if (cusum_token_count != nullptr) {
tokens_total_cur = __load_gdram(cusum_token_count + expert_deal_end + 1) -
__load_gdram(cusum_token_count + expert_deal_start);
}
if (sram_loop != 1) {
tokens_total_cur = ((int *)nram_count)[expert_deal_end + 1] -
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
if (deal_loop == 0 && start_expert_id != 0) {
tokens_total_cur -= __load_nram((int *)nram_count + expert_deal_start);
}
}
int tokens_core_rem = tokens_total_cur % taskDim;
int tokens_cur_core = tokens_total_cur / taskDim + (taskId < tokens_core_rem);
if (tokens_cur_core == 0) {
continue;
}
// if ep, input start in current token, have a real start in total network
int real_start =
cusum_token_count != nullptr ? __load_gdram((int *)cusum_token_count + start_expert_id) : 0;
int tokens_core_start = tokens_cur_core * taskId +
(taskId < tokens_core_rem ? 0 : tokens_core_rem) +
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
if (deal_loop != 0) {
tokens_core_start -= real_start;
}
uint32_t expert_start = 0;
uint32_t expert_end = 0;
uint32_t tokens_deal_first = 0;
// 4. nram loop compute
int nram_loop_rem = tokens_cur_core % max_token_deal;
int nram_loop = tokens_cur_core / max_token_deal + (int)(nram_loop_rem != 0);
int tokens_load = max_token_deal;
int tokens_compute = max_token_deal;
int tokens_store = max_token_deal;
for (int loop = 0; loop < nram_loop + 2; loop++) {
int inner_io_offset = (loop % 2) * max_token_deal * inner_size;
int inner_com_offset = ((loop + 1) % 2) * max_token_deal * inner_size;
int real_io_offset = (loop % 2) * max_token_deal * real_inner;
int real_com_offset = ((loop + 1) % 2) * max_token_deal * real_inner;
if (nram_loop_rem != 0) {
if (loop > 1 && (loop - 2) == (nram_loop - 1)) {
tokens_store = nram_loop_rem;
}
if (loop > 0 && (loop - 1) == (nram_loop - 1)) {
tokens_compute = nram_loop_rem;
}
if (loop == (nram_loop - 1)) {
tokens_load = nram_loop_rem;
}
}
int tokens_cur_start = tokens_core_start + loop * max_token_deal;
int tokens_cur_end = tokens_cur_start + tokens_load - 1;
// get current load info
if (loop < nram_loop && has_bias) {
get_expert_info((int *)nram_count, (int *)count_sram, (uint32_t *)gather_offset, real_inner,
tokens_cur_start + real_start, tokens_cur_end + real_start, tokens_load,
expert_deal_start, expert_deal_end, expert_start, expert_end,
tokens_deal_first, sizeof(T));
}
// store
if (loop > 1) {
size_t output_offset = (tokens_core_start + (loop - 2) * max_token_deal) * output_stride;
storeOutput((T *)output, (T *)((float *)input_left + inner_io_offset), output_offset,
output_stride, tokens_store, inner_size);
}
// compute
if (loop > 0 && loop <= nram_loop) {
T *left_dst = (T *)((float *)input_left + inner_com_offset) +
((std::is_same<T, float>::value) ? 0 : tokens_compute * inner_size);
computeAddActivation((T *)bias_nram + real_com_offset, (T *)left_dst,
(T *)input_right + inner_com_offset,
(float *)input_left + inner_com_offset, (float *)act_space,
tokens_compute, inner_size, is_gated, has_bias, active_coef, act_type);
}
// load
if (loop < nram_loop) {
T *left_dst = (T *)((float *)input_left + inner_io_offset) +
((std::is_same<T, float>::value) ? 0 : tokens_load * inner_size);
size_t input_offset = tokens_cur_start * real_inner;
loadBiasInput((T *)input, (T *)left_dst, (T *)input_right + inner_io_offset,
(T *)bias_nram + real_io_offset, (T *)bias_sram, (uint32_t *)gather_offset,
input_offset, tokens_load, inner_size, expert_start, expert_end,
expert_deal_start, tokens_deal_first, is_gated, has_bias);
}
__sync();
}
}
}
} // namespace kernels
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
cnnlDataType_t dtype,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef) {
if (bias != NULL && cusum_token_count == NULL) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
<< "when have bias, cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (act_type != CNNL_ACTIVATION_GELU && act_type != CNNL_ACTIVATION_SWISH) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
<< "activation mode only supports gelu and swish now.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)output, (const float *)input, (const float *)bias, cusum_token_count, num_expert,
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
active_coef);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)output, (const half *)(input), (const half *)bias, cusum_token_count, num_expert,
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
active_coef);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)bias,
cusum_token_count, num_expert, total_tokens, inner_size, output_stride, is_gated, act_type,
start_expert_id, expert_size, active_coef);
} else {
std::cerr << "[invokeGroupAddBiasActivationKernel]: add_bias_activation data_type not support, "
<< "only support float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,84 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
#define CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Add bias and activate to all tokens. Different expert with different bias.
* If in gated mode, add bias to all input. But only activate in the
* input of [:, :inner_size]. Then multiply it to the input of [:, inner_size:].
* Else, add bias and activate to all input, has no multiply process.
* @example
* is_gated = true, num_expert = 4, total_tokens = 6, inner_size = 2
* input: (6, 4) = [[2, 4, 5, 6], [1, 4, 5, 3],
* [3, 5, 7, 8], [6, 8, 5, 3],
* [2, 3, 4 ,5], [2, 9, 2, 3]]
* bias: (4, 4) = [[1, 0, 1, 0], [0, 1, 2, 2], [2, 3, 2, 3], [1, 2, 3, 4]]
* token_count = [2, 2, 1, 1]
* first step: add bias
* [[2+1, 4+0, 5+1, 6+0], [1+1, 4+0, 5+1, 3+0],
* [3+0, 5+1, 7+2, 8+2], [6+0, 8+1, 5+2, 3+2],
* [2+2, 3+3, 4+2, 5+3], [2+1, 9+2, 2+3, 3+4]]
* second step: act and mul
* output: (6, 2) = [[act(3)*6, act(4)*6], [act(2)*6, act(4)*3],
* [act(3)*9, act(6)*10], [act(6)*7, act(9)*5],
* [act(4)*6, act(6)*8], [act(3)*5, act(11)*7]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* When is_gated is true, The shape is [total_tokens, input_size / 2].
* In this case, the input_size must be even. Otherwise the shape is [total_tokens,
* input_size]. The memory can be discontinuous in total_tokens dim. The stride is output_stride.
* @param input: Input. Pointer to the MLU memory that stores the input tokens.
* The shape is [total_tokens, input_size].
* When is_gated is true, the shape is [total_tokens, 2 * inner_size].
* Otherwise the shape is [total_tokens, inner_size].
* @param bias: Input. Pointer to the MLU memory that stores the bias. The memory must be
* continuous. When is_gated is true, the shape is [num_expert, 2 * inner_size]. Otherwise the shape
* is [num_expert, inner_size]. Bias can be nullptr. If bias is nullptr, has no add bias process.
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of token
* counts. The shape is [num_expert + 1]. If cusum_token_count
* is not nullptr, cusum_token_count, start_expert_id and
* expert_size together determine which tokens to process.
* If cusum_token_count is nullptr, process all tokens,
* the number of which is total_tokens. When bias is not nullptr,
* cusum_token_count must also not be nullptr.
* @param num_expert: The number of expert.
* @param total_tokens: The total number of tokens.
* @param inner_size: The inner size of output.
* @param output_stride: The stride of output, must be greater than or equal to inner_size.
* @param dtype: Data type.
* @param is_gated: Gated or not.
* @param act_type: The type of activation. Support gelu and swish.
* @param start_expert_id: The index of the start expert.
* @param expert_size: The number of experts to process.
* @param active_coef: The coefficient used in the swish activation.
*/
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
cnnlDataType_t dtype,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_

View File

@@ -0,0 +1,646 @@
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
#include "cast_gating.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
#define NRAM_BUFFER_SIZE (496 * 1024)
#define WRAM_BUFFER_SIZE (512 * 1024)
#define SRAM_BUFFER_SIZE (2032 * 1024)
#ifndef ONE_LINE
#define ONE_LINE 64
#endif
#ifndef LT_NUM
#define LT_NUM 64
#endif
struct castGatingTileInfo {
int32_t block = 64;
int32_t split_k_num = 8;
int32_t block_k = 256;
};
namespace kernels {
#pragma bang walign(16)
#ifndef ROW_PER_LT
#define ROW_PER_LT 4
#endif
#ifndef LT_SIZE
#define LT_SIZE 16
#endif
#ifndef WRAM_LT_MAP16_STRIDE
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
#endif
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
do { \
uint32_t align_num = 64 / src_dsize; \
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
uint32_t rem = size % 64; \
if (n) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
\
if (rem) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
} while (false)
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
}
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
convert_type) \
int align_n = PAD_DOWN(n, LT_NUM); \
int sn0 = ONE_LINE; \
int size_sn0 = sn0 / src_dsize; \
int sn1 = ONE_LINE / src_dsize; \
int ss1 = total_k * src_dsize; \
int sn3 = k / size_sn0; \
int sn4 = align_n / sn1; \
int ss4 = sn1 * ss1; \
int dn0 = sn0; \
int dn1 = ROW_PER_LT; \
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
int ds1 = dst_k * dst_dsize; \
int dn2 = sn1 / ROW_PER_LT; \
int ds2 = WRAM_LT_MAP16_STRIDE; \
int ds3 = sn0 * dst_dsize / src_dsize; \
int dn4 = LT_SIZE / dn2; \
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
int dn5 = align_n / LT_NUM; \
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
int rem_k = k % size_sn0; \
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
if (align_n > 0 && sn3 > 0) { \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
sram_src += align_n * total_k; \
wram_dst += align_n / LT_SIZE * dst_k; \
} \
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
if (align_n > 0 && sn3 > 0) { \
sn1 = align_n; \
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
[dst_s3] "r"(ds3)); \
sram_src += align_n * total_k; \
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
} \
if (rem_k > 0) { \
align_n = PAD_UP(n, LT_NUM); \
sn0 = rem_k * src_dsize; \
dn0 = sn0; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst2), \
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
[dst_s5] "r"(0)); \
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
half *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
bfloat16_t *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
template <typename T>
__mlu_func__ void warp_prompt_weight(T *wram_dst,
T *sram_src,
int32_t n,
int32_t len_k,
int32_t total_k) {
int32_t type_size = sizeof(T);
int32_t data_size = len_k * type_size;
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
int32_t ss0 = total_k * type_size;
int32_t count = n / LT_NUM;
for (int32_t i = 0; i < count; ++i) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
}
count = n % LT_NUM / ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
}
count = n % ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
}
}
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
const int32_t &taskid,
const int32_t &taskdim,
int32_t &task_offset,
int32_t &num_cur_task) {
int32_t num_per_task = num_total_task / taskdim;
int32_t rem_idx = num_total_task % taskdim;
if (taskid < rem_idx) {
task_offset = taskid * (num_per_task + 1);
num_cur_task = num_per_task + 1;
} else {
task_offset = taskid * num_per_task + rem_idx;
num_cur_task = num_per_task;
}
}
__mlu_func__ void bidirectionBarrierOp() {
int32_t bcnt = coreDim + 1;
if (__is_ipu()) {
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
} else {
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
}
}
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
}
__mlu_func__ void warp_store(void *ddr_dst,
void *nram_src,
const int32_t data_num,
const int32_t dst_stride,
const int32_t src_stride,
const int32_t count,
const int32_t dt_size) {
if (src_stride == data_num && dst_stride == data_num) {
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
} else {
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
src_stride * dt_size, count - 1);
}
}
template <typename Tc, typename Tcc>
__mlu_func__ void splitKReduce(Tcc *workspace,
Tc *output,
int32_t M,
int32_t N,
int32_t split_k_num,
int32_t ldc) {
int32_t offset_m, cta_m;
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
if (cta_m <= 0) return;
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
Tc *output_ddr = (Tc *)output + offset_m * ldc;
for (int32_t i = 0; i < repeat; i++) {
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
int32_t data_size = N * sizeof(Tc);
int32_t data_num = current_m - 1;
if (ldc == N) {
data_size = current_m * N * sizeof(Tc);
data_num = 0;
}
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
split_k_num, 1, 1, 1);
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
N * sizeof(Tc), data_num);
workspace_ddr = workspace_ddr + block_m * N;
output_ddr = output_ddr + block_m * ldc;
}
}
template <typename Ta,
typename Tac,
typename Tb,
typename Tbc,
typename Tc,
typename Tcc,
bool EXCHANGE_AB>
__mlu_global__ void MLUCastGating(Ta *A,
Tb *B,
Tc *C,
Tcc *workspace,
int32_t M,
int32_t N,
int32_t K,
int32_t lda,
int32_t ldb,
int32_t ldc,
castGatingTileInfo split_info) {
#if __BANG_ARCH__ >= 500
int32_t block_k = split_info.block_k;
int32_t grid_dimx = split_info.split_k_num;
int32_t block = split_info.block;
int32_t grid_idx = clusterId % grid_dimx;
int32_t grid_idy = clusterId / grid_dimx;
int32_t offset_k = 0, problem_k = 0;
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
int32_t warp_m = cta_m, warp_offset_m = 0;
int32_t warp_n = cta_n, warp_offset_n = 0;
int32_t outer_loop = 0, outer_rem = 0;
if (EXCHANGE_AB) {
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
if (cta_n > block && cta_n % block != 0) {
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_n + block - 1) / block;
outer_rem = cta_n % block == 0 ? block : cta_n % block;
} else {
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
if (cta_m > block && cta_m % block != 0) {
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_m + block - 1) / block;
outer_rem = cta_m % block == 0 ? block : cta_m % block;
}
int32_t size_nram_buf =
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
Tac *nbuf_a = (Tac *)nram_buffer;
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
Ta *sbuf_a = (Ta *)sram_buffer;
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
Tbc *wbuf_b = (Tbc *)wram_buffer;
int32_t a_dsize = sizeof(Ta);
int32_t b_dsize = sizeof(Tb);
int32_t k_loop_count = 0;
for (int32_t j = 0; j < outer_loop; j++) {
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
if (EXCHANGE_AB) {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
} else {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
}
int32_t compute_total = warp_m * warp_n;
if (compute_total > 0 && __is_ipu()) {
if (!EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
}
__bang_write_zero((Tcc *)nbuf_c, compute_total);
}
int32_t i = 0;
for (; i < k_loop; i++) {
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
cta_k = i == k_loop - 1 ? rem_k : block_k;
if (__is_mpu()) {
if (EXCHANGE_AB) {
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
} else {
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
}
a_ddr = (Ta *)a_ddr + block_k;
b_ddr = (Tb *)b_ddr + block_k;
}
bidirectionBarrierOp();
if (__is_ipu() && compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__sync_io_move_compute(false, false, true, false, true, false);
if (i >= 1) {
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
}
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
// mvdma bound for EXCHANGE_AB when n==32
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
}
k_loop_count += 1;
}
if (compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
if (EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
}
int32_t total_offset =
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
: (offset_m + warp_offset_m + block * j) * N);
Tcc *wks = (Tcc *)workspace + total_offset;
int32_t store_c_size = sizeof(Tcc);
int8_t *store_ddr = (int8_t *)wks;
int32_t dst_str = EXCHANGE_AB ? M : N;
if (grid_dimx == 1) {
// convert Tcc to Tc
dst_str = ldc;
store_ddr =
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
: (offset_m + warp_offset_m + block * j) * ldc));
}
__asm__ volatile("sync.psimd.cio;\n\t");
if (EXCHANGE_AB) {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
} else {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
}
}
}
if (grid_dimx != 1) {
__sync_all();
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
split_info.split_k_num, ldc);
}
#endif // __BANG_ARCH__ >= 500
}
} // namespace kernels
int32_t getBlock(int32_t m,
int32_t n,
int32_t core_num,
int32_t block_k,
int32_t a_dtype_size,
int32_t b_dtype_size,
int32_t compute_dtype_size,
bool EXCHANGE_AB) {
int32_t block = 0;
if (EXCHANGE_AB) {
int32_t block_m = n;
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
(2 * block_m * compute_dtype_size) * core_num;
int32_t wram_block_n =
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
int32_t sram_block_n =
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
return block_n > 0 ? block_n : block_n_tmp;
} else {
int32_t block_n = n;
int32_t nram_block_m =
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
int32_t sram_block_m =
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
return block;
}
}
void gatingTiling(int32_t m,
int32_t n,
int32_t k,
size_t a_dtype_size,
size_t b_dtype_size,
size_t compute_dtype_size,
size_t workspace_size,
int32_t union_number,
int32_t core_num,
int32_t &block,
int32_t &split_k_num,
int32_t &block_k,
bool &EXCHANGE_AB) {
block_k = std::min(k, int32_t(512 / a_dtype_size));
split_k_num = 1;
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
if (m >= core_num * LT_NUM &&
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
EXCHANGE_AB = true;
}
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
compute_dtype_size, EXCHANGE_AB);
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
block = tmp_block;
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
for (int32_t i = total_blocks; i <= union_number; i++) {
if (union_number % i == 0) {
int32_t tmp_split_k = union_number / i;
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
if (workspace_size >= workspace_size_need) {
split_k_num = tmp_split_k;
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
if (EXCHANGE_AB && block > LT_NUM * core_num) {
block = PAD_DOWN(block, LT_NUM * core_num);
}
break;
}
}
}
}
}
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
}
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes) {
if (is_arch300()) {
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (expert_num > 128) {
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
std::cerr
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace_size_bytes > 0 && workspace == NULL) {
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
"greater than 0."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t union_number, core_num;
getContxtInfo(&union_number, &core_num);
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
cnrtDim3_t dim;
dim.x = (int32_t)func_type;
dim.y = 1;
dim.z = 1;
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
castGatingTileInfo split_info;
bool EXCHANGE_AB = false;
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
workspace_size_bytes, union_number, core_num, split_info.block,
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else if (a_dtype == CNNL_DTYPE_HALF) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, half, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<half, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else {
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,50 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_CAST_GATING_MLUH_
#define CSRC_KERNELS_CAST_GATING_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Convert input to float32 and do gating operation.
* @param queue: The queue for mlu.
* @param input: Input. Pointer to the MLU memory that stores the input,
* the shape must be [input_row, hidden_size].
* @param filter: Input. Pointer to the MLU memory that stores the weight,
* the shape must be [expert_num, hidden_size].
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [input_row, expert_num].
* @param input_row: Input.
* @param expert_num: Input.
* @param hidden_size: Input.
* @param a_dtype: Input. The data-type of input.
* @param workspace: Input. Pointer to the MLU workspace.
* @param workspace_size_bytes: Input. The size of workspace in bytes.
* @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF.
* expert_num must be in range [1, 128].
* If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024.
* The data-type of filter and output must be float.
* cast_gating only supports MLU500 device or higher.
*/
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes);
} // namespace tmo
#endif // CSRC_KERNELS_CAST_GATING_MLUH_

View File

@@ -0,0 +1,760 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <iostream>
#include "cnrt.h"
#include "combine_result.mluh"
// clang-format off
#include <bang_device_functions_extra.h>
#include <mlu.h>
// clang-format on
#if __BANG_ARCH__ >= 592
#include <bang_fusor.h>
template <typename SrcT>
using bang_fusor = bang::experimental::fusor<SrcT>;
#endif
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
template <typename T>
__mlu_func__ void swap(T *&ping, T *&pong) {
T *temp = ping;
ping = pong;
pong = temp;
}
#define GATHER_ASYNC_IO0(offset_type) \
__asm__ __volatile__( \
"gather.vector.async.nram.gdram.nram." #offset_type \
".io0 [%[dst]], [%[src]], [%[offset]], " \
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
#define FUSE_MUL_CVT(dst_dtype) \
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1]," \
" %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
#define FUSE_MULADD_CVT(dst_dtype) \
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int count) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int count) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
__mlu_func__ void loadAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
#endif
}
__mlu_func__ void storeAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
#endif
}
template <typename T_IDX>
__mlu_func__ void gatherTokensAsync(void *dst,
void *src_gdram,
T_IDX *nram_offset,
int transfer_size,
int token_count) {
if (token_count <= 0 || src_gdram == nullptr) return;
#if __BANG_ARCH__ > 500
if (std::is_same<T_IDX, uint32_t>::value) {
GATHER_ASYNC_IO0(u32);
} else {
GATHER_ASYNC_IO0(u64);
}
#else
for (int k = 0; k < token_count; k++) {
__memcpy_async((int8_t *)dst + k * transfer_size,
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
}
#endif
}
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
int *nram_mask,
uint8_t *nram_mask_char,
int *nram_mask_buffer,
int begin_expert_acc_tokens,
int end_expert_acc_tokens,
int token_count,
bool expert_parallelism) {
if (!expert_parallelism) {
return token_count;
}
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
#if __BANG_ARCH__ >= 592
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
.ge(begin_expert_acc_tokens)
.land(nram_mask_buffer)
.cvt<float>(0);
#else
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
#endif
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
int active_token_count = __bang_count((float *)nram_mask, token_count);
return active_token_count;
}
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
int *nram_idx,
uint64_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
#endif
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
}
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
int *nram_idx,
uint32_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
token_count);
}
template <typename T_IDX>
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
T_IDX *nram_bias_offset,
int *nram_token_idx,
int *nram_expert_tables,
int expert_num,
int token_count,
int active_token_count,
int hidden_size,
int local_hidden_begin,
int dtype_size,
int start_expert_id,
int expert_size,
int begin_expert_acc_tokens,
bool has_bias) {
// for large tensor, convert int322int64 then do multiply and add seperately.
if (active_token_count <= 0) return;
if (has_bias) {
int *nram_bias_offset_temp = (int *)nram_token_offset;
__bang_write_zero(nram_bias_offset, active_token_count);
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
active_token_count);
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
active_token_count);
}
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
}
int64_t offset =
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
active_token_count);
}
template <typename T>
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MUL_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MUL_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
}
#else
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
template <typename T>
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MULADD_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MULADD_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
}
#else
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
// weightedReduceSum with EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
int32_t index[32];
float reg_weight[128];
int8_t *index_ = (int8_t *)index;
int topk_divide_4 = PAD_UP(topk, 4) / 4;
int token_use_count = 0;
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
for (int i = 0; i < topk_divide_4; i++) {
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
float *weight_begin = weight + t_i * topk + i * 4;
reg_weight[i * 4] = __load_nram(weight_begin);
if (i * 4 + 1 < topk) {
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
}
if (i * 4 + 2 < topk) {
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
}
if (i * 4 + 3 < topk) {
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
}
}
int first_in_expert = 0;
float expert_coeff;
for (; first_in_expert < topk - 1; first_in_expert++) {
bool in_expert_range = index_[first_in_expert];
if (!in_expert_range) continue;
expert_coeff = reg_weight[first_in_expert];
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
token_use_count++;
break;
}
if (first_in_expert == topk - 1) {
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
__bang_write_zero((T *)output_begin, hidden_size);
}
} else {
for (int j = first_in_expert + 1; j < topk - 1; j++) {
bool in_expert_range = index_[j];
if (!in_expert_range) continue;
expert_coeff = reg_weight[j];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
}
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
}
}
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
// weightedReduceSum without EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
if (topk == 1) {
for (int i = 0; i < token_count; i++) {
float expert_coeff = __load_nram(weight + i);
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
}
return;
}
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
float expert_coeff = __load_nram(weight + t_i * topk);
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + 1);
for (int k_i = 2; k_i < topk; k_i++) {
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + k_i);
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
}
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
template <typename T, typename T_IDX>
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
T *input,
T *bias,
T *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {
if (__is_mpu()) {
return;
}
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
int task_avg_tokens = num_token / taskDimY;
int task_remain_tokens = num_token % taskDimY;
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
if (local_hidden_size <= 0) return;
if (task_tokens <= 0) return;
constexpr int int32_dtype_size = (int)sizeof(int);
constexpr int fp32_dtype_size = (int)sizeof(float);
int pad_num_expert = PAD_UP(num_expert + 1, 32);
bool has_bias = bias != nullptr;
bool has_residual = residual != nullptr;
bool using_acc_sum = cusum_token_count != nullptr;
bool expert_parallelism = expert_size < num_expert;
int block_size = TOKEN_BLOCK * topk;
int pad_block_size = PAD_UP(block_size, 64);
int *nram_expert_tables = (int *)nram_buffer;
int *nram_token_idx = nram_expert_tables + pad_num_expert;
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
float *nram_weight_ping =
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
float *nram_weight_pong = nram_weight_ping + pad_block_size;
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
int *nram_mask_buffer = (int *)nram_token_offset;
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
int begin_expert_acc_tokens = 0;
int end_expert_acc_tokens = num_token * topk;
if (using_acc_sum) {
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
GDRAM2NRAM);
}
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
init_token_count * sizeof(int), GDRAM2NRAM);
__sync_io();
if (expert_parallelism) {
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
}
int active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, init_token_count, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
init_token_count, active_token_count, hidden_size, local_hidden_begin,
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
__sync_io_move_compute(true, false, false, false, false, true);
__sync_io_move_compute(false, false, true, true, false, false);
int next_active_token_count = active_token_count;
int previous_global_token_begin = 0;
int previous_token_count = 0;
bool is_ping = false;
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
bool is_last_loop = next_token_begin >= task_tokens;
bool is_last_2_loop = next_next_token_begin >= task_tokens;
int current_token_begin = task_begin * TOKEN_BLOCK;
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
int current_global_token_begin = task_token_begin + current_token_begin;
int next_global_token_begin = task_token_begin + next_token_begin;
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
if (!is_last_loop) {
if (!is_last_2_loop) {
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
next_next_token_count * topk * sizeof(int), 0, 0, 0);
}
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
if (has_residual) {
loadAsync2d(nram_residual_ping,
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
hidden_size * sizeof(T), next_token_count - 1);
}
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
local_hidden_size * sizeof(T), next_active_token_count);
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
local_hidden_size * sizeof(T), next_active_token_count);
}
if (task_begin >= 1) {
storeAsync2d(
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
if (task_begin >= 0) {
if (has_bias && active_token_count) {
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
active_token_count * local_hidden_size);
}
if (expert_parallelism) {
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
} else {
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
}
if (has_residual) {
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
current_token_count * local_hidden_size);
}
}
__sync_io_move_compute();
active_token_count = next_active_token_count;
if (expert_parallelism && !is_last_loop) {
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
}
if (!is_last_2_loop) {
next_active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
begin_expert_acc_tokens, has_bias);
}
swap(nram_input_ping, nram_input_pong);
swap(nram_bias_ping, nram_bias_pong);
swap(nram_residual_ping, nram_residual_pong);
swap(nram_weight_ping, nram_weight_pong);
swap(nram_output_ping, nram_output_pong);
previous_global_token_begin = current_global_token_begin;
previous_token_count = current_token_count;
}
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
#if __BANG_ARCH__ < 500
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
#endif
} // namespace kernels
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype) {
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
std::cerr << "[invokeMoeCombineResultKernel]: "
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (bias != nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
<< "cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
size_t data_bytes = 0;
cnnlGetSizeOfDataType(dtype, &data_bytes);
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
int convert_buffer = data_bytes == 2
? 3 * hidden_size * data_bytes
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
int max_input_size = (432 * 1024 - convert_buffer) /
(2 * topk * data_bytes + /*input size, double buffer*/
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
2 * data_bytes); /*output size, one buffer*/
int TOKEN_BLOCK = 1;
int HIDDEN_BLOCK = 1;
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
TOKEN_BLOCK = 1;
} else {
HIDDEN_BLOCK = hidden_size;
}
// for latency case, hidden_size is large but token is small.
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
}
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
task_dim_x =
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
uint32_t pad_dim_x = task_dim_x;
while (pad_dim_x <= cluster_num * core_num) {
if ((cluster_num * core_num % pad_dim_x == 0)) {
task_dim_x = pad_dim_x;
break;
}
pad_dim_x += core_num;
}
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
}
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
float max_cluster_num = core_num * cluster_num / task_dim_x;
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
if (dtype == CNNL_DTYPE_FLOAT) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_HALF) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else {
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
<< "among float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,85 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Sort tokens grouped by different experts based on index. Each token
* selects the topk hidden vectors, multiplies them by corresponding weights,
* and finally reduces the topk vectors for each token. This process involves
* bias and residual, calculated as (x + bias) * weight + residual.
* @example
* input:
* [[[1, 2, 1, 1],
* [1, 1, 1, 2]],
* [[2, 1, 1, 1],
* [1, 1, 1, 1]]]
* num_token = 2, topk = 2
* cusum_token_count = [0, 2, 4]
* index:
* [0, 1, 2, 3]
* weight:
* [0, 0, 1, 1]
* bias:
* [[0, 0, 0, 0],
* [1, 1, 1, 1]]
* residual:
* [[1, 1, 1, 1],
* [0, 0, 0, 0]]
* output:
* [[1, 1, 1, 1],
* [5, 4, 4, 4]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* The shape is [num_token, hidden_size].
* @param input: Input. Pointer to the MLU memory that stores input tokens.
* The shape is [num_token * topk, hidden_size].
* @param bias: Input. Pointer to the MLU memory that stores bias.
* The shape is [num_expert, hidden_size].
* @param residual: Input. Pointer to the MLU memory that stores residual.
* The shape is [num_token, hidden_size].
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
* The shape is [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
* token number of each expert. The shape is [num_expert + 1].
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
* The shape is [num_token * topk].
* @param num_token: The total number of tokens.
* @param topk: The number of expert.
* @param num_expert: The number of expert.
* @param hidden_size: The size of lowest dimension.
* @param start_expert_id: The id of the first processed expert.
* @param expert_size: The number of processed experts.
* @param dtype: Data type.
* @note Currently does not support bias.
*/
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_

View File

@@ -0,0 +1,219 @@
#include <bang_device_functions_extra.h>
#include <mlu.h>
#include "cnnl.h"
#include "cnrt.h"
#include "expand_input.mluh"
namespace tmo {
namespace kernels {
#define RESERVED_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define MEMCPY_BURST_SIZE 128
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_func__ void ExpandInputKernel(void *output,
void *input,
int *index,
int num_token,
int hidden_size,
int num_index) {
void *input_ptr = input;
uint64_t input_size = data_size * num_token * hidden_size;
// whether SRAM_BUFFER_SIZE can hold the input data
// sram_enable is true ==> is_nram_output is true
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
if (sram_enable) {
input_ptr = (void *)sram_buffer;
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
__sync_cluster();
}
if (__is_mpu()) {
return;
}
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
uint32_t total_num = num_index;
uint32_t base = total_num / maxTaskDim;
uint32_t tail = total_num - base * maxTaskDim;
if (taskId >= maxTaskDim) {
return;
}
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
// nram
/*
* first: compute offset: index[i] * data_size
* second: gather data
* -------------------------------------------------
* addr || index/offset | output |
* type || int32_t/T_offset | T |
* num || n | n * hidden_size |
* -------------------------------------------------
*/
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
uint32_t per_num =
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
int *index_base = index + batch_step;
T_offset *nram_offset = (T_offset *)nram_buffer;
int32_t *nram_index;
if (std::is_same<T_offset, int64_t>::value) {
nram_index = (int32_t *)nram_offset + per_num;
} else {
nram_index = (int32_t *)nram_offset;
}
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
uint32_t repeat = batch_per_core / per_num;
uint32_t remain = batch_per_core - repeat * per_num;
uint32_t deal_num = per_num;
uint32_t is_remain = remain != 0 ? 1 : 0;
for (int32_t i = 0; i < repeat + is_remain; i++) {
if (i == repeat) {
deal_num = remain;
}
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
int32_t *index_ptr = index_base + i * per_num;
// index -> offset
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
if (std::is_same<T_offset, uint64_t>::value) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
#endif
}
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
// copy
if (is_nram_output) {
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
// GDRAM or SRAM -> NRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
dir);
}
#endif // __BANG_ARCH__
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
} else {
// GDRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input + offset_value;
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
hidden_size * data_size, GDRAM2GDRAM);
}
#endif // __BANG_ARCH__
}
}
}
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
void *hidden_state,
int *gather_idx,
int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
int total_expert_num,
int start_expert_id,
int expert_count) {
int32_t num_index = num_token * topk;
int *gather_start_idx = (int *)gather_idx;
if (cusum_token_count != nullptr) {
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
*((int *)cusum_token_count + start_expert_id);
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
}
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
num_token, hidden_size, num_index);
}
// instantiate kernels
#define INSTANTIATE_ONE(data_size, T_offset) \
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
void *, void *, int *, int *, int, int, int, int, int, int);
INSTANTIATE_ONE(1, uint32_t)
INSTANTIATE_ONE(2, uint32_t)
INSTANTIATE_ONE(4, uint32_t)
INSTANTIATE_ONE(8, uint32_t)
// large tensor
INSTANTIATE_ONE(1, uint64_t)
INSTANTIATE_ONE(2, uint64_t)
INSTANTIATE_ONE(4, uint64_t)
INSTANTIATE_ONE(8, uint64_t)
} // namespace kernels
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
size_t type_size = 0;
cnnlGetSizeOfDataType(data_type, &type_size);
int max_cluster_num =
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
expert_count);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,81 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count.
* @example
* hidden_state:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [9, 10, 11, 12]]
* gather_idx:
* [[1, 0, 2, 2, 1, 0]]
* cusum_token_count: NULL
* num_token = 3
* hidden_size = 4
* topk = 2
* expand_hidden_state:
* [[5, 6, 7, 8],
* [1, 2, 3, 4],
* [9, 10, 11, 12],
* [9, 10, 11, 12],
* [5, 6, 7, 8],
* [1, 2, 3, 4]]
* @param queue: The queue for mlu.
* @param hidden_state: Input. Pointer to the MLU memory that store the input,
* the shape must be [num_token, hidden_size].
* @param gather_idx: Input. Pointer to the MLU memory that stores the index,
* the shape must be [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of
* token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The
* gather operation will be performed as follows: if cusum_token_count is not NULL: index =
* gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]]
* expand_hidden_state = hidden_state[index]
* else:
* index = gather_idx[:]
* expand_hidden_state = hidden_state[index]
* @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output,
* if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in
* which num_index =
* cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise,
* the shape should be [num_token * topk, hidden_size].
* @param num_token: the number of token.
* @param hidden_size: the slice size.
* @param topk: the number of topk.
* @param data_type: Data type of hidden_state.
* @param total_expert_num: the total number of expert.
* @param start_expert_id: the first expert id.
* @param expert_count: the number of experts currently being processed.
*/
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_

View File

@@ -0,0 +1,935 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <vector>
#include "gen_idx.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
#define ALIGN_16 (16)
#define EXPERT_AVG_COUNT_TEST (0)
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
// Generate integer sequence data from 0 to length-1
__mlu_func__ void generateIntSeq(int *dst, int length) {
int count = 64;
__bang_move(dst, range, std::min(count, length) * sizeof(int));
while (count < length) {
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
count *= 2;
}
}
// genIdx Block kernel, use only 1 core to process
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
/* NRAM space */
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
// --------------------------------------------------------------
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
// | combine_idx | expand_idx | | scatter_offset |
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
// --------------------------------------------------------------
// ------------------------------
// |token_count|token_count_presum|
// | | |
// | num_expert| num_expert |
// ------------------------------
uint32_t token_total_num = num_token * topk;
// num align to 16, size align to 64B
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
#if __BANG_ARCH__ >= 592
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
#endif
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
// Load current core input expert_id and generate int sequence
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
GDRAM2NRAM);
generateIntSeq((int *)gen_idx_onchip, token_total_num);
__sync();
// Initialize sort idx offset
uint32_t sorted_idx_offset = 0;
// Initialize token count first presum with 0
((int *)token_count_presum_onchip)[0] = 0;
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
// Loop on each expert, eq, count, filter index
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
token_total_num);
// Use filter to sort gen_idx, output with sorted_idx_offset
uint32_t cur_expert_count =
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
(float *)cur_expert_result, token_total_num);
sorted_idx_offset += cur_expert_count;
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
// Compute cusum token count and store
if (need_cusum_token_count) {
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
}
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
#endif
__sync_compute();
// Store token_count and cusum token count
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
NRAM2GDRAM);
if (need_cusum_token_count) {
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Use sorted idx to generate gather idx for expand and combine
#if __BANG_ARCH__ >= 592
// scatter_offset = sorted_idx mul_scalar sizeof(int);
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
token_total_num);
#else
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
#endif
__sync_compute();
#if __BANG_ARCH__ >= 592
// scatter_async to NRAM
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
#endif
// expand_idx = sorted_idx div(topk)
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
// Store expand_idx and combine_idx
__sync_compute();
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
NRAM2GDRAM);
#if __BANG_ARCH__ >= 592
__sync_move();
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
token_total_num * sizeof(int), NRAM2GDRAM);
#else
// 370 directly scatter to GDRAM
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
#endif
}
// Only MLU500 series support NRAM2SRAM scatter direction
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
#if __BANG_ARCH__ >= 592
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
seg_offset += 32768;
}
if (seg_remain > 0) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
}
#endif
}
// Scatter sequence, transfer size is sizeof(int)
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#endif
seg_offset += 32768;
}
if (seg_remain > 0) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#endif
}
}
// 1. Get token count
__mlu_func__ void getTokenCount(int *token_count,
int *expert_id,
int token_cur_core,
int cur_token_start,
int num_expert) {
// 1. Partition on [num_token*topk],
// each core for-loop on all expert_id, use eq and count instructions,
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
// And sync for all cores.
// NRAM:
// ------------------------------------------------------
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
// | deal_num | deal_num | num_expert |
// ------------------------------------------------------
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
// Current core data loop
uint32_t repeat = token_cur_core / deal_num;
uint32_t remain = token_cur_core % deal_num;
uint32_t total_repeat = repeat + (int)(remain > 0);
uint32_t token_addr_offset = cur_token_start;
// Initialize token_count with 0
if (taskId == 0) {
__gdramset((int *)token_count, num_expert, 0);
}
// Sync for initialize token_count
__sync_all_ipu();
// Initialize expert count onchip with 0
if (token_cur_core > 0) {
__bang_write_zero((int *)expert_count_onchip, num_expert);
}
// actual num in loop
int cur_deal_num = deal_num;
for (int i = 0; i < total_repeat; i++) {
if (i == total_repeat - 1 && remain > 0) {
cur_deal_num = remain;
}
// Load current core input expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
cur_deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += cur_deal_num;
// Loop on each expert, eq, count
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
// NOTE: __bang_count() only support floating data type
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
}
}
// AtomicAdd(reduce) all cores token count results
if (token_cur_core > 0) {
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
}
// Sync for all cores, get accumulate of token_count
__sync_all_ipu();
}
// 2. Get token count presum, for each expert index start address after sorting
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
int *token_count,
const int num_expert) {
// 2. After first process, already get token_count.
// Then use one core to pre-sum on token_count, consider size of int32,
// first expert id start address should be zero.
// to get each expert id start address after sorting, store to workspace,
// token_count_presum.
// And sync for all cores.
// NRAM:
// load token_count to token_count_presum[1~num_expert+1],
// for i = 0 to num_expert:
// token_count_presum[i+1] += token_count_presum[i]
// store token_count_presum[0~num_expert]
// -------------------------
// |token_count_presum_onchip|
// | {0}, num_expert |
// -------------------------
if (taskId == 0) {
// Initialize count presum onchip with a first 0
int8_t *token_count_presum_onchip = nram_buffer;
((int *)token_count_presum_onchip)[0] = 0;
// Load token_count with an offset of 1
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
GDRAM2NRAM);
// Calculate presum of token count by each expert
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
((int *)token_count_presum_onchip)[cur_expert + 1] +=
((int *)token_count_presum_onchip)[cur_expert];
}
// Store token count presum to workspace
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Sync for all cores, get presum of token count
__sync_all_ipu();
}
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
int *token_count,
const uint32_t token_total_num,
const int num_expert) {
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
int8_t *token_count_onchip = nram_buffer;
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
((int *)token_count_presum_onchip)[0] = 0;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// 3. Get expert position index after sorting
__mlu_func__ void getSortedIdx(int *sorted_idx,
int *expert_id,
int *token_count_presum,
const int token_total_num,
const int num_expert,
const int expert_cur_core,
const int cur_expert_start,
const int cur_expert_end) {
// 3. Partition on num_expert, each core generate position index from 0,
// and for-loop on all expert_id data, use eq with own each expert_id,
// and filter on index, stores to each expert_id start address of
// sorted_idx on workspace.
// And sync for all cores.
// NRAM:
// -------------------------------------------------------------------
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// |expert_start_addr|
// | num_expert |
// -----------------
// Calculate new deal_num of sorting process
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
// Each core deal with whole token expert_id data
int repeat = token_total_num / deal_num;
int remain = token_total_num % deal_num;
int token_addr_offset = 0;
int8_t *expert_id_onchip = nram_buffer;
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
// When num_expert < taskDim, not all cores need to sort
if (expert_cur_core > 0) {
// Generate position index from 0
if (deal_num <= token_total_num) {
generateIntSeq((int *)gen_idx_onchip, deal_num);
} else { // only remainder part
generateIntSeq((int *)gen_idx_onchip, token_total_num);
}
// Initialize expert start address with presum of token count
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
GDRAM2NRAM);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += deal_num;
// Loop for current core expert, eq, filter position index
// use filter, store to sorted_idx[expert_start_addr]
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, deal_num);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
// Update address offset of current expert
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
}
}
// Update position index for each data loop
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
}
// remainder part
if (remain > 0) {
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, remain);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
}
}
}
}
// Sync for all cores, get position index after sorting
__sync_all_ipu();
}
// 4. Get gather index for expand and combine
template <bool is_sram_scatter>
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
int *gather_combine_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4. Partition on [num_token*topk],
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// gather_expand_idx = sorted_idx / topk
// update sequence
// NRAM:
// -------------------------------------------------------------------
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
deal_num);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, deal_num);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
if (is_sram_scatter) {
// if scatter to SRAM, need to sync compute with mv
__sync_move();
}
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
remain);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), remain);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
remain);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, remain);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
// 4.1 Get gather combine index on SRAM
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
const int token_cur_core,
const int cur_token_start) {
// 4.1 Partition on [num_token*topk], with only 1 union
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// update sequence
// NRAM:
// -------------------------------
// |scatter_offset|scatter_sequence|
// | deal_num | deal_num |
// -------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *scatter_offset = nram_buffer;
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
// Sync for scatter
__sync_compute();
// Scatter to SRAM
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
__sync_move();
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
// Sync for scatter
__sync_compute();
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
}
}
// 4.2 Get gather expand index
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4.2 Partition on [num_token*topk],
// load sorted_idx onchip,
// gather_expand_idx = sorted_idx / topk
// NRAM:
// -----------------------------------
// |sorted_idx_onchip|expand_idx_onchip|
// | deal_num | deal_num |
// -----------------------------------
// Calculate new deal_num of generate gather index
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
// Store token count presum result, shape [num_expert + 1]
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
// Store position index after sorting, shape [num_token*topk]
int *sorted_idx = ((int *)workspace) + num_expert + 1;
// Calculate partition information for different processes
// Partition on [num_token*topk]
uint32_t token_total_num = num_token * topk;
uint32_t token_cur_core = token_total_num / taskDim;
uint32_t token_remain_num = token_total_num % taskDim;
token_cur_core += (uint32_t)(taskId < token_remain_num);
// Current core range according to partition on [num_token*topk]
uint32_t cur_token_start = (taskId < token_remain_num)
? token_cur_core * taskId
: token_cur_core * taskId + token_remain_num;
// Partition on [num_expert]
uint32_t expert_cur_core = num_expert / taskDim;
uint32_t expert_remain_num = num_expert % taskDim;
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
// Current core range according to partition on [num_expert]
uint32_t cur_expert_start = (taskId < expert_remain_num)
? expert_cur_core * taskId
: expert_cur_core * taskId + expert_remain_num;
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
// Use Union1 SRAM to scatter, only MLU500 series support now
#if __BANG_ARCH__ >= 592
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
#else
bool is_sram_scatter = false;
#endif
if (__is_ipu()) {
// 1. Get token count
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
num_expert);
// 2. Get presum of token count
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
// 3. Get expert position index after sorting
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
if (__is_ipu() && taskId == 0) {
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
num_expert);
}
__sync_cluster();
#endif
// 4. Get gather index for expand and combine
if (is_sram_scatter) {
// Only use Union1 SRAM
uint32_t scatter_idx_cur_core = token_total_num / 4;
uint32_t scatter_idx_remain_num = token_total_num % 4;
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
? scatter_idx_cur_core * taskId
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
// Only Union1 task type,
// deal once num is same with deal_num in getGatherIdx,
// which means only 1 repeat to generate both expand and combine idx on NRAM
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
if (taskDim <= 4 || token_total_num < deal_once_num) {
if (taskId < 4) {
if (__is_ipu()) {
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
scatter_idx_cur_core, cur_idx_start, topk);
// sync for ipu and mpu
__sync_cluster();
} else {
// sync for ipu and mpu
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
}
} else {
// If taskDim > 4, use first union to generate combine idx,
// use other union to generate expand idx
if (taskId < 4) {
if (__is_ipu()) {
// Scatter combine idx to SRAM
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
__sync_cluster();
} else {
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
} else {
// Other union generate expand idx
if (__is_ipu()) {
uint32_t expand_dim = taskDim - 4;
uint32_t expand_id = taskId - 4;
uint32_t expand_token_cur_core = token_total_num / expand_dim;
uint32_t expand_token_remain_num = token_total_num % expand_dim;
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
uint32_t expand_cur_token_start =
(expand_id < expand_token_remain_num)
? expand_token_cur_core * expand_id
: expand_token_cur_core * expand_id + expand_token_remain_num;
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
expand_cur_token_start, topk);
}
}
}
} else {
// not use SRAM to generate both expand and combine idx
if (__is_ipu()) {
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
token_cur_core, cur_token_start, topk);
}
}
// step 5 does not need MPU
if (__is_mpu()) {
return;
}
} // end of kernel
} // namespace kernels
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
const int token_total_num = num_token * topk;
// For partition on num_token*topk, single core processes at least 128 num
const int single_core_num_limit = 1024;
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
// When partition on num_expert, each core at least processes one expert
need_core_num = std::max(num_expert, need_core_num);
// When consider UnionX cnrt func type, reset cluster_num
if (token_total_num <= 4096) { // Block
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
cnrtDim3_t k_dim{1, 1, 1};
// Block kernel does not need workspace
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
} else if (need_core_num <= 4) { // Union1
cluster_num = 1;
} else if (need_core_num <= 8) { // Union2
cluster_num = std::min(cluster_num, 2);
} else if (need_core_num <= 16) { // Union4
cluster_num = std::min(cluster_num, 4);
} else if (need_core_num <= 32) { // Union8
cluster_num = std::min(cluster_num, 8);
}
cnrtFunctionType_t k_type;
cnrtDim3_t k_dim{1, 1, 1};
// Find max UnionX cnrt func type
if (cluster_num == 1) {
k_type = cnrtFuncTypeUnion1;
k_dim.x = 4;
} else if (cluster_num < 4) { // cluster num is 2 or 3
k_type = cnrtFuncTypeUnion2;
k_dim.x = 8;
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
k_type = cnrtFuncTypeUnion4;
k_dim.x = 16;
} else { // cluster num larger than 8
k_type = cnrtFuncTypeUnion8;
k_dim.x = 32;
}
// The expert_id is int data type
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
num_token, num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
#undef EXPERT_AVG_COUNT_TEST // undef test macro
} // namespace tmo

View File

@@ -0,0 +1,58 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#include <vector>
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Apply generate MOE index operation, which performs the following
* tasks:
* - 1. Generate gather_expand_idx and gather_combine_idx.
* - 2. Output token_count, the token number of each expert.
* - 3. Prepare inputs and outputs address for group_gemm.
* @param queue: The queue of mlu.
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
* gather index for expand hidden state operation, the shape must be
* [num_token * topk].
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
* gather index for combine MOE operation, the shape must be
* [num_token * topk].
* @param token_count: Output. Pointer to the MLU memory that stores the token
* number of each expert, the shape must be [num_expert].
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
* cumulative sum of the token number of each expert, the shape must be
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
* @param workspace: Input. A pointer to the extra workspace required in the
* operation, the size must be larger than
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
* of each token, the shape must be [num_token, topk].
* @param num_token: The number of tokens.
* @param num_expert: The number of experts.
* @param topk: The number of expert selected by each token.
*/
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_

View File

@@ -0,0 +1,21 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_MOE_MLUH_
#define CSRC_KERNELS_MOE_MOE_MLUH_
#include "add_bias_activation.mluh"
#include "combine_result.mluh"
#include "expand_input.mluh"
#include "gen_idx.mluh"
#include "softmax_topk.mluh"
#endif // CSRC_KERNELS_MOE_MOE_MLUH_

View File

@@ -0,0 +1,602 @@
#include <mlu.h>
#include <cassert>
#include <iostream>
#include <limits>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "softmax_topk.mluh"
namespace tmo {
namespace kernels {
#define SCATTER_ALIGN (64) // align for __scatter()
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
#define TILING_ALIGN (64)
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
#define __TRANS_TILING(TYPE, CONVERT) \
__asm__ volatile("trans.tiling." TYPE \
" [%[dst]], [%[src]]," \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
"%[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
const SRC_DTYPE *src,
const uint32_t in0,
const uint32_t in1,
const uint32_t is1,
const uint32_t in2,
const uint32_t is2,
const uint32_t in3,
const uint32_t is3,
const uint32_t in4,
const uint32_t is4,
const uint32_t in5,
const uint32_t is5,
const uint32_t dn0,
const uint32_t dn1,
const uint32_t ds1,
const uint32_t dn2,
const uint32_t ds2,
const uint32_t dn3,
const uint32_t ds3,
const uint32_t dn4,
const uint32_t ds4,
const uint32_t dn5,
const uint32_t ds5) {
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
if (std::is_same<SRC_DTYPE, float>::value) {
__TRANS_TILING("nram.sram.b32", ";")
} else if (std::is_same<SRC_DTYPE, half>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
#if __BANG_ARCH__ >= 500
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
#endif
}
}
}
/* 将shape为[h,w]的数据转置为[w,h](带转数)分4块分别进行处理。
* dst: dst地址
* src: src地址
* h: h方向大小
* w: w方向大小
*/
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
uint32_t w_align = w / align_num;
uint32_t w_rem = w % align_num;
uint32_t h_align = h / align_num;
uint32_t h_rem = h % align_num;
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
uint32_t in3 = w_align, is3 = TILING_ALIGN;
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
/* 1. h_align * w_align */
if (w_align > 0 && h_align > 0) {
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
}
/* 2. h_align * w_rem */
if (w_rem > 0 && h_align > 0) {
SRC_DTYPE *src_temp = src + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = TILING_ALIGN;
in1 = align_num;
is1 = w * sizeof(SRC_DTYPE);
in4 = h_align;
is4 = w * TILING_ALIGN;
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 3. h_rem * w_align */
if (w_align > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w;
DST_DTYPE *dst_temp = dst + h_align * align_num;
in0 = TILING_ALIGN;
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
in4 = w_align;
is4 = TILING_ALIGN;
dn1 = align_num;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = h * align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 4. h_rem * w_rem */
if (w_rem > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
}
}
__mlu_func__ void getTopk(float *value_buffer,
uint32_t *index_buffer,
float *src_buffer,
float *compute_buffer,
float *max_buffer,
float *temp_buffer,
uint32_t *i_buffer,
uint32_t *col_buffer,
uint32_t topk,
uint32_t num_expert_group,
uint32_t col,
uint32_t row,
uint32_t value_index_stride,
uint32_t group_size,
bool is_deal_group) {
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
for (int k = 0; k < topk; k++) {
if (is_deal_group) {
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
1, num_expert_group, 1, 1);
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
col);
} else {
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
value_index_stride);
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
}
#if __BANG_ARCH__ >= 592
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
col); // replace max value with -inf
#else
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram(col_buffer + i);
max_buffer[index] = -INFINITY;
}
#endif
#if __BANG_ARCH__ < 500
if (is_deal_group) {
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
__memcpy(compute_buffer + i * row + index * group_size,
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
}
}
#endif
}
#if __BANG_ARCH__ >= 592
if (is_deal_group) {
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
topk - 1);
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
(uint32_t *)compute_buffer, col * topk, col * topk);
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
NRAM2NRAM, group_size * sizeof(float), col * topk);
__bang_write_value(src_buffer, row * col, -INFINITY);
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
group_size * sizeof(float), col * topk);
}
#endif
}
template <typename T>
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
T *load_buffer,
float *src_buffer,
float *compute_buffer,
float *group_max_buffer,
float *nramout_value,
uint32_t *nramout_index,
uint32_t *i_buffer,
uint32_t *col_buffer,
float *softmax_buffer,
uint32_t row,
uint32_t nram_compute_col_num,
uint32_t mask_num,
uint32_t nram_max_col_num,
uint32_t topk,
int num_expert_group,
uint32_t topk_group,
uint32_t top_num,
uint32_t nram_col_offset,
int normalize_mode,
bool valid_mask,
bool split_mask) {
uint32_t nram_compute_num = nram_compute_col_num * row;
// convert to float for half/bf16 datatype
if (std::is_same<T, half>::value) {
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
}
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
// compute softmax
int tmp = 0x3fb8aa3b;
float log2e = *(float *)&tmp; // for exp
// src_buffer reuse as buffer for max/sum.
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
nram_compute_col_num);
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
nram_compute_col_num);
__sync_cluster();
// move mask and compute
if (valid_mask) {
if (!split_mask) {
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
if (std::is_same<T, half>::value) {
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
SRAM2NRAM);
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
mask_num * row);
} else if (std::is_same<T, bfloat16_t>::value) {
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
mask_num * row * sizeof(T), SRAM2NRAM);
__bang_bfloat162float((float *)compute_buffer,
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
} else {
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
}
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
mask_num * row);
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
} else {
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
nram_compute_col_num, row);
__sync();
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
}
}
if (normalize_mode == 2) {
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
}
if (num_expert_group <= 1) {
// num_expert_group <= 1, maintain original topk calculation logic
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * topk * sizeof(float), 0, false);
} else {
// num_expert_group > 1, use grouped_topk calculation logic
uint32_t group_size = row / num_expert_group;
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
group_size, 1, group_size, 1, 1);
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
// get topk_group
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
// get topk
#if __BANG_ARCH__ < 500
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#else
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#endif
} // end else
// normalize result
if (normalize_mode == 1) {
// compute_buffer reuse as buffer for sum.
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
} else if (normalize_mode == 2) {
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
}
// transpose back. src and dst of transpose can not be the same address.
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
}
template <typename T>
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
T *mask,
int *index_out,
float *value_out,
int col,
int row,
int mask_num,
int topk,
int num_expert_group,
int topk_group,
int normalize_mode) {
bool valid_mask = (mask != nullptr);
int top_num = topk >= topk_group ? topk : topk_group;
uint32_t nram_low_space =
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
SCATTER_ALIGN);
if (num_expert_group <= 1) {
nram_low_space =
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
}
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
nram_max_col_num = col / taskDim + (col % taskDim > 0);
}
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
if (nram_max_col_num <= 0) {
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
}
uint32_t nram_deal_num = nram_max_col_num * row;
uint32_t batch = col / mask_num;
// nram split:
// |--------------------------|--------------------------|--------------------|...
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
// | src_buffer | compute_buffer | group_max_buffer |...
// |--------------------------|--------------------------|--------------------|...
// |----------------------------------------|---------------|--------------|
// | nram_col_num*3 | col*topk | col*topk |
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
// |----------------------------------------|---------------|--------------|
float *src_buffer = (float *)nram_buffer;
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
float *group_max_buffer = compute_buffer + nram_deal_num;
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
if (num_expert_group <= 1) {
i_buffer = (uint32_t *)group_max_buffer;
}
uint32_t *col_buffer = i_buffer + nram_max_col_num;
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
if (normalize_mode != 2) {
softmax_buffer = (float *)col_buffer;
}
float *nramout_value = softmax_buffer + nram_max_col_num;
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
if (num_expert_group <= 1) {
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
}
T *load_buffer = (T *)src_buffer;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
load_buffer = load_buffer + nram_deal_num;
}
// set i_buffer
for (uint32_t i = 0; i < nram_max_col_num; i++) {
i_buffer[i] = i;
}
// input[batch, mask, low], mask[mask, low]
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
bool split_mask = false;
uint32_t batch_seg = nram_max_col_num / mask_num;
uint32_t batch_rem = batch % batch_seg;
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
int repeat = DIV_UP(batch_seg_num, taskDim);
for (int i = 0; i < repeat; i++) {
uint32_t seg_id = i * taskDim + taskId;
uint32_t sram_load_num = mask_num * row;
uint32_t sram_load_offset = 0;
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
? batch_rem * mask_num
: batch_seg * mask_num;
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
} else {
bool split_mask = true;
uint32_t mask_seg = nram_max_col_num;
uint32_t mask_rem = mask_num % mask_seg;
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
uint32_t batch_idx = i / sram_mask_seg_num;
uint32_t mask_idx = i % sram_mask_seg_num;
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
uint32_t sram_load_num = sram_deal_mask_num * row;
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
? mask_idx * (sram_average_mask_num + 1)
: mask_idx * sram_average_mask_num + sram_mask_rem;
uint32_t sram_load_offset = sram_mask_offset * row;
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
uint32_t nram_load_num = nram_deal_mask_num * row;
uint32_t nram_col_offset = taskIdX < nram_mask_rem
? taskIdX * (nram_average_mask_num + 1)
: taskIdX * nram_average_mask_num + nram_mask_rem;
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
uint32_t nram_store_num = nram_deal_mask_num * topk;
uint32_t nram_store_offset =
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, nram_col_offset,
normalize_mode, valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
}
}
} // namespace kernels
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
int top_num = topk >= topk_group ? topk : topk_group;
if (num_expert_group <= 1) {
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
} else {
if (num_expert >
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (topk > num_expert) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
<< "topk:" << topk << ". num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (num_expert_group > 1) {
if (mask != nullptr) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
}
if (num_expert % num_expert_group != 0) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
<< "divisible by num_expert_group, but now num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk_group <= 0 || topk_group > num_expert_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
<< topk_group << ", num_expert group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk > (num_expert / num_expert_group) * topk_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
<< "topk :" << topk << ", num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
num_mask, topk, num_expert_group, topk_group, normalize_mode);
} else {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Execute MOE Softmax Top-K Kernel.
*
* This function executes the MOE Softmax Top-K Kernel, which computes
* the Top-K values along a specified dimension after applying softmax to the input data.
* It is specifically designed for reduction along the lowest dimension.
*
* @param queue CNRT queue used to specify the queue for execution.
* @param reduce_weight Pointer to store the Top-K values.
* The shape must be [num_token, topk].
* @param expert_id Pointer to store the indices of the Top-K values.
* The shape must be [num_token, topk].
* @param input Pointer to the input data containing the values to be computed.
* The shape must be [num_token, num_expert].
* @param mask Pointer to the input data containing the mask value to be computed after
* computing softmax, Mask can be nullptr, which means no need to compute,
* otherwise the shape and datatype of mask should be the same as input.
* @param num_token Number of channels in the input data.
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
* @param num_mask Number of channels in the mask data.
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
* is not valid.
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
* than num_expert_group.
* @param dtype Data type of the input data, should match the actual data type.
* float, half, bfloat16 is supported.
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
normalization is performed; if normalize_mode == 1, the normalized denominator is
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
* the products of softmax_result mask.
*/
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_

View File

@@ -0,0 +1,425 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cassert>
#include <iostream>
#include "offline_quant_to_linear_cache.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
#define sizeof_(T) (uint32_t)sizeof(T)
template <typename T>
__mlu_func__ void quantify(int8_t *nram_output,
float *nram_input_float,
T *nram_input,
float *nram_scale,
int input_num,
int scale_num) {
if (std::is_same<half, T>::value) {
__bang_half2float(nram_input_float, (half *)nram_input, input_num);
} else if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(nram_input_float, (bfloat16_t *)nram_input, input_num);
#endif
}
__bang_cycle_mul(nram_input_float, nram_input_float, nram_scale, input_num, scale_num);
__bang_float2int8_rn(nram_output, (float *)nram_input_float, input_num, 0);
}
template <typename T>
__mlu_func__ void quantPerHead(int8_t *output_gdram,
int8_t *output_nram,
const T *input_gdram,
T *input_nram,
const float *scale_gdram,
float *scale_nram,
float *input_nram_float,
T *trans_nram,
int seq,
int head_num,
int head_size,
size_t in_hstr_bytes, // context head_num stide bytes
size_t in_sstr_bytes, // context seq stide bytes
size_t scale_hstr_bytes, // scale head_num stide bytes
size_t out_hstr_bytes, // cache head_num stride bytes
size_t out_sstr_bytes // cache seq stride bytes
) {
constexpr int dtype_size = sizeof_(T);
// nram_input: (head_num, seq, head_size)
int io1_size = head_size * dtype_size;
__memcpy(trans_nram, input_gdram, io1_size, GDRAM2NRAM, seq * io1_size, head_num - 1, io1_size,
seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
// nram_scale:(head_num, seq);
int io2_size = seq * sizeof_(float);
__memcpy(scale_nram, scale_gdram, io2_size, GDRAM2NRAM, io2_size, scale_hstr_bytes, head_num - 1);
__bang_recip(scale_nram, scale_nram, head_num * seq);
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
__bang_transpose((half *)input_nram, (half *)trans_nram, head_num * seq, head_size);
} else {
__bang_transpose(input_nram_float, (float *)trans_nram, head_num * seq, head_size);
}
quantify<T>(output_nram, input_nram_float, input_nram, scale_nram, head_size * head_num * seq,
head_num * seq);
__bang_transpose((int8_t *)trans_nram, output_nram, head_size, head_num * seq);
__memcpy(output_gdram, trans_nram, head_size, NRAM2GDRAM, out_hstr_bytes, head_num - 1,
out_sstr_bytes, seq - 1, seq * head_size, head_num - 1, head_size, seq - 1);
}
template <typename T>
__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerHead(int8_t *key_cache,
int8_t *value_cache,
const float *key_cache_scale,
const float *value_cache_scale,
const int *cache_bs_offsets,
const int *cache_seq_offsets,
const T *key,
const T *value,
const int *context_seq_offsets,
const int *context_lens,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int seq_block) {
bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr);
bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr);
if ((!handle_key) && (!handle_value)) {
return;
}
constexpr int dtype_size = sizeof_(T);
size_t in_hstr_bytes = context_head_stride * dtype_size;
size_t in_sstr_bytes = context_seq_stride * dtype_size;
size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t);
size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t);
size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float);
/* ***************************nram space ****************************
* | scale | input/output | trans |
* scale size:[head_num, seq_block], float
* input size:[head_num, seq_block, head_size], float
* trans size:, [head_size, head_num, seq_block], T
*/
float *scale_nram = (float *)nram_buffer;
float *input_nram_float = nullptr;
T *trans_nram = nullptr, *input_nram = nullptr;
input_nram_float = scale_nram + head_num * seq_block;
trans_nram = (T *)(input_nram_float + head_num * seq_block * head_size);
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
// need cast from input_nram to input_nram_float
input_nram = (T *)input_nram_float + seq_block * head_num * head_size;
} else {
input_nram = (T *)input_nram_float;
}
int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int task_seq_begin = taskIdZ * seq_block;
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, seq_block);
// context offset
size_t context_offset = 0;
if (packed) {
context_offset = (context_len + task_seq_begin) * context_seq_stride;
} else {
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
context_offset =
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
}
// cache offset
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
size_t cache_offset = (cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
// per_head, nram input[head_num, seq, head_size], nram scale[head_num, seq]
if (handle_key) {
quantPerHead(key_cache + cache_offset, output_nram, key + context_offset, input_nram,
key_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram,
seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes,
out_hstr_bytes, out_sstr_bytes);
}
if (handle_value) {
quantPerHead(value_cache + cache_offset, output_nram, value + context_offset, input_nram,
value_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram,
seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes,
out_hstr_bytes, out_sstr_bytes);
}
}
}
template <typename T>
__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerChannel(
int8_t *key_cache,
int8_t *value_cache,
const float *key_cache_scale,
const float *value_cache_scale,
const int *cache_bs_offsets,
const int *cache_seq_offsets,
const T *key,
const T *value,
const int *context_seq_offsets,
const int *context_lens,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int seq_block) {
bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr);
bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr);
if ((!handle_key) && (!handle_value)) {
return;
}
constexpr int dtype_size = sizeof_(T);
size_t in_hstr_bytes = context_head_stride * dtype_size;
size_t in_sstr_bytes = context_seq_stride * dtype_size;
size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t);
size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t);
size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float);
/* *********************************nram space **************************************
* per_chennel: |scale[head_num, head_size]| input[seq_block, head_num, head_size]|
*/
float *scale_nram = (float *)nram_buffer;
float *input_nram_float = scale_nram + head_num * head_size;
T *input_nram = (T *)input_nram_float;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
// need cast from input_nram to input_nram_float
input_nram = (T *)input_nram_float + seq_block * head_num * head_size;
}
int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space
int size1 = head_size * sizeof_(float);
int size2 = head_size * dtype_size;
int scale_num = head_num * head_size;
if (handle_key) {
// load offline scale nram_scale:(head_num, head_size);
__memcpy(scale_nram, key_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes, head_num - 1);
__bang_recip(scale_nram, scale_nram, scale_num);
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int task_seq_begin = taskIdZ * seq_block;
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, seq_block);
// context offset
size_t context_offset = 0;
if (packed) {
context_offset = (context_len + task_seq_begin) * context_seq_stride;
} else {
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
context_offset =
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
}
// cache offset
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
size_t cache_offset =
(cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
__memcpy(input_nram, key + context_offset, size2, GDRAM2NRAM, size2, head_num - 1,
head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
quantify<T>((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num,
scale_num);
__memcpy(key_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes,
head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1);
}
}
if (handle_value) {
// load offline scale nram_scale:(head_num, head_size);
__memcpy(scale_nram, value_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes,
head_num - 1);
__bang_recip(scale_nram, scale_nram, scale_num);
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
int context_len = __load_gdram(context_lens + bs_idx);
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
int task_seq_begin = taskIdZ * seq_block;
if (task_seq_begin >= seq_len) continue;
int seq = std::min(seq_len - task_seq_begin, seq_block);
// context offset
size_t context_offset = 0;
if (packed) {
context_offset = (context_len + task_seq_begin) * context_seq_stride;
} else {
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
context_offset =
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
}
// cache offset
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
continue;
}
cache_seq_offset += task_seq_begin;
size_t cache_offset =
(cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
__memcpy(input_nram, value + context_offset, size2, GDRAM2NRAM, size2, head_num - 1,
head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
quantify<T>((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num,
scale_num);
__memcpy(value_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes,
head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1);
}
}
}
} // namespace kernels
#define LAUNCH_OFFLINE_QUANT_KERNEL(Dtype, Name) \
kernels::MLUOfflineQuantToLinearCacheKernel##Name<Dtype><<<dim, cnrtFuncTypeBlock, queue>>>( \
(int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, \
(float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (Dtype *)key, \
(Dtype *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, \
max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, \
cache_bs_stride, cache_head_stride, cache_seq_stride, cache_scale_head_stride, packed, \
seq_block);
KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *key_cache_scale,
const void *value_cache_scale,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
const void *key,
const void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int quant_mode) {
constexpr int nram_size = 480 * 1024;
int dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof(float) : sizeof(half);
int seq_block = 0;
if (quant_mode == 0) {
seq_block = nram_size / (head_num * head_size * sizeof(float)) - 1;
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * head_size * sizeof(float) should be less than 240KB when "
"quant_mode is 0."
<< std::endl;
}
} else {
seq_block = nram_size /
(head_num * sizeof(float) + head_num * head_size * (sizeof(float) + dtype_size));
if (seq_block <= 0) {
std::cerr << __func__ << "," << __LINE__
<< " :head_num * sizeof(float) + head_num * head_size * (sizeof(float) + "
"context_dtype_size)) "
<< " should be less than 480KB when quant_mode is not 0." << std::endl;
}
}
seq_block = std::min(seq_block, max_context_len);
if (seq_block > 16 && seq_block < max_context_len) {
seq_block = seq_block / 16 * 16;
}
int seq_seg = (max_context_len + seq_block - 1) / seq_block;
CNdev dev;
int cluster_dim, core_dim;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
uint32_t core_num = cluster_dim * core_dim;
uint32_t task_y_dim = std::min((uint32_t)batch, core_num);
cnrtDim3_t dim{1, task_y_dim, (uint32_t)seq_seg};
if (dtype == CNNL_DTYPE_HALF) {
if (quant_mode == 0) {
LAUNCH_OFFLINE_QUANT_KERNEL(half, PerChannel);
} else {
LAUNCH_OFFLINE_QUANT_KERNEL(half, PerHead);
}
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (quant_mode == 0) {
LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerChannel);
} else {
LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerHead);
}
} else {
if (quant_mode == 0) {
LAUNCH_OFFLINE_QUANT_KERNEL(float, PerChannel);
} else {
LAUNCH_OFFLINE_QUANT_KERNEL(float, PerHead);
}
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,103 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
#define CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Quantize current key and value, Then store key and value to key_cache and value_cache.
* @param queue: The queue for mlu.
* @param key_cache: Pointer to the MLU memory that stores the key cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* Data type of key_cache must be int8. key_cache could be nullptr.
* @param value_cache: Pointer to the MLU memory that stores the value cache,
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
* Data type of value_cache must be int8. value_cache could be nullptr.
* @param key_cache_scale: Pointer to the MLU memory that stores the key cache scale,
* the shape must be [head_num, cache_mem_len] when quant_mode is not zero,
* and [head_num, head_size] when quant_mode is zero. Data type of key_cache_scale
* must be float. value_cache could be nullptr.
* @param value_cache_scale: Pointer to the MLU memory that stores the value cache scale,
* the shape must be [head_num, cache_mem_len] when quant_mode is not zero,
* and [head_num, head_size] when quant_mode is zero. Data type of value_cache_scale
* must be float. value_cache_scale could be nullptr.
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is {0, 1, 2 ... batch - 1}.
* @param cache_seq_offsets: Pointer to the MLU memory that stores the sequence
* offset of cache, the shape must be [batch], if it's nullptr, the
* default value is 0 for every batch.
* @param key: Pointer to the MLU memory that stores the key,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* Data type of key couble be float/half/bfloat16. key could be nullptr.
* @param value: Pointer to the MLU memory that stores the value,
* the shape must be [batch, max_contxt_len, head_num, head_size].
* Data type of value couble be float/half/bfloat16, value could be nullptr.
* @param context_seq_offsets: Pointer to the MLU memory that stores the
* sequence offset of context, the shape must be [batch]. if it's nullptr,
* the default value is 0 for every batch. It must be nullptr when packed is true.
* @param context_lens: Pointer to the MLU memory that stores the sequence length or cumulative
* sequence length of context. when packed is false, the shape must be [batch], which
* indicates sequence length of context. when packed is true, the shape must be [batch + 1],
which
* indicates cumulative sequence length of context.
* @param dtype: Data type.
* @param batch: Batch size.
* @param head_num: Head number.
* @param head_size: Head size.
* @param max_contxt_len: The maximum sequence length of context.
* @param cache_mem_len: The maximum sequence length of cache.
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
* @param contxt_head_stride: The stride of head_num in context.
* @param contxt_seq_stride: The stride of max_contxt_len in context.
* @param cache_bs_stride: The stride of batch in cache.
* @param cache_head_stride: The stride of head_num in cache.
* @param cache_seq_stride: The stride of cache_mem_len in cache.
* @param cache_scale_bs_stride: The stride of batch in cache scale.
* @param cache_scale_head_stride: The stride of head in cache scale.
* @param packed: A boolean value indicates whether to use pack mode.
* @param quant_mode: A int value indicates the quantify mode, 0 means quantify by per_channel, and
others value means quantify by per_head.
* @note If one of key/key_cache/key_cache_scale is nullptr, nothing todo for key.
If one of value/value_cache/value_cache_scale is nullptr, nothing todo for value.
*/
KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue,
void *key_cache,
void *value_cache,
const void *key_cache_scale,
const void *value_cache_scale,
const void *cache_bs_offsets,
const void *cache_seq_offsets,
const void *key,
const void *value,
const void *context_seq_offsets,
const void *context_lens,
const cnnlDataType_t dtype,
const int batch,
const int head_num,
const int head_size,
const int max_context_len,
const int cache_mem_len,
const size_t context_bs_stride,
const size_t context_head_stride,
const size_t context_seq_stride,
const size_t cache_bs_stride,
const size_t cache_head_stride,
const size_t cache_seq_stride,
const size_t cache_scale_head_stride,
const bool packed,
const int quant_mode);
} // namespace tmo
#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_

View File

@@ -0,0 +1,232 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <climits>
#include "offline_quant_to_paged_cache.mluh"
namespace tmo {
namespace kernels {
#define sizeof_(T) (uint32_t)sizeof(T)
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#define REM_FOR_STACK (32 * 1024)
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK];
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
template <typename T>
__mlu_func__ void quantifyToInt8(T *nram_input, float *nram_scale, int token_handle, int head_len) {
// quantify
if (std::is_same<half, T>::value) {
__bang_half2float((float *)nram_input,
(half *)((int8_t *)nram_input + token_handle * head_len * sizeof_(half)),
token_handle * head_len);
}
if (std::is_same<bfloat16_t, T>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(
(float *)nram_input,
(bfloat16_t *)((int8_t *)nram_input + token_handle * head_len * sizeof_(bfloat16_t)),
token_handle * head_len);
#endif
}
__bang_cycle_mul((float *)nram_input, (float *)nram_input, nram_scale, token_handle * head_len,
head_len);
__bang_float2int8_rn((int8_t *)nram_input, (float *)nram_input, token_handle * head_len, 0);
}
template <typename T>
__mlu_global__ void MLUOfflineQuantToPagedCacheKernel(T *key,
T *value,
int8_t *key_cache,
int8_t *value_cache,
float *key_cache_scale,
float *value_cache_scale,
int *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int tokens_num,
int head_num,
int block_size,
int head_size,
int tokens_block) {
/*******************************************************nram space***********************
* nram:| input | scale | cache_offset | scale_offset | mask | temp | index |
* input size: tokens_block * head_num * head_size * sizeof(float)
* scale size: head_num * head_size * sizeof(float)
* cache_offset size: tokens_block * head_num * sizeof(float)
* scale_offset size: equal to cache_offset size
* mask size: CEIL_DIV(tokens_size * head_num, 8) * sizeof(int8_t)
* temp size: CEIL_ALIGN(token_size * head_num, 8) * sizeof(int)
* index size: head_num * sizeof(int)
****************************************************************************************/
#if __BANG_ARCH__ > 500
int token_begin = taskId * tokens_block;
if (token_begin >= tokens_num) return;
int token_handle = std::min(tokens_block, tokens_num - token_begin);
int seq_len = token_handle * head_num;
int head_len = head_num * head_size;
int pad8_num = CEIL_DIV(seq_len, CHAR_BIT) * CHAR_BIT;
int input_size = seq_len * head_size * sizeof_(float);
int8_t *nram_input = nram_buffer;
float *nram_scale = (float *)(nram_buffer + input_size);
int *cache_offset = (int *)(nram_scale + head_len);
int *scale_offset = cache_offset + pad8_num;
int *nram_mask = scale_offset + pad8_num;
int *nram_temp = nram_mask + pad8_num;
int *head_index = nram_temp + pad8_num;
// generate range: (0, 1, 2, ..., (head_num - 1))
__memcpy(head_index, nram_range_32, std::min(head_num, 32) * sizeof_(int), NRAM2NRAM);
int begin = 32;
while (begin < head_num) {
int count = std::min(begin, head_num - begin);
__bang_add_scalar(head_index + begin, head_index, begin, count);
begin += count;
}
// load slot(token_handle) -> expand(head_num, token_handle) ->transpose(token_handle, head_num)
int token_size = token_handle * sizeof_(int);
__memcpy(scale_offset, slot_mapping + token_begin, token_size, GDRAM2NRAM);
__memcpy(nram_temp, scale_offset, token_size, NRAM2NRAM, token_size, 0, head_num - 1);
__bang_transpose(scale_offset, nram_temp, head_num, token_handle);
__bang_write_zero((float *)nram_temp, pad8_num);
__bang_ge_bitindex((float *)nram_mask, (float *)scale_offset, (float *)nram_temp, pad8_num);
// calculate cache/scale scatter offset
__bang_div(cache_offset, scale_offset, (int)block_size, seq_len);
__bang_rem(scale_offset, scale_offset, (int)block_size, seq_len);
__bang_mul_scalar(cache_offset, cache_offset, head_num * block_size, seq_len);
__bang_mul_scalar(head_index, head_index, block_size, head_num);
__bang_cycle_add(cache_offset, cache_offset, head_index, seq_len, head_num);
__bang_add(scale_offset, cache_offset, scale_offset, seq_len);
__bang_mul_scalar(cache_offset, scale_offset, head_size, seq_len);
__bang_mul_scalar(scale_offset, scale_offset, sizeof_(float), seq_len);
int hidden_bytes = head_num * head_size * sizeof_(T);
bool half_size = (sizeof(T) == sizeof(half));
if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr) {
// load key_cache_scale
__memcpy(nram_scale, key_cache_scale, head_len * sizeof_(float), GDRAM2NRAM);
__bang_recip(nram_scale, nram_scale, head_len);
// (token_handle, head_num, head_size)
__memcpy(nram_input + half_size * token_handle * hidden_bytes, key + token_begin * key_stride0,
hidden_bytes, GDRAM2NRAM, hidden_bytes, key_stride0 * sizeof_(T), token_handle - 1);
// quantify
quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len);
// scatter to gdram
__scatter(key_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size,
NRAM2GDRAM, head_size, seq_len);
}
if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr) {
// load key_cache_scale
__memcpy(nram_scale, value_cache_scale, head_len * sizeof_(float), GDRAM2NRAM);
__bang_recip(nram_scale, nram_scale, head_len);
// (token_handle, head_num, head_size)
__memcpy(nram_input + half_size * token_handle * hidden_bytes,
value + token_begin * value_stride0, hidden_bytes, GDRAM2NRAM, hidden_bytes,
value_stride0 * sizeof_(T), token_handle - 1);
// quantify
quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len);
// scatter to gdram
__scatter(value_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size,
NRAM2GDRAM, head_size, seq_len);
}
#endif
}
} // namespace kernels
KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size) {
if (is_arch300()) {
std::cerr << "[invokeOfflineQuantToPagedCache]: kernel does not support MLU300 devices."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int dtype_size = 1;
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
dtype_size = 2;
} else if (data_type == CNNL_DTYPE_FLOAT) {
dtype_size = 4;
} else {
std::cerr << "invokeOfflineQuantToPagedCache: unsupport data type\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
if (kv_cache_range > UINT32_MAX) {
std::cerr
<< "invokeOfflineQuantToPagedCache: The addressing range of kv_cache cannot exceed 4G."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
// nram_size_need: token_block * head_num * head_size + head_num * head_size * sizeof(float)
// token_block * head_num * 4 * sizeof(int) + head_num * sizeof(int)
// nram uesd: 480KB
int nram_size = 480 * 1024 - num_heads * sizeof(int) - num_heads * head_size * sizeof(float);
int hidden_bytes = num_heads * head_size * sizeof(float) +
4 * CEIL_DIV(num_heads, CHAR_BIT) * CHAR_BIT * sizeof(int);
int seq_block = nram_size / hidden_bytes;
if (seq_block <= 0) {
std::cerr << "invokeOfflineQuantToPagedCache: "
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
<< "should be less than 480KB.\n";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (seq_block > 16) {
seq_block = seq_block / 16 * 16;
}
int cluster_num, core_dim;
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
int core_num = core_dim * cluster_num;
seq_block = std::min(seq_block, CEIL_DIV(num_tokens, core_num));
uint32_t task_dim = CEIL_DIV(num_tokens, seq_block);
cnrtDim3_t dim{1, task_dim, 1};
if (data_type == CNNL_DTYPE_FLOAT) {
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)key, (float *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
} else if (data_type == CNNL_DTYPE_HALF) {
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)key, (half *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
} else {
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)key, (bfloat16_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,62 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
#define CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
#include "kernel_utils.h"
namespace tmo {
/**
* @brief Perform offline_quant_to_paged_cache operation.
* @param queue[in]: The queue for mlu.
* @param data_type[in]: The cnnl data type of key.
* @param key[in]: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens,
* num_heads, head_size]. Data type of key must be half/bfloat16_t/float.
* @param value[in]: Pointer to the MLU memory that stores the value tensor which has shape
* [num_tokens, num_heads, head_size]. Data type of value must be half/bfloat16_t/float.
* @param key_cache[out]: Pointer to the MLU memory that stores the key_cache tensor which has
* shape [num_blocks, num_heads, block_size, head_size]. Data type of key cache must be int8_t.
* @param value_cache[out]: Pointer to the MLU memory that stores the value_cache tensor which has
* shape [num_blocks, num_heads, block_size, head_size]. Data type of value cache must be int8_t.
* @param key_cache_scale[in]: Pointer to the MLU memory that stores the key_cache_scale tensor
* which has shape [num_heads, head_size]. Data type of key cache scale must be float.
* @param value_cache_scale[in]: Pointer to the MLU memory that stores the value_cache_scale tensor
* which has shape [num_heads, head_size]. Data type of value cache scale must be float.
* @param slot_mapping[in]: Pointer to the MLU memory that stores the slot_mapping tensor which has
* shape [num_tokens]. Data type of slot mapping must be int32_t.
* @param key_stride0[in]: The first dimension stride length of key_cache tensor.
* @param value_stride0[in]: The first dimension stride length of value_cache tensor.
* @param num_tokens[in]: Total number of tokens.
* @param num_heads[in]: Head number.
* @param block_num[in]: Total number of blocks.
* @param block_size[in]: Number of tokens per block.
* @param head_size[in]: Head size.
* @note: offline_quant_to_paged_cache does not support MLU300 device.
*/
KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue,
cnnlDataType_t data_type,
void *key,
void *value,
void *key_cache,
void *value_cache,
void *key_cache_scale,
void *value_cache_scale,
void *slot_mapping,
size_t key_stride0,
size_t value_stride0,
int num_tokens,
int num_heads,
int block_num,
int block_size,
int head_size);
} // namespace tmo
#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_

View File

@@ -0,0 +1,156 @@
#include <algorithm>
#include "cnrt.h"
#include "operate_cu_seq_lens.mluh"
namespace {
constexpr int pair_elem_num = 2;
}
namespace tmo {
namespace kernels {
#define ONCHIP_DATA_NUM ((int)((__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) / sizeof(int)))
__nram__ int nram_buffer[ONCHIP_DATA_NUM];
__nram__ const int acc_seq_lens[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
__mlu_func__ void genSeqLens(int *seq_len_nram, int start, int multi, int elem_count) {
constexpr int acc_seq_lens_size = 16;
int count = std::min(acc_seq_lens_size, elem_count);
int add_on = multi * acc_seq_lens_size;
__bang_mul_scalar(seq_len_nram, acc_seq_lens, multi, count);
__bang_add_scalar(seq_len_nram, seq_len_nram, start, count);
while (count < elem_count) {
__bang_add_scalar(seq_len_nram + count, seq_len_nram, add_on,
std::min(count, elem_count - count));
count *= 2;
add_on *= 2;
}
}
__mlu_global__ void MLUSliceCuSeqlens(int *cu_seq_lens,
int *sliced_cu_seq_lens,
int batch,
int every,
int remain,
int loop) {
int cu_seq_lens_elem_count = batch + 1;
int sliced_cu_seq_lens_elem_count = batch + loop;
int *cu_seq_lens_narm = nram_buffer;
int *sliced_cu_seq_lens_narm = cu_seq_lens_narm + cu_seq_lens_elem_count;
int *sliced_cu_seq_lens_narm_start = sliced_cu_seq_lens_narm;
__memcpy(cu_seq_lens_narm, cu_seq_lens, cu_seq_lens_elem_count * sizeof(int), GDRAM2NRAM);
__bang_write_zero(sliced_cu_seq_lens_narm, sliced_cu_seq_lens_elem_count);
for (int i = 0; i < loop; ++i) {
int elem_num = 1 + (i == loop - 1 && remain != 0 ? remain : every);
__bang_sub_scalar(sliced_cu_seq_lens_narm, cu_seq_lens_narm, cu_seq_lens_narm[0], elem_num);
cu_seq_lens_narm += elem_num - 1;
sliced_cu_seq_lens_narm += elem_num;
}
__memcpy(sliced_cu_seq_lens, sliced_cu_seq_lens_narm_start,
sliced_cu_seq_lens_elem_count * sizeof(int), NRAM2GDRAM);
}
__mlu_global__ void MLUGenerateKVCuSeqlens(int *gen_cu_seq_lens,
int every,
int remain,
int loop,
int seq_len,
bool is_causal_mask,
int seg_data_num,
int task_num) {
int offset = seg_data_num * taskIdX;
int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset);
int seq_len_elem_num = total_elem_num / pair_elem_num;
int *gen_cu_seq_lens_narm = nram_buffer;
__bang_write_zero(gen_cu_seq_lens_narm, total_elem_num);
if (is_causal_mask) {
int *seq_lens_narm = gen_cu_seq_lens_narm + total_elem_num;
genSeqLens(seq_lens_narm, every * offset / pair_elem_num, every, seq_len_elem_num);
__memcpy(gen_cu_seq_lens_narm + 1, seq_lens_narm, sizeof(int), NRAM2NRAM,
pair_elem_num * sizeof(int), sizeof(int), seq_len_elem_num - 1);
if (remain != 0 && taskIdX == task_num - 1) {
gen_cu_seq_lens_narm[total_elem_num - 1] -= (every - remain);
}
} else {
__bang_write_value(gen_cu_seq_lens_narm + 1, 1, seq_len, pair_elem_num * sizeof(int),
seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0);
}
__memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int),
NRAM2GDRAM);
}
__mlu_global__ void MLUGenerateQCuSeqlens(int *gen_cu_seq_lens,
int every,
int remain,
int loop,
int seg_data_num,
int task_num) {
int offset = seg_data_num * taskIdX;
int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset);
int seq_len_elem_num = total_elem_num / pair_elem_num;
int *gen_cu_seq_lens_narm = nram_buffer;
__bang_write_zero(gen_cu_seq_lens_narm, total_elem_num);
__bang_write_value(gen_cu_seq_lens_narm + 1, 1, every, pair_elem_num * sizeof(int),
seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0);
if (remain != 0 && taskIdX == task_num - 1) {
gen_cu_seq_lens_narm[total_elem_num - 1] = remain;
}
__memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int),
NRAM2GDRAM);
}
} // namespace kernels
KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue,
int *cu_seq_lens,
int *sliced_cu_seq_lens,
int batch,
int parallel_num) {
int every = (batch + parallel_num - 1) / parallel_num;
int repeat = batch / every;
int remain = batch % every;
int loop = repeat + (remain != 0);
cnrtDim3_t dim{1, 1, 1};
kernels::MLUSliceCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(cu_seq_lens, sliced_cu_seq_lens,
batch, every, remain, loop);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue,
int *gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len) {
int every = (seq_len + parallel_num - 1) / parallel_num;
int repeat = seq_len / every;
int remain = seq_len % every;
int loop = repeat + (remain != 0);
int seg_data_num = ONCHIP_DATA_NUM / 2;
if (is_kv_seq_len && is_causal_mask) {
// max segnum for 2d memcpy is 64k
seg_data_num = std::min(pair_elem_num * 64 * 1024, seg_data_num);
}
int total_elem_num = loop * pair_elem_num;
int task_num = (total_elem_num + seg_data_num - 1) / seg_data_num;
cnrtDim3_t dim{(unsigned int)task_num, 1, 1};
if (is_kv_seq_len) {
kernels::MLUGenerateKVCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(
gen_cu_seq_lens, every, remain, loop, seq_len, is_causal_mask, seg_data_num, task_num);
} else {
kernels::MLUGenerateQCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(
gen_cu_seq_lens, every, remain, loop, seg_data_num, task_num);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
#define CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief slice cu_seq_lens and cu_k_seq_lens for parallel context when attention split batch;
* @example
* cu_seq_lens: [0, 2, 5, 10, 20, 33, 46, 51, 77]
* batch: 8, parallel_num: 3
* sliced_cu_seq_lens: [0, 2, 5, 10, 0, 10, 23, 36, 0, 5, 31]
* @param queue: The queue for mlu.
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the current seq lens, the shape
* is [batch + 1].
* @param sliced_cu_seq_lens: Output. Pointer to the MLU memory that stores the sliced current seq
* lens, the shape is [batch + loop_time].
* @param batch: Batch size.
* @param parallel_num: Parallel num of batch.
*/
KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue,
int *cu_seq_lens,
int *sliced_cu_seq_lens,
int batch,
int parallel_num);
/**
* @brief generate cu_seq_lens and cu_k_seq_lens for parallel context when attention split seq;
* @example
* seq_len: 11, parallel_num: 3
* gen_cu_seq_lens for q: [0, 4, 0, 4, 0, 3]
* @example
* seq_len: 11, parallel_num: 3
* is_causal_mask false, gen_cu_seq_lens for kv: [0, 11, 0, 11, 0, 11]
* is_causal_mask true , gen_cu_seq_lens for kv: [0, 4, 0, 8, 0, 11]
* @param queue: The queue for mlu.
* @param gen_cu_seq_lens: Output. Pointer to the MLU memory that stores the generated current seq
* lens, the shape is [2 * loop_time].
* @param seq_len: Sequence length.
* @param parallel_num: Parallel num of sequence length.
* @param is_causal_mask: Whether self attention use causal mask.
* @param is_kv_seq_len: The gen_cu_seq_lens is for q or kv.
*/
KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue,
int *gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len);
} // namespace tmo
#endif // CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_

View File

@@ -0,0 +1,81 @@
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "preload.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024)
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
__mlu_func__ void split(const int64_t total,
const int64_t num,
const int64_t id,
size_t &every,
size_t &offset) {
int64_t base = total / num;
int64_t tail = total - base * num;
every = base + (id < tail ? 1 : 0);
offset = base * id + (id < tail ? id : tail);
}
__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) {
#if __BANG_ARCH__ > 372
size_t cluster_preload_size = 0;
size_t cluster_preload_offset = 0;
split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset);
size_t load_repeat = cluster_preload_size / SRAM_SIZE;
size_t load_remain = cluster_preload_size % SRAM_SIZE;
for (size_t i = 0; i < load_repeat + 1; i++) {
if (i == load_repeat && load_remain == 0) {
break;
}
size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain);
int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE;
if (loop_load_size > 0) {
__memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM);
}
}
#endif
}
} // namespace kernels
KernelStatus invokePreload(cnrtQueue_t queue,
void *filter_ptr,
size_t filter_size,
size_t preload_size) {
if (preload_size == 0) {
std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (preload_size > filter_size) {
preload_size = filter_size;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1};
if (cluster_num == 1) {
dim.y = 1;
} else if (cluster_num >= 2) {
dim.y = 2;
}
kernels::MLUUnion1Preload<<<dim, cnrtFuncTypeUnion1, queue>>>(filter_ptr, preload_size);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,34 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_PRELOAD_MLUH_
#define CSRC_KERNELS_PRELOAD_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief When tp is greater than 1, while executing reducesum, the weight of ffn
* or selfattention to be calculated is loaded into LLC in advance.
* @param queue: The queue for mlu.
* @param filter_ptr: Input. Pointer to the MLU memory that stores the weight of ffn or
* selfattention.
* @param filter_size: The weight size of ffn or selfattention.
* @param preload_size: The size of the preload weight.
* @note The weights of ffn or selfattention must be continuous in filter_ptr.
*/
KernelStatus invokePreload(cnrtQueue_t queue,
void *filter_ptr,
size_t filter_size,
size_t preload_size);
} // namespace tmo
#endif // CSRC_KERNELS_PRELOAD_MLUH_

Some files were not shown because too many files have changed in this diff Show More