Compare commits
73 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd221f3084 | ||
|
|
7b4f7d74c3 | ||
|
|
16d41a8fc1 | ||
|
|
633aa4db30 | ||
|
|
6eae065dd6 | ||
|
|
e752946445 | ||
|
|
7626238695 | ||
|
|
f3a4d10195 | ||
|
|
ed6a2aff91 | ||
|
|
6faa595799 | ||
|
|
50e02f2011 | ||
|
|
c584139543 | ||
|
|
2ad23aa8da | ||
|
|
86fd3b5a92 | ||
|
|
eaeb5169e0 | ||
|
|
44ffd2094a | ||
|
|
5132af6176 | ||
|
|
5c4c2222ba | ||
|
|
026380fddb | ||
|
|
d9d1f3a724 | ||
|
|
d93c740e4d | ||
|
|
153bc4ec7b | ||
|
|
96ed925486 | ||
|
|
8ac7afcbd3 | ||
|
|
128aed196c | ||
|
|
659ef273c8 | ||
|
|
98003e6f8b | ||
|
|
094541296e | ||
|
|
5a05c22162 | ||
|
|
60f3a23d5f | ||
|
|
9c1d7cc9ff | ||
|
|
934ed88691 | ||
|
|
fa0219fbf8 | ||
|
|
efbb06147a | ||
|
|
a26729bf7f | ||
|
|
8a613d15bd | ||
|
|
a6f39375e5 | ||
|
|
afc34d988e | ||
|
|
fa194c215b | ||
|
|
5fbe8b20a7 | ||
|
|
2dad4e71c5 | ||
|
|
cb1846cd4f | ||
|
|
81fc273396 | ||
|
|
3ef89630ab | ||
|
|
40dee08f7b | ||
|
|
1d70f93cfc | ||
|
|
8ecba6115e | ||
|
|
65ad893ee7 | ||
|
|
d08217307d | ||
|
|
8ac4215755 | ||
|
|
a095dede48 | ||
|
|
374826c841 | ||
|
|
ebdc6fed03 | ||
|
|
b702adf015 | ||
|
|
fba02652c8 | ||
|
|
5d2f4000cc | ||
|
|
f088a6b45d | ||
|
|
d31ace279b | ||
|
|
ac2082ff36 | ||
|
|
2068984bde | ||
|
|
df848b4284 | ||
|
|
4d0da98b9e | ||
|
|
05605419e3 | ||
|
|
332e5f71a6 | ||
|
|
6e38461af6 | ||
|
|
b399840b8d | ||
| 808b9b7c97 | |||
|
|
6b650ae280 | ||
|
|
92f0016e6f | ||
|
|
9563c9af0d | ||
|
|
3b3e614cb6 | ||
|
|
3cf13dd8c5 | ||
|
|
79dfc69789 |
240
.gitignore
vendored
Normal file
240
.gitignore
vendored
Normal file
@@ -0,0 +1,240 @@
|
||||
# version file generated by setuptools-scm
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
cmake-build-*/
|
||||
CMakeUserPresets.json
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
/.deps/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Claude
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
|
||||
# Codex
|
||||
AGENTS.md
|
||||
.codex/
|
||||
|
||||
# Cursor
|
||||
.cursor/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
# Results
|
||||
*.csv
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
|
||||
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
|
||||
!vllm/benchmarks/lib/
|
||||
|
||||
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
|
||||
vllm/grpc/vllm_engine_pb2.py
|
||||
vllm/grpc/vllm_engine_pb2_grpc.py
|
||||
vllm/grpc/vllm_engine_pb2.pyi
|
||||
19
README.md
19
README.md
@@ -3,6 +3,7 @@
|
||||
# 寒武纪 mlu370 文本生成
|
||||
该模型测试框架在寒武纪mlu370 (X8/X4)加速卡上,基于vllm 推理引擎,适配了 Qwen1.5-1.8B-Chat 模型。
|
||||
|
||||
|
||||
* Qwen1.5-1.8B-Chat 是通义千问系列中一款约18亿参数、轻量级的中英文对话大模型,专为高效推理和多场景聊天交互设计。
|
||||
* Llama-2-7b-chat-hf:Meta 发布的 LLaMA 2 系列中 70 亿参数的对话优化版开源大模型,适合多轮聊天与通用任务。
|
||||
* ChatGLM3-6B:智谱 AI 推出的第 3 代 ChatGLM 系列中 60 亿参数的中英双语对话大模型,支持推理、代码和多任务能力。
|
||||
@@ -162,5 +163,19 @@ curl http://localhost:80/v1/chat/completions \
|
||||
|
||||
| 模型名称 | mlu370-X8首字延迟(秒) | mlu370-X8输入处理速度(字每秒) | mlu370-X8输出速度(字每秒) | mlu370-X8输出质量 | Nvidia A100字延迟(秒) | Nvidia A100输入处理速度(字每秒) | Nvidia A100输出速度(字每秒) | Nvidia A100输出质量 |
|
||||
| ------------------- | ------------------- | -------------------| ------------------- | ------------------- | ------------------- | ------------------- | ------------------- | ------------------- |
|
||||
| Qwen/Qwen-1_8B |0.203 | 13493.2 | 119.2 | 10.0 | 0.052 | 25591.5 | 165.0 | 15.0|
|
||||
| Qwen/Qwen1.5-0.5B |0.132 | 12366.6 | 106.9 | 15.0 | 0.066 | 24935.4 | 151.4 | 10.0|
|
||||
| Qwen/Qwen-1_8B |0.203 | 13493.2 | 119.2 | 10.0 | 0.052 | 25591.5 | 165.0 | 15.0|
|
||||
| Qwen/Qwen1.5-0.5B |0.132 | 12366.6 | 106.9 | 15.0 | 0.066 | 24935.4 | 151.4 | 10.0|
|
||||
|
||||
## 版本更新记录
|
||||
|
||||
| 版本 | 日期 | 更新内容 |
|
||||
|------|------|----------|
|
||||
| v0.0.2 | 2026-02-04 | **Qwen3 模型支持**:实现 QK Normalization 架构适配,修复 rope_scaling 和 tokenizer 兼容性问题,解决张量连续性导致的 view 操作失败 |
|
||||
| v0.0.3 | 2026-02-06 | **Transformers 通用后端**:支持通过 `auto_map` 加载任意自定义 HuggingFace 模型,新增 registry 回退逻辑、Linear 返回值处理、RMSNorm 维度恢复等 |
|
||||
| v0.0.3.1 | 2026-02-06 | **CNNL Tensor 溢出修复**:解决极小模型在大显存设备上部署时 KV cache 元素数超过 int32 限制的问题,在 mlu_worker 和 cache_engine 中添加双重防护 |
|
||||
| v0.0.4 | 2026-02-10 | **Gemma3 模型支持**:新增 Gemma3ForCausalLM 模型实现(含 QK Normalization、per-layer rope 配置、layer_types 滑动窗口),修复 `patch_rope_scaling_dict` 在 rope_scaling 缺少 `rope_type` 键时崩溃的问题,更新模型注册表及 config.py 中 interleaved attention 和 dtype 自动处理逻辑 |
|
||||
| v0.0.4.1 | 2026-02-10 | **Gemma3 rope 兼容性修复**:修复新版 transformers `Gemma3TextConfig` 缺少 `rope_theta` 属性的问题,从 `rope_parameters` 字典兼容提取 rope 配置(支持 Transformers v4/v5);修复 `rope_scaling` 嵌套字典导致 `get_rope` 缓存 unhashable 的问题;适配 MLU `forward_mlu` 接口,将 q/k 合并为单张量调用 rotary_emb 后再拆分 |
|
||||
| v0.0.5 | 2026-02-10 | **Qwen3MoE 模型支持**:新增 Qwen3MoeForCausalLM 模型实现(含 QK Normalization、ReplicatedLinear shared_expert_gate),修复 FusedMoE `forward_mlu` 签名缺少 `layer` 参数的已有 bug(影响所有 MLU 上的 MoE 模型),更新模型注册表 |
|
||||
| v0.0.6 | 2026-02-11 | **DeepSeek V3 模型支持**:注册 DeepseekV3ForCausalLM(复用 V2 实现),扩展 MLU MLA config 判断支持 `deepseek_v3`,实现 `noaux_tc` 路由方式(`e_score_correction_bias`),跳过 MTP 层权重加载,修复 MLA unpaged 缓存路径使用错误的 paged cache 算子(prefill + decode 均替换为 `reshape_linear_cache`) |
|
||||
| v0.0.6 | 2026-02-11 | **DeepSeek V3 MTP 推测解码**:新建 `deepseek_mtp.py` 实现 MTP draft model(复用 DeepseekV2DecoderLayer,EAGLE 模板适配),SpeculativeConfig 自动检测 `num_nextn_predict_layers` 并改写 draft config,target worker 为 MTP 返回 hidden states,MLU config 三处 model_type 判断扩展支持 `deepseek_mtp` 以匹配 MLA cache 格式 |
|
||||
| v0.0.6 | 2026-02-11 | **Llama4 模型支持**:新建 Llama4ForCausalLM 模型实现(复合 config 处理、sigmoid routing MoE、QK Normalization、交替 dense/MoE 层),新建 MLU hijack 适配(SparseMoeMlp MoE 替换、embedding dtype 修复),处理 `Llama4Config` 嵌套 `text_config` 的 architectures 提取问题。**⚠️ MoE dense 模式(影响所有 MoE 模型)**:原始 `forward_experts_nofused` 包含 `torch.unique`、`torch.tensor` 创建、数据依赖分支等 graph capture 不兼容操作,导致 MLU370 上所有走 `SparseMoeMlp` 的 MoE 模型必须加 `--enforce-eager` 才能运行。现已改为 dense 模式(每个 expert 处理全部 token),解决了 graph capture 兼容性,所有 MoE 模型无需 `--enforce-eager` 即可运行,但计算量增大 num_experts/topk 倍(Mixtral 4x、Llama4 16x、Qwen2-MoE 15x)。DeepSeek V2/V3 不受影响(有独立 MLU MoE hijack)。后续应拆分 `is_use_fused_moe` 标志让 MLU370 走 `forward_group_experts` 路径优化 |
|
||||
|
||||
4
torch_mlu_ops-v1.3.2/.clang-format
Normal file
4
torch_mlu_ops-v1.3.2/.clang-format
Normal file
@@ -0,0 +1,4 @@
|
||||
BasedOnStyle: Chromium
|
||||
ColumnLimit: 100
|
||||
PointerAlignment: Right
|
||||
AllowShortIfStatementsOnASingleLine: true
|
||||
15
torch_mlu_ops-v1.3.2/.clang-tidy
Normal file
15
torch_mlu_ops-v1.3.2/.clang-tidy
Normal file
@@ -0,0 +1,15 @@
|
||||
Checks: "bugprone-*,google-*,-google-explicit-constructor,cppcoreguidelines-avoid-const-or-ref-data-members,cppcoreguidelines-init-variables,cppcoreguidelines-interfaces-global-init,cppcoreguidelines-misleading-capture-default-by-value,cppcoreguidelines-missing-std-forward,cppcoreguidelines-no-malloc,cppcoreguidelines-pro-type-member-init,cppcoreguidelines-rvalue-reference-param-not-moved,cppcoreguidelines-slicing,cppcoreguidelines-virtual-class-destructor,performance-unnecessary-copy-initialization",
|
||||
CheckOptions: [
|
||||
{
|
||||
key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion,
|
||||
value: false
|
||||
},
|
||||
{
|
||||
key: cppcoreguidelines-narrowing-conversions.WarnOnIntegerToFloatingPointNarrowingConversion,
|
||||
value: false
|
||||
},
|
||||
{
|
||||
key: cppcoreguidelines-narrowing-conversions.IgnoreConversionFromTypes,
|
||||
value: size_t;ptrdiff_t;size_type;difference_type
|
||||
}
|
||||
]
|
||||
13
torch_mlu_ops-v1.3.2/.dockerignore
Normal file
13
torch_mlu_ops-v1.3.2/.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
||||
.cache/
|
||||
.devcontainer/
|
||||
build/
|
||||
__pycache__/
|
||||
.vscode/
|
||||
*.so
|
||||
*.out
|
||||
*.o
|
||||
core_dump_*
|
||||
record.txt
|
||||
*.log
|
||||
CMakePresets.json
|
||||
.clangd
|
||||
4
torch_mlu_ops-v1.3.2/.gitattributes
vendored
Normal file
4
torch_mlu_ops-v1.3.2/.gitattributes
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# git archive exclude files
|
||||
tools/ci export-ignore
|
||||
docker export-ignore
|
||||
# legacy export-ignore
|
||||
25
torch_mlu_ops-v1.3.2/.gitignore
vendored
Normal file
25
torch_mlu_ops-v1.3.2/.gitignore
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
.cache/
|
||||
.devcontainer/
|
||||
build/
|
||||
packages/
|
||||
__pycache__/
|
||||
.idea/
|
||||
.vscode/
|
||||
*.so
|
||||
*.out
|
||||
*.o
|
||||
*.pyc
|
||||
*.deb
|
||||
*.pt
|
||||
core_dump_*
|
||||
record.txt
|
||||
*.log
|
||||
CMakePresets.json
|
||||
.clangd
|
||||
compile_flags.txt
|
||||
legacy/samples/pytorch/outputs/
|
||||
legacy/src/thirdparty/bangtransformer.egg-info/
|
||||
legacy/src/thirdparty/dist/
|
||||
torch_mlu_ops.egg-info/
|
||||
torch_mlu_ops/_version.py
|
||||
dist/
|
||||
15
torch_mlu_ops-v1.3.2/.leak.supp
Normal file
15
torch_mlu_ops-v1.3.2/.leak.supp
Normal file
@@ -0,0 +1,15 @@
|
||||
leak:libtorch_cpu.so
|
||||
leak:libtorch_mlu.so
|
||||
leak:libtorch_mlu_python.so
|
||||
leak:libtriton.so
|
||||
leak:libtorch_python.so
|
||||
leak:libcatch_python.so
|
||||
leak:pybind11::cpp_function::initialize_generic
|
||||
leak:pybind11::cpp_function::make_function_record
|
||||
leak:pybind11::detail::process_attribute
|
||||
leak:libstdc++.so
|
||||
leak:numpy
|
||||
leak:Objects
|
||||
leak:Modules
|
||||
leak:Python
|
||||
leak:libcrypto.so
|
||||
29
torch_mlu_ops-v1.3.2/LICENSE
Normal file
29
torch_mlu_ops-v1.3.2/LICENSE
Normal file
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2024, Cambricon Technologies
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
183
torch_mlu_ops-v1.3.2/README.md
Normal file
183
torch_mlu_ops-v1.3.2/README.md
Normal file
@@ -0,0 +1,183 @@
|
||||
<div align="center">
|
||||
|
||||
Torch-MLU-Ops
|
||||
===========================
|
||||
|
||||

|
||||
|
||||
<div align="center">
|
||||
<b>
|
||||
<a href="https://www.cambricon.com/docs/sdk_1.15.0/bangtransformer_0.4.0/index.html">
|
||||
<font size="4"> 📖 Torch_MLU_Ops用户手册</font>
|
||||
</a>
|
||||
</b>
|
||||
|
||||
<b>
|
||||
<a href="https://developer.cambricon.com/">
|
||||
<font size="4"> 🌏 寒武纪开发者社区</font>
|
||||
</a>
|
||||
</b>
|
||||
|
||||
<b>
|
||||
<a href="https://sdk.cambricon.com/download?sdk_version=V1.13.0&component_name=Basis">
|
||||
<font size="4"> 🛠️ 依赖组件获取</font>
|
||||
</a>
|
||||
</b>
|
||||
</div>
|
||||
|
||||
|
||||
---
|
||||
<div align="left">
|
||||
|
||||
## 目录
|
||||
|
||||
- [Torch-MLU-Ops](#torch-mlu-ops)
|
||||
- [目录](#目录)
|
||||
- [简介](#简介)
|
||||
- [安装](#安装)
|
||||
- [环境要求](#环境要求)
|
||||
- [通过docker镜像启动](#通过docker镜像启动)
|
||||
- [编译安装](#编译安装)
|
||||
- [测试](#测试)
|
||||
- [目录结构](#目录结构)
|
||||
- [关键特性](#关键特性)
|
||||
|
||||
## 简介
|
||||
|
||||
Torch-MLU-Ops是寒武纪设计和开发的PyTorch第三方算子库。对于使用PyTorch框架的开发者,通过Torch-MLU-Ops,能够便捷地使用这些自定义算子,进行算子的集成、评测和业务部署。
|
||||
|
||||
Torch-MLU-Ops已全量覆盖LLM(Large Language Model)推理场景下的常见算子。作为后端,已支持Cambricon vLLM、Cambricon TGI、Cambricon Stable Diffusion web UI、Cambricon ComfyUI以及Cambricon Diffusers。
|
||||
|
||||
|
||||
## 安装
|
||||
|
||||
### 环境要求
|
||||
|
||||
Torch-MLU-Ops支持的操作系统以及软件依赖如下。
|
||||
|
||||
不同平台支持的软件版本如下:
|
||||
|
||||
| 操作系统 | 编译器版本 | CMake版本 |
|
||||
| ------------------------| ----------------------| -------------------------- |
|
||||
| Ubuntu22.04 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 |
|
||||
| Debian10.11 x86-64 | gcc 9.4.0或者更高版本 | 3.10或者更高版本 |
|
||||
|
||||
|
||||
编译、运行Torch-MLU-Ops同时还需要配置以下依赖。
|
||||
|
||||
| Torch-MLU-Ops | Torch-MLU | CNNL | CNNL_Extra | CNCL | CNToolkit |
|
||||
| -----------------| ------------------| ---------| --------------- | --------| -----------|
|
||||
| v1.3.2 | v1.24.1 | v1.28.3 | v1.12.3 | v1.24.1 | v3.15.6 |
|
||||
| v1.3.1 | v1.24.1 | v1.28.3 | v1.12.2 | v1.24.1 | v3.15.5 |
|
||||
| v1.3.0 | v1.24.0 | v1.28.2 | v1.12.1 | v1.24.0 | v3.15.4 |
|
||||
| v1.2.3 | v1.24.0 | v1.28.1 | v1.12.1 | v1.24.0 | v3.15.3 |
|
||||
| v1.2.2 | v1.23.1 | v1.27.4 | v1.11.3 | v1.23.0 | v3.14.3 |
|
||||
| v1.2.1 | v1.23.1 | v1.27.3 | v1.11.3 | v1.22.3 | v3.14.2 |
|
||||
| v1.2.0 | v1.23.0 | v1.27.1 | v1.11.1 | v1.22.1 | v3.14.1 |
|
||||
| v1.1.4 | v1.22.2 | v1.26.7 | v1.10.4 | v1.21.1 | v3.13.7 |
|
||||
| v1.1.3 | v1.22.2 | v1.26.6 | v1.10.3 | v1.21.1 | v3.13.5 |
|
||||
|
||||
|
||||
此外运行Torch-MLU-Ops还依赖PyTorch环境。Torch-MLU-Ops已支持的PyTorch版本请参考。
|
||||
|
||||
| Torch-MLU-Ops | PyTorch |
|
||||
| -----------------| ------------------|
|
||||
| v1.3.2 | v2.1, v2.4, v2.5 |
|
||||
| v1.3.1 | v2.1, v2.4, v2.5 |
|
||||
| v1.3.0 | v2.1, v2.4, v2.5 |
|
||||
| v1.2.3 | v2.1, v2.4, v2.5 |
|
||||
| v1.2.2 | v2.1, v2.4, v2.5 |
|
||||
| v1.2.1 | v2.1, v2.4, v2.5 |
|
||||
| v1.2.0 | v2.1, v2.4, v2.5 |
|
||||
| v1.1.4 | v2.1, v2.3, v2.4 |
|
||||
| v1.1.3 | v2.1, v2.3 |
|
||||
|
||||
### 通过docker镜像启动
|
||||
|
||||
寒武纪提供的解决方案镜像已提供了Torch-MLU-Ops所需要的依赖,请参考Cambricon PyTorch Container Image使用。您可以从[寒武纪开发者社区](https://account.cambricon.com/interaction/SMmmcGFmw837gghtZ7VR4)获取该镜像。
|
||||
|
||||
### 编译安装
|
||||
|
||||
Torch-MLU-Ops编译脚本依赖于`NEUWARE_HOME`环境变量,该环境变量指向寒武纪CNToolkit的安装路径,通常为`/usr/local/neuware`。
|
||||
|
||||
在运行示例程序前,您需要将`NEUWARE_HOME`中的`lib64`目录添加到`LD_LIBRARY_PATH`环境变量中,例如:
|
||||
|
||||
```shell
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$NEUWARE_HOME/lib64
|
||||
```
|
||||
|
||||
如果您使用寒武纪提供的docker镜像,以上环境变量均已经被设置好了。
|
||||
|
||||
设置好`NEUWARE_HOME`环境变量后,您可以使用以下命令编译Torch-MLU-Ops。
|
||||
|
||||
从远端clone仓库,进入项目的主目录:
|
||||
```shell
|
||||
cd torch_mlu_ops
|
||||
```
|
||||
|
||||
源码安装:
|
||||
```shell
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
wheel包安装:
|
||||
```shell
|
||||
python setup.py bdist_wheel
|
||||
pip install dist/torch_mlu_ops*.whl
|
||||
```
|
||||
|
||||
> _注意,如果安装过程中遇到权限问题,请获取相关权限或切换到拥有权限的user执行pip install 这步操作。_
|
||||
|
||||
移除Torch-MLU-Ops(Optional):
|
||||
```shell
|
||||
pip uninstall torch_mlu_ops
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
在测试前,请确保Torch-MLU-Ops已安装,您可以通过 `pip list` 命令查询Torch-MLU-Ops的安装情况。若有如下类似打印,则表明安装成功。
|
||||
```shell
|
||||
torch-mlu-ops 1.1.0+pt21
|
||||
```
|
||||
|
||||
Torch-MLU-Ops模块的导入方法请参考。
|
||||
```python
|
||||
import torch_mlu_ops as tmo
|
||||
```
|
||||
|
||||
在`./tests/ops_pytest`目录下,提供了`ops`级别的测试示例程序,您可以参考如下命令进行测试。
|
||||
```shell
|
||||
cd tests/ops_pytest/
|
||||
./run_test.sh
|
||||
# or 单独测试某个测例,例如
|
||||
python test_flash_attention.py
|
||||
```
|
||||
在`./tests/kernels_pytest`目录下,提供了`kernels`级别的测试示例程序,您可以参考如下命令进行测试。
|
||||
```shell
|
||||
cd tests/kernels_pytest
|
||||
./build.sh # 需要先编译
|
||||
./run_test.sh
|
||||
```
|
||||
## 目录结构
|
||||
```
|
||||
torch_mlu_ops
|
||||
|-- benchmarks ## 性能benchmarks测试代码目录
|
||||
|-- csrc ## C以及BangC源码目录
|
||||
| |-- common
|
||||
| |-- kernels
|
||||
| |-- ops
|
||||
| `-- torch_api
|
||||
|-- docs ## 文档目录
|
||||
| |-- release_notes
|
||||
| `-- user_guide
|
||||
|-- tests ## 测试代码目录,包含了kernel和ops级别的测试
|
||||
| |-- kernels_pytest
|
||||
| `-- ops_pytest
|
||||
|-- tools ## 工具目录,主要包含CI相关代码
|
||||
| |-- ci
|
||||
`-- torch_mlu_ops ## PyTorch第三方算子接口实现相关代码
|
||||
```
|
||||
|
||||
## 关键特性
|
||||
|
||||
Torch-MLU-Ops提供的自定义算子列表以及接口说明请参考[Torch-MLU-Ops用户手册](./docs/user_guide/)。
|
||||
51
torch_mlu_ops-v1.3.2/benchmarks/README.md
Normal file
51
torch_mlu_ops-v1.3.2/benchmarks/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
## benchmark测试脚本使用方式
|
||||
|
||||
Torch-MLU-Ops benchmark测试脚本为用户提供了进行算子性能测试的便捷入口。
|
||||
用户可通过以下命令获取各个参数的含义。
|
||||
|
||||
```bash
|
||||
# 测试命令帮助
|
||||
python3 benchmark_xxx.py --help
|
||||
```
|
||||
各个参数含义如下:
|
||||
|
||||
`options`:
|
||||
- -h, --help show this help message and exit
|
||||
- --repeat_times REPEAT_TIMES repeat times for testing
|
||||
- --csv write the report data to csv
|
||||
- -o O specify the output folder name under --csv mode
|
||||
|
||||
```bash
|
||||
# 测试命令示例如下
|
||||
python3 benchmark_active.py --repeat_times 10 --csv -o './active/'
|
||||
```
|
||||
支持如下算子:
|
||||
|
||||
| op_name |
|
||||
| ---------------------------------|
|
||||
| active |
|
||||
| apply_rotary |
|
||||
| attention_project |
|
||||
| ffn |
|
||||
| flash_attn |
|
||||
| fused_layer_norm |
|
||||
| fused_moe |
|
||||
| fused_norm_attention_project |
|
||||
| fused_norm_residual_ffn |
|
||||
| fused_rms_norm |
|
||||
| group_gemm |
|
||||
| matmul |
|
||||
| offline_quant_to_linear_cache |
|
||||
| per_token_smooth_quantize |
|
||||
| preload |
|
||||
| quantize |
|
||||
| reshape_linear_cache |
|
||||
| quant_to_linear_cache |
|
||||
| reshape_paged_cache |
|
||||
| single_query_cached_kv_attn |
|
||||
| smooth_quant_matmul |
|
||||
| weight_only_quant_matmul |
|
||||
| moe_gen_idx |
|
||||
| moe_expand_input |
|
||||
| moe_softmax_topk |
|
||||
| moe_combine_result |
|
||||
64
torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py
Normal file
64
torch_mlu_ops-v1.3.2/benchmarks/benchmark_active.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 16, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 1024, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 4096, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 8192, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 32768, "seq_len": 5, "hidden_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "input_dtype": [torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["input_shape", "act_mode", "is_gated", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
is_gated = params_dict["is_gated"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.active,
|
||||
input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
repeats=args.repeat_times)
|
||||
io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated)
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{batch,seq_len,hidden_size}", f"{act_mode}", f"{is_gated}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
84
torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py
Normal file
84
torch_mlu_ops-v1.3.2/benchmarks/benchmark_apply_rotary.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "rotary_dim": 128,
|
||||
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "rotary_dim": 64,
|
||||
"interleaved": True, "discrete": False, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "rotary_dim": 128,
|
||||
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "rotary_dim": 128,
|
||||
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "rotary_dim": 64,
|
||||
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "rotary_dim": 96,
|
||||
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 4, "seq_len": 1, "head_num": 80, "head_size": 128, "rotary_dim": 128,
|
||||
"interleaved": False, "discrete": True, "dynamic_ntk": False, "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "seq_len", "head_num", "head_size", "rotary_dim", "interleaved", "discrete", "dynamic_ntk", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
head_num = params_dict["head_num"]
|
||||
head_size = params_dict["head_size"]
|
||||
# full/partial
|
||||
rotary_dim = params_dict["rotary_dim"]
|
||||
# cross/fold
|
||||
interleaved = params_dict["interleaved"]
|
||||
# discrete
|
||||
discrete = params_dict["discrete"]
|
||||
# dynamic_ntk
|
||||
dynamic_ntk = params_dict["dynamic_ntk"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input = torch.randn(batch, seq_len, head_num, head_size).to(device).to(dtype) # [batch, seqlen, head_num, head_size]
|
||||
if dynamic_ntk:
|
||||
sin_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
|
||||
cos_cache = torch.randn(batch, seq_len, rotary_dim).to(device).to(dtype)
|
||||
else:
|
||||
sin_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
|
||||
cos_cache = torch.randn(seq_len, rotary_dim).to(device).to(dtype)
|
||||
if discrete:
|
||||
pos_ids = torch.randint(0, seq_len, (batch * seq_len,)).to(device).to(torch.int32)
|
||||
else:
|
||||
pos_ids = None
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.apply_rotary,
|
||||
input,
|
||||
sin_cache,
|
||||
cos_cache,
|
||||
pos_ids,
|
||||
None,
|
||||
interleaved,
|
||||
discrete,
|
||||
dynamic_ntk,
|
||||
seq_len,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{head_num}", f"{head_size}", f"{rotary_dim}", f"{interleaved}", f"{discrete}", f"{dynamic_ntk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "hidden_size": 1600,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 2048, "hidden_size": 2048,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 4096, "hidden_size": 4096,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 6144, "hidden_size": 6144,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 6656, "hidden_size": 6656,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 8192, "hidden_size": 8192,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 12288, "hidden_size": 12288,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 14336, "hidden_size": 14336,
|
||||
"has_residual": True, "has_bias": True, "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "seq_len", "input_size", "hidden_size", "has_residual", "has_bias", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
input_size = params_dict["input_size"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
has_residual = params_dict["has_residual"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
x = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
|
||||
weight = torch.randn(hidden_size, input_size).to(dtype).to(device)
|
||||
residual, bias = None, None
|
||||
if has_residual:
|
||||
residual = torch.randn(batch, seq_len, hidden_size).to(dtype).to(device)
|
||||
if has_bias:
|
||||
bias = torch.randn(hidden_size).to(dtype).to(device)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.attention_project,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
1.0,
|
||||
1.0,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{has_residual}", f"{has_bias}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
60
torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py
Normal file
60
torch_mlu_ops-v1.3.2/benchmarks/benchmark_batch_matmul.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 2, "m": 1024, "k": 1600, "n": 6400, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 2, "m": 1024, "k": 2048, "n": 8192, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 2, "m": 1024, "k": 4096, "n": 11008, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 2, "m": 1024, "k": 5120, "n": 16384, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 2, "m": 1024, "k": 6144, "n": 24576, "has_c": True, "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "m", "k", "n", "has_c", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
m = params_dict["m"]
|
||||
k = params_dict["k"]
|
||||
n = params_dict["n"]
|
||||
has_c = params_dict["has_c"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
a = torch.randn(batch, m, k).to(device).to(dtype)
|
||||
b = torch.randn(batch, n, k).to(device).to(dtype)
|
||||
c = None
|
||||
if has_c:
|
||||
c = torch.randn(batch, m, n).to(device).to(dtype)
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.batch_matmul,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
1.0,
|
||||
1.0,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{m}", f"{k}", f"{n}", f"{has_c}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
import random
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
|
||||
from common import benchmark_forward, save_to_csv
|
||||
from itertools import product
|
||||
from tabulate import tabulate
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
|
||||
"head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "head_size": 128,
|
||||
"input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [4, 8],
|
||||
"use_offset": True},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark for dequant from linear cache.")
|
||||
parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["max_batch_size", "batch_size", "max_context_len", "head_num_q", "head_num_kv",
|
||||
"cache_mem_len", "head_size", "input_dytpe", "quant_mode", "quant_bit",
|
||||
"use_offset", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
max_batch_size = params_dict["max_batch_size"]
|
||||
batch_size_list = params_dict["batch_size"]
|
||||
max_context_len_list = params_dict["max_context_len"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
cache_mem_len = params_dict["cache_mem_len"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
quant_mode_list = params_dict["quant_mode"]
|
||||
quant_bit_list = params_dict["quant_bit"]
|
||||
use_offset = params_dict["use_offset"]
|
||||
for batch_size, max_context_len, quant_mode, quant_bit, dtype in list(product( \
|
||||
batch_size_list, max_context_len_list, quant_mode_list, quant_bit_list, \
|
||||
input_dtype_list)):
|
||||
torch.manual_seed(2766)
|
||||
torch.mlu.manual_seed(2766)
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \
|
||||
* head_size >= 2**31 - 1):
|
||||
print("large tensor is not support on {}, skip".format(mlu_name))
|
||||
continue
|
||||
total_heads = head_num_q + head_num_kv * 2
|
||||
assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
|
||||
"equal to cache_mem_len."
|
||||
max_seq_offset = cache_mem_len - max_context_len
|
||||
# Generates key and cache from context
|
||||
context_lens = torch.randint(size=[batch_size], low=max_context_len,
|
||||
high=max_context_len + 1,
|
||||
dtype=torch.int32, device="mlu")
|
||||
if use_offset:
|
||||
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device="mlu")
|
||||
else:
|
||||
context_paddings = torch.zeros_like(context_lens)
|
||||
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
|
||||
context_seq_offset[1:] = cu_context_lens[:-1]
|
||||
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
|
||||
key = context[..., head_num_q:head_num_q + head_num_kv, :]
|
||||
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
|
||||
|
||||
# Generates key_cache and value_cache
|
||||
cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu()
|
||||
cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size],
|
||||
dtype=torch.int32, device="mlu")
|
||||
if quant_bit == 4:
|
||||
key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
|
||||
head_num_kv, cache_mem_len, head_size // 2), device="mlu")
|
||||
value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
|
||||
head_num_kv, cache_mem_len // 2, head_size), device="mlu")
|
||||
key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8)
|
||||
else:
|
||||
cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size),
|
||||
low=-128, high=127, dtype=torch.int32, device="mlu")
|
||||
cache = cache.to(torch.int8)
|
||||
key_cache, value_cache = cache[[0, 1]]
|
||||
|
||||
# Generates key_cache_scale and value_cache_scale
|
||||
if quant_mode == 0: # quant_mode == 0 is per channel
|
||||
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
|
||||
else: # quant_mode != 1 (== 1 for extend) is per head
|
||||
cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len),
|
||||
dtype=torch.float, device="mlu")
|
||||
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_linear_cache,
|
||||
key, value, key_cache, value_cache,
|
||||
key_cache_scale, value_cache_scale,
|
||||
context_lens, max_context_len,
|
||||
context_seq_offset if use_offset else None,
|
||||
cache_bs_id, cache_seq_offset, quant_mode,
|
||||
quant_bit, repeats=args.repeat_times)
|
||||
content = [f"{max_batch_size}", f"{batch_size}", f"{max_context_len}", f"{head_num_q}",
|
||||
f"{head_num_kv}", f"{cache_mem_len}", f"{head_size}", f"{dtype}", f"{quant_mode}",
|
||||
f"{quant_bit}", f"{quant_mode}", f"{use_offset}", f"{hardware_time}",
|
||||
f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,110 @@
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
|
||||
from common import benchmark_forward, save_to_csv
|
||||
from itertools import product
|
||||
from tabulate import tabulate
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"max_batch_size": 128, "batch_size": [1, 32, 64], "max_context_len": [1024, 2048, 3072, 4096],
|
||||
"head_num_q": 32, "head_num_kv": 1, "cache_mem_len": 6144, "block_size": 16, "head_size": 128,
|
||||
"input_dtype": [torch.float16, torch.bfloat16], "quant_mode": [0, 1], "quant_bit": [8],
|
||||
"use_offset": True},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark for dequant from paged cache.")
|
||||
parser.add_argument('--repeat_times', type=int, default=100, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
if "MLU3" in torch.mlu.get_device_name():
|
||||
print("Op dequant_from_paged_cache does not support MLU300 devices.")
|
||||
return
|
||||
titles = ["batch_size", "max_context_len", "head_num_q", "head_num_kv",
|
||||
"cache_mem_len", "block_size", "head_size", "input_dytpe", "quant_mode", "quant_bit",
|
||||
"use_offset", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch_size_list = params_dict["batch_size"]
|
||||
max_context_len_list = params_dict["max_context_len"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
block_size = params_dict["block_size"]
|
||||
head_size = params_dict["head_size"]
|
||||
cache_mem_len = params_dict["cache_mem_len"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
quant_mode_list = params_dict["quant_mode"]
|
||||
quant_bit_list = params_dict["quant_bit"]
|
||||
use_offset = params_dict["use_offset"]
|
||||
for quant_mode, batch_size, max_context_len, quant_bit, dtype in list(product( \
|
||||
quant_mode_list, batch_size_list, max_context_len_list, quant_bit_list, \
|
||||
input_dtype_list)):
|
||||
torch.manual_seed(2766)
|
||||
torch.mlu.manual_seed(2766)
|
||||
total_heads = head_num_q + head_num_kv * 2
|
||||
|
||||
assert max_context_len <= cache_mem_len, "max_context_len should smaller than or " \
|
||||
"equal to cache_mem_len."
|
||||
max_seq_offset = cache_mem_len - max_context_len
|
||||
max_block_num = int(math.ceil(max_context_len / block_size))
|
||||
total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size
|
||||
block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size,
|
||||
max_block_num)
|
||||
# Generates key and cache from context
|
||||
context_lens = torch.randint(size=[batch_size], low=max_context_len, high=max_context_len + 1,
|
||||
dtype=torch.int32, device="mlu")
|
||||
if use_offset:
|
||||
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
|
||||
dtype=torch.int32, device="mlu")
|
||||
else:
|
||||
context_paddings = torch.zeros_like(context_lens)
|
||||
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
|
||||
context_seq_offset[1:] = cu_context_lens[:-1]
|
||||
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
|
||||
key = context[..., head_num_q:head_num_q + head_num_kv, :]
|
||||
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
|
||||
# Generates key_cache and value_cache
|
||||
cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size),
|
||||
low=-128, high=127, dtype=torch.int32, device="mlu")
|
||||
cache = cache.to(torch.int8)
|
||||
key_cache, value_cache = cache[[0, 1]]
|
||||
|
||||
# Generates key_cache_scale and value_cache_scale
|
||||
if quant_mode == 0: # quant_mode == 0 is per channel
|
||||
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
|
||||
else: # quant_mode != 1 (== 1 for extend) is per head
|
||||
cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size),
|
||||
dtype=torch.float, device="mlu")
|
||||
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.dequant_from_paged_cache,
|
||||
key, value, key_cache, value_cache,
|
||||
key_cache_scale, value_cache_scale,
|
||||
context_lens, max_context_len,
|
||||
context_seq_offset if use_offset else None,
|
||||
block_tables, quant_mode,
|
||||
quant_bit, repeats=args.repeat_times)
|
||||
content = [f"{batch_size}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}",
|
||||
f"{cache_mem_len}", f"{block_size}", f"{head_size}", f"{dtype}", f"{quant_mode}",
|
||||
f"{quant_bit}", f"{use_offset}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
89
torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py
Normal file
89
torch_mlu_ops-v1.3.2/benchmarks/benchmark_ffn.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
|
||||
"gated_ffn": False, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
|
||||
"gated_ffn": True, "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
inner_size = params_dict["inner_size"]
|
||||
gated_ffn = params_dict["gated_ffn"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
|
||||
up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
|
||||
up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
|
||||
down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
|
||||
down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
|
||||
gate_up_proj_weight, gate_up_proj_bias = None, None
|
||||
if gated_ffn:
|
||||
gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
|
||||
gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.ffn,
|
||||
input,
|
||||
up_proj_weight,
|
||||
up_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_bias,
|
||||
gate_up_proj_weight,
|
||||
gate_up_proj_bias,
|
||||
act_mode,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
92
torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py
Normal file
92
torch_mlu_ops-v1.3.2/benchmarks/benchmark_flash_attn.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
# for e2e time test
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_q": 32768, "seq_kv": 32768, "head_num": 8,
|
||||
"head_num_kv": 1, "head_size": 128, "use_causal": True,
|
||||
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_q": 16384, "seq_kv": 16384, "head_num": 8,
|
||||
"head_num_kv": 1, "head_size": 128, "use_causal": True,
|
||||
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_q": 8192, "seq_kv": 24576, "head_num": 8,
|
||||
"head_num_kv": 1, "head_size": 128, "use_causal": True,
|
||||
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_q": 4096, "seq_kv": 28672, "head_num": 8,
|
||||
"head_num_kv": 1, "head_size": 128, "use_causal": True,
|
||||
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_q": 4096, "seq_kv": 32768, "head_num": 8,
|
||||
"head_num_kv": 1, "head_size": 128, "use_causal": True,
|
||||
"softmax_scale": 1e-6, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["batch", "seq_q", "seq_kv", "head_num", "head_num_kv", "head_size", "use_causal", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_q = params_dict["seq_q"]
|
||||
seq_kv = params_dict["seq_kv"]
|
||||
head_num = params_dict["head_num"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
use_causal = params_dict["use_causal"]
|
||||
softmax_scale = params_dict["softmax_scale"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
if seq_q == seq_kv:
|
||||
qkv = torch.randn(batch, seq_q, head_num + 2 * head_num_kv, head_size).to(dtype).to(device)
|
||||
q = qkv[:, :, : head_num, :]
|
||||
k = qkv[:, :, head_num : head_num + head_num_kv, :]
|
||||
v = qkv[:, :, head_num + head_num_kv : head_num + head_num * 2, :]
|
||||
elif seq_q < seq_kv:
|
||||
q = torch.randn(batch, seq_q, head_num, head_size).to(device).to(dtype)
|
||||
kv = torch.randn(batch, seq_kv, head_num_kv * 2, head_size).to(device).to(dtype)
|
||||
k = kv[:, :, : head_num_kv, :]
|
||||
v = kv[:, :, head_num_kv :, :]
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.flash_attention,
|
||||
q = q,
|
||||
k = k,
|
||||
v = v,
|
||||
out = None,
|
||||
cu_seq_lens_q = None,
|
||||
cu_seq_lens_kv = None,
|
||||
alibi_slope = None,
|
||||
attn_bias = None,
|
||||
max_seq_len_q = seq_q,
|
||||
max_seq_len_kv = seq_kv,
|
||||
softmax_scale = softmax_scale,
|
||||
is_causal = use_causal,
|
||||
window_size_left = -1,
|
||||
window_size_right = -1,
|
||||
compute_dtype = dtype,
|
||||
return_lse = False,
|
||||
block_tables = None,
|
||||
k_cache_quant_scale = None,
|
||||
v_cache_quant_scale = None,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_q}", f"{seq_kv}", f"{head_num}", f"{head_num_kv}", f"{head_size}", f"{use_causal}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
103
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py
Normal file
103
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_layer_norm.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "has_residual": True,
|
||||
"has_bias": False, "has_quant": True, "dynamic_quant": True,
|
||||
"input_dtype": torch.bfloat16},
|
||||
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "has_residual": True,
|
||||
"has_bias": False, "has_quant": True, "dynamic_quant": True,
|
||||
"input_dtype": torch.bfloat16},
|
||||
{"batch": 1, "seq_len": 8192, "hidden_size": 8192, "has_residual": True,
|
||||
"has_bias": False, "has_quant": True, "dynamic_quant": True,
|
||||
"input_dtype": torch.bfloat16},
|
||||
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "has_residual": True,
|
||||
"has_bias": False, "has_quant": True, "dynamic_quant": True,
|
||||
"input_dtype": torch.bfloat16},
|
||||
{"batch": 490, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
|
||||
"has_bias": False, "has_quant": True, "dynamic_quant": True,
|
||||
"input_dtype": torch.bfloat16},
|
||||
{"batch": 525, "seq_len": 1, "hidden_size": 8192, "has_residual": True,
|
||||
"has_bias": False, "has_quant": True, "dynamic_quant": True,
|
||||
"input_dtype": torch.bfloat16},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "seq_len", "hidden_size", "has_residual", "has_bias", "has_quant",
|
||||
"dynamic_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
has_residual = params_dict["has_residual"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
has_quant = params_dict["has_quant"]
|
||||
dynamic_quant = params_dict["dynamic_quant"]
|
||||
dtype = params_dict["input_dtype"]
|
||||
eps = 1e-6
|
||||
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
|
||||
x = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
|
||||
beta = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
residual, bias, quant_scale = None, None, None
|
||||
if has_residual:
|
||||
residual = torch.randn(batch, seq_len, hidden_size, dtype=dtype, device=device)
|
||||
if has_bias:
|
||||
bias = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
if has_quant or dynamic_quant:
|
||||
quant_scale = torch.randn(hidden_size, dtype=torch.float, device=device)
|
||||
|
||||
store_output_before_norm = has_residual
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_layer_norm,
|
||||
x,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
bias,
|
||||
eps,
|
||||
store_output_before_norm,
|
||||
quant_scale,
|
||||
None,
|
||||
dynamic_quant,
|
||||
repeats=args.repeat_times)
|
||||
n = x.nelement()
|
||||
sizeoft = x.element_size()
|
||||
io_bytes = (sizeoft + 1) * n + \
|
||||
(1 + store_output_before_norm) * (sizeoft * n if has_residual else 0) + \
|
||||
sizeoft * hidden_size * 2 + \
|
||||
(hidden_size * 4 if has_quant else 0) + \
|
||||
(batch * seq_len * 4 if dynamic_quant else 0)
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
|
||||
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{has_residual}", f"{has_bias}",
|
||||
f"{has_quant}", f"{dynamic_quant}", f"{dtype}",
|
||||
f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
143
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py
Normal file
143
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_moe.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"batch": 1, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 16, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 72, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 128, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 490, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 525, "seq_len": 1, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 1, "seq_len": 8192, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "inner_size": 1024, "num_expert": 32, "start_expert_id": 0, "expert_size": 32,
|
||||
"gated_ffn": False, "has_residual": False, "smooth_quant": True, "act_mode": "gelu",
|
||||
"topk": 5, "renormalize": False, "dtype": [torch.bfloat16]},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "num_expert", "topk", "act_mode", "quant_weight", "dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
inner_size = params_dict["inner_size"]
|
||||
gated_ffn = params_dict["gated_ffn"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
num_expert = params_dict["num_expert"]
|
||||
start_expert_id = params_dict["start_expert_id"]
|
||||
expert_size = params_dict["expert_size"]
|
||||
topk = params_dict["topk"]
|
||||
has_residual = params_dict["has_residual"]
|
||||
smooth_quant = params_dict["smooth_quant"]
|
||||
renormalize = params_dict["renormalize"]
|
||||
input_dtype_list = params_dict["dtype"]
|
||||
# print(f"batch:{batch}, seq_len:{seq_len}, hidden_size:{hidden_size}, inner_size:{inner_size}, "
|
||||
# f"gated_ffn:{gated_ffn}, act_mode:{act_mode}, num_expert:{num_expert}, topk:{topk}")
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
hidden_states = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
|
||||
router_logit = torch.randn(batch, seq_len, num_expert, device=device, dtype=torch.float32)
|
||||
|
||||
if False: # print token_count
|
||||
softmax = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
|
||||
topk_logit, expert_id = torch.topk(softmax, k=topk, dim=1)
|
||||
if renormalize:
|
||||
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
|
||||
sorted_expert_id, indices = expert_id.int().flatten().sort()
|
||||
token_cout = torch.bincount(sorted_expert_id, minlength=num_expert).int()
|
||||
print(token_cout)
|
||||
|
||||
residual = None
|
||||
if has_residual:
|
||||
residual = torch.randn(batch, seq_len, hidden_size, device=device, dtype=dtype)
|
||||
weight1 = torch.randn(num_expert, inner_size*(1+gated_ffn), hidden_size, device=device, dtype=dtype)
|
||||
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device=device, dtype=data_type)
|
||||
weight2 = torch.randn(num_expert, hidden_size, inner_size, device=device, dtype=dtype)
|
||||
bias2 = None # torch.randn(expert_num, hidden_size, device=device, dtype=data_type)
|
||||
input_smooth, act_smooth, w1_scale, w2_scale = None, None, None, None
|
||||
if smooth_quant:
|
||||
input_smooth = torch.randn(expert_size, hidden_size, device=device, dtype=torch.float32).abs() + 0.1
|
||||
act_smooth = torch.randn(expert_size, inner_size, device=device, dtype=torch.float32).abs() + 0.1
|
||||
weight1 = torch.randint(-128, 127, (num_expert, inner_size*(1+gated_ffn), hidden_size)).to(torch.int8).mlu()
|
||||
weight2 = torch.randint(-128, 127, (num_expert, hidden_size, inner_size)).to(torch.int8).mlu()
|
||||
w1_scale = torch.randn(expert_size, (1+gated_ffn)*inner_size).to(device).to(torch.float32)
|
||||
w2_scale = torch.randn(expert_size, hidden_size).to(device).to(torch.float32)
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_moe,
|
||||
hidden_states,
|
||||
router_logit,
|
||||
weight1[start_expert_id:start_expert_id+expert_size],
|
||||
weight2[start_expert_id:start_expert_id+expert_size],
|
||||
bias1,
|
||||
bias2,
|
||||
residual,
|
||||
input_smooth,
|
||||
act_smooth,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk,
|
||||
renormalize,
|
||||
gated_ffn,
|
||||
act_mode,
|
||||
start_expert_id,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{num_expert}", f"{topk}", f"{act_mode}", f"{smooth_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "input_size": 1600, "head_size": 80, "hidden_size": 1600, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 2048, "head_size": 128, "hidden_size": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 4096, "head_size": 128, "hidden_size": 4096, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 6144, "head_size": 128, "hidden_size": 6144, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 6656, "head_size": 128, "hidden_size": 6656, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 8192, "head_size": 128, "hidden_size": 8192, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 12288, "head_size": 128, "hidden_size": 12288, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "input_size": 14336, "head_size": 128, "hidden_size": 14336, "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["batch", "seq_len", "input_size", "hidden_size", "head_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
input_size = params_dict["input_size"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
head_size = params_dict["head_size"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input = torch.randn(batch, seq_len, input_size).to(device).to(dtype)
|
||||
weight = torch.randn(hidden_size * 3, input_size).to(device).to(dtype)
|
||||
bias = torch.randn(hidden_size * 3).to(device).to(dtype)
|
||||
weights = torch.chunk(weight, 3)
|
||||
biases = torch.chunk(bias, 3)
|
||||
norm_weight = torch.randn(input_size).to(device).to(dtype)
|
||||
norm_bias = torch.randn(input_size).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_attention_project,
|
||||
input,
|
||||
weights[0],
|
||||
biases[0],
|
||||
weights[1],
|
||||
biases[1],
|
||||
weights[2],
|
||||
biases[2],
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
1e-6,
|
||||
'nthc',
|
||||
head_size,
|
||||
False,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{input_size}", f"{hidden_size}", f"{head_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "hidden_size": 1600, "inner_size": 6400,
|
||||
"gated_ffn": False, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 2048, "inner_size": 8192,
|
||||
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 11008,
|
||||
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 14336,
|
||||
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 4096, "inner_size": 16384,
|
||||
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 13824,
|
||||
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 5120, "inner_size": 27392,
|
||||
"gated_ffn": True, "residual_is": "input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 6656, "inner_size": 17920,
|
||||
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 22016,
|
||||
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 24576,
|
||||
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 28672,
|
||||
"gated_ffn": True, "residual_is": "normed_input", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 8192, "inner_size": 49152,
|
||||
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 12288, "inner_size": 32768,
|
||||
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "hidden_size": 14336, "inner_size": 57344,
|
||||
"gated_ffn": True, "residual_is": "none", "act_mode": "gelu", "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "seq_len", "hidden_size", "inner_size", "gated_ffn", "residual_is", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
inner_size = params_dict["inner_size"]
|
||||
gated_ffn = params_dict["gated_ffn"]
|
||||
residual_is = params_dict["residual_is"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input = torch.randn(batch, seq_len, hidden_size).to(device).to(dtype)
|
||||
up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
|
||||
up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
|
||||
down_proj_weight = torch.randn(hidden_size, inner_size).to(device).to(dtype)
|
||||
down_proj_bias = torch.randn(hidden_size).to(device).to(dtype)
|
||||
layernorm_weight = torch.randn(hidden_size).to(device).to(dtype)
|
||||
layernorm_bias = torch.randn(hidden_size).to(device).to(dtype)
|
||||
gate_up_proj_weight, gate_up_proj_bias = None, None
|
||||
if gated_ffn:
|
||||
gate_up_proj_weight = torch.randn(inner_size, hidden_size).to(device).to(dtype)
|
||||
gate_up_proj_bias = torch.randn(inner_size).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_norm_residual_ffn,
|
||||
input,
|
||||
up_proj_weight,
|
||||
up_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_bias,
|
||||
gate_up_proj_weight,
|
||||
gate_up_proj_bias,
|
||||
layernorm_weight,
|
||||
layernorm_bias,
|
||||
1e-6,
|
||||
act_mode,
|
||||
residual_is,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{inner_size}", f"{gated_ffn}", f"{residual_is}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
90
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py
Normal file
90
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rms_norm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
from itertools import product
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "head_num": 25, "head_size": 64, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 16, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 32, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 40, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 96, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 52, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 64, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 96, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "head_num": 112, "head_size": 128, "has_residual": True,
|
||||
"has_bias": True, "has_quant": True, "eps": 1e-6, "dynamic_quant": [False, True], "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["input_shape", "has_residual", "has_bias", "has_quant", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
head_num = params_dict["head_num"]
|
||||
head_size = params_dict["head_size"]
|
||||
has_residual = params_dict["has_residual"]
|
||||
has_quant = params_dict["has_quant"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
eps = params_dict["eps"]
|
||||
dynamic_quant_list = params_dict["dynamic_quant"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
dynamic_quant_list = params_dict["dynamic_quant"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
iters = product(dynamic_quant_list, input_dtype_list)
|
||||
for dynamic_quant, dtype in iters:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
x = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
|
||||
beta = torch.randn(head_size).to(dtype).to(device)
|
||||
gamma = torch.randn(head_size).to(dtype).to(device)
|
||||
residual, bias, quant_scale = None, None, None
|
||||
if has_residual:
|
||||
residual = torch.randn(batch, seq_len, head_num, head_size).to(dtype).to(device)
|
||||
if has_bias:
|
||||
bias = torch.randn(head_size).to(dtype).to(device)
|
||||
if has_quant or dynamic_quant:
|
||||
quant_scale = torch.randn(head_size).to(device)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_rms_norm,
|
||||
x,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
bias,
|
||||
eps,
|
||||
False,
|
||||
quant_scale,
|
||||
None,
|
||||
dynamic_quant,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch, seq_len, head_num, head_size}", f"{has_residual}", f"{has_bias}", f"{has_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
207
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py
Normal file
207
torch_mlu_ops-v1.3.2/benchmarks/benchmark_fused_rope.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": False, "max_decode_len": 2048, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 64, "block_size": 32, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": False, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"quant_kv": True, "paged_cache": True, "num_blocks": 128, "block_size": 16, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
|
||||
{"batch": 40, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "head_num_q": 8, "head_num_k": 1, "head_size": 128, "rotary_dim": 128,
|
||||
"mixed_cache": True, "input_dtype": [torch.float16, torch.bfloat16]},]
|
||||
def main():
|
||||
if 'MLU3' in torch.mlu.get_device_name():
|
||||
exit()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "head_num_q", "head_num_k", "head_size", "rotary_dim", "quant_kv", "paged_cache", "max_decode_len", "num_blocks", \
|
||||
"block_size", "mixed_cache", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
bs = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
q_heads = params_dict["head_num_q"]
|
||||
kv_heads = params_dict["head_num_k"]
|
||||
head_size = params_dict["head_size"]
|
||||
rope_dim = params_dict["rotary_dim"]
|
||||
quant_kv = params_dict["quant_kv"] if "quant_kv" in params_dict else True
|
||||
paged_cache = params_dict["paged_cache"] if "paged_cache" in params_dict else False
|
||||
mixed_cache = params_dict["mixed_cache"] if "mixed_cache" in params_dict else False
|
||||
max_decode_len = 0
|
||||
num_blocks = 0
|
||||
block_size = 0
|
||||
if paged_cache:
|
||||
num_blocks = params_dict["num_blocks"]
|
||||
block_size = params_dict["block_size"]
|
||||
else:
|
||||
max_decode_len = params_dict["max_decode_len"] if "max_decode_len" in params_dict else 32
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
|
||||
for dtype in input_dtype_list:
|
||||
discrete_batch = True
|
||||
max_bs = bs + 1 if discrete_batch else bs
|
||||
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
|
||||
input = torch.randn(size=input_shape, dtype=dtype).mlu()
|
||||
input_ref = input.clone()
|
||||
|
||||
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
|
||||
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
|
||||
cache_dtype = dtype
|
||||
if quant_kv:
|
||||
k_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
|
||||
v_scale = torch.randn(size=(kv_heads, head_size), dtype=torch.float).mlu()
|
||||
cache_dtype = torch.int8
|
||||
k_scale_ops = 1 / k_scale
|
||||
v_scale_ops = 1 / v_scale
|
||||
else:
|
||||
k_scale = None
|
||||
v_scale = None
|
||||
k_scale_ops = None
|
||||
v_scale_ops = None
|
||||
if paged_cache:
|
||||
cache = torch.randn((2, num_blocks, kv_heads, block_size, head_size), dtype=dtype, device='mlu')
|
||||
else:
|
||||
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
|
||||
|
||||
if quant_kv:
|
||||
cache = (cache - 0.5) * 256
|
||||
cache = cache.to(cache_dtype)
|
||||
k_cache = cache[0]
|
||||
v_cache = cache[1]
|
||||
|
||||
cache_bs_id = None
|
||||
cache_seq_offsets = None
|
||||
slot_mapping = None
|
||||
if not paged_cache:
|
||||
if discrete_batch:
|
||||
cache_bs_id = random.sample([*range(0, max_bs)], bs)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
|
||||
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
else:
|
||||
slot_mapping = random.sample([*range(-1, block_size * num_blocks)], bs)
|
||||
slot_mapping = torch.IntTensor(slot_mapping).mlu()
|
||||
|
||||
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
|
||||
|
||||
k_cache_lp = None
|
||||
v_cache_lp = None
|
||||
k_scale_lp = None
|
||||
v_scale_lp = None
|
||||
cache_bs_id_lp = None
|
||||
cache_seq_offsets_lp = None
|
||||
|
||||
if mixed_cache:
|
||||
max_decode_len_lp = 1024
|
||||
k_cache_raw = torch.randn((max_bs, kv_heads, max_decode_len_lp, int(head_size / 2)), dtype=dtype, device='mlu')
|
||||
v_cache_raw = torch.randn((max_bs, kv_heads, int(max_decode_len_lp / 2), head_size), dtype=dtype, device='mlu')
|
||||
max_value = torch.amax(torch.abs(k_cache_raw))
|
||||
k_cache_raw = k_cache_raw * (7 / max_value)
|
||||
max_value = torch.amax(torch.abs(v_cache_raw))
|
||||
v_cache_raw = v_cache_raw * (7 / max_value)
|
||||
k_cache_lp = k_cache_raw.to(torch.int8)
|
||||
v_cache_lp = v_cache_raw.to(torch.int8)
|
||||
|
||||
k_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
|
||||
v_scale_lp = torch.randn(size=(max_bs, kv_heads, max_decode_len_lp, 1), dtype=torch.float).mlu()
|
||||
|
||||
cache_bs_id_lp = random.sample([*range(0, max_bs)], bs)
|
||||
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
|
||||
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
|
||||
dtype=torch.int32, device='mlu')
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.fused_rope,
|
||||
input,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sin_table,
|
||||
cos_table,
|
||||
position_id,
|
||||
gamma,
|
||||
beta,
|
||||
k_cache_lp,
|
||||
v_cache_lp,
|
||||
cache_bs_id,
|
||||
cache_seq_offsets,
|
||||
cache_bs_id_lp,
|
||||
cache_seq_offsets_lp,
|
||||
k_scale_ops,
|
||||
v_scale_ops,
|
||||
k_scale_lp,
|
||||
v_scale_lp,
|
||||
slot_mapping,
|
||||
None,
|
||||
1e-5,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{bs}", f"{q_heads}", f"{kv_heads}", f"{head_size}", f"{rope_dim}", f"{quant_kv}", f"{paged_cache}", \
|
||||
f"{max_decode_len}", f"{num_blocks}", f"{block_size}", f"{mixed_cache}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
117
torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py
Normal file
117
torch_mlu_ops-v1.3.2/benchmarks/benchmark_group_gemm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"batch": 1, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 16, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 16, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 72, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 490, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 490, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 525, "seq_len": 1, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 525, "seq_len": 1, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 1, "seq_len": 1024, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1024, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 2048, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 2048, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 4096, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 4096, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 8192, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 8192, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 32768, "k": 8192, "n": 1024, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 32768, "k": 1024, "n": 8192, "expert_num": 32,
|
||||
"topk": 5, "is_quant": True, "dtype": [torch.bfloat16]},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
titles = ["batch", "seq_len", "k", "n", "expert_num", "topk", "smooth_quant", "dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
k = params_dict["k"]
|
||||
n = params_dict["n"]
|
||||
expert_num = params_dict["expert_num"]
|
||||
topk = params_dict["topk"]
|
||||
is_quant = params_dict["is_quant"]
|
||||
input_dtype_list = params_dict["dtype"]
|
||||
# print(f"batch:{batch}, seq_len:{seq_len}, k:{k}, n:{n}, expert_num:{expert_num}, topk:{topk}, is_quant:{is_quant}")
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
max_m = batch * seq_len
|
||||
m = batch * seq_len * topk
|
||||
avg, rem = m // expert_num, m % expert_num
|
||||
m_list = [avg + (i < rem) for i in range(expert_num)]
|
||||
token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
|
||||
|
||||
if not is_quant:
|
||||
a = torch.randn(m, k, dtype=dtype, device='mlu')
|
||||
b = torch.randn(expert_num, n, k, dtype=dtype, device='mlu')
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.group_gemm,
|
||||
a, b, token_count,
|
||||
None, None, None, None,
|
||||
max_m,
|
||||
repeats=args.repeat_times)
|
||||
else:
|
||||
a = torch.randint(-128, 127, (m, k)).to(torch.int8).mlu()
|
||||
b = torch.randint(-128, 127, (expert_num, n, k)).to(torch.int8).mlu()
|
||||
a_scale = torch.randn(a.size(0), dtype=torch.float32, device='mlu')
|
||||
b_scale = torch.randn(expert_num, n, dtype=torch.float32, device='mlu')
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_group_gemm,
|
||||
a, b, token_count,
|
||||
None, None, None, None,
|
||||
a_scale, b_scale, dtype, max_m,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{seq_len}", f"{k}", f"{n}", f"{expert_num}", f"{topk}",
|
||||
f"{is_quant}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
75
torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py
Normal file
75
torch_mlu_ops-v1.3.2/benchmarks/benchmark_matmul.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"m": 1024, "k": 1600, "n": 6400, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 2048, "n": 8192, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 4096, "n": 11008, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 4096, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 5120, "n": 16384, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 5120, "n": 27392, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 6144, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 6656, "n": 17920, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 8192, "n": 22016, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 8192, "n": 24576, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 8192, "n": 28672, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 8192, "n": 49152, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 12288, "n": 32768, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 14336, "n": 57344, "has_c": True, "has_bias": True, "act_mode": "relu", "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
m = params_dict["m"]
|
||||
k = params_dict["k"]
|
||||
n = params_dict["n"]
|
||||
has_c = params_dict["has_c"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
a = torch.randn(m, k).to(device).to(dtype)
|
||||
b = torch.randn(n, k).to(device).to(dtype)
|
||||
c = None
|
||||
if has_c:
|
||||
c = torch.randn(m, n).to(device).to(dtype)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(n).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.matmul,
|
||||
a,
|
||||
b,
|
||||
bias,
|
||||
c,
|
||||
act_mode,
|
||||
1.0,
|
||||
1.0,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
117
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py
Normal file
117
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_active.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 1024, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": True, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 4096, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 8192, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 32768, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 1, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 16, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 32, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 64, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 128, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 256, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 512, "seq_len": 1, "inner_size": 1024,
|
||||
"act_mode": "gelu", "is_gated": False, "has_bias": True,
|
||||
"is_ep": True, "input_dtype": [torch.bfloat16]},]
|
||||
|
||||
def gen_data(num_expert,
|
||||
total_tokens,
|
||||
inner_size,
|
||||
output_stride,
|
||||
dtype,
|
||||
is_gated,
|
||||
has_bias,
|
||||
is_ep):
|
||||
ci = inner_size * (1 + is_gated)
|
||||
input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu')
|
||||
cusum_token_count, token_count = generate_token_count(num_expert, total_tokens)
|
||||
output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu')
|
||||
output.as_strided(output.size(), (output_stride, 1))
|
||||
start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0
|
||||
expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert
|
||||
bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None
|
||||
return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["input_shape", "act_mode", "is_gated", "has_bias", "expert_num", "start_expert_id",
|
||||
"expert_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
inner_size = params_dict["inner_size"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
is_gated = params_dict["is_gated"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
is_ep = params_dict["is_ep"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
expert_num = expert_num = random.randint(1, 256)
|
||||
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
|
||||
gen_data(expert_num, batch * seq_len, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
|
||||
|
||||
real_bias = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_active,
|
||||
input,
|
||||
act_mode,
|
||||
is_gated,
|
||||
output,
|
||||
real_bias,
|
||||
cusum_token_count.mlu() if has_bias or is_ep else None,
|
||||
start_expert_id,
|
||||
expert_size,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
io_bytes = input.element_size() * input.nelement() * (2 - 0.5 * is_gated) + \
|
||||
real_bias.element_size() * real_bias.nelement() + \
|
||||
(cusum_token_count.element_size() * cusum_token_count.nelement()) if has_bias or is_ep else 0
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{batch,seq_len,inner_size}", f"{act_mode}", f"{is_gated}", f"{has_bias}", f"{expert_num}",
|
||||
f"{start_expert_id}", f"{expert_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
59
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py
Normal file
59
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_cast_gating.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 1, "seq_len": 2048, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
|
||||
{"batch": 1, "seq_len": 4096, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
|
||||
{"batch": 1, "seq_len": 32768, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
|
||||
{"batch": 16, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
|
||||
{"batch": 128, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16},
|
||||
{"batch": 512, "seq_len": 1, "hidden_size": 8192, "expert_num": 32, "input_dtype": torch.bfloat16}]
|
||||
|
||||
def main():
|
||||
if 'MLU3' in torch.mlu.get_device_name():
|
||||
exit()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["batch", "seq_len", "hidden_size", "expert_num", "input_dtype", "hardware_time(us)",
|
||||
"e2e_latency(us)", "IO efficiency"]
|
||||
|
||||
contents = []
|
||||
bandwidth = get_band_width()
|
||||
for param_dict in e2e_time_param_dict_list:
|
||||
batch = param_dict["batch"]
|
||||
seq_len = param_dict["seq_len"]
|
||||
hidden_size = param_dict["hidden_size"]
|
||||
expert_num = param_dict["expert_num"]
|
||||
input_dtype = param_dict["input_dtype"]
|
||||
if input_dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
input_dtype = torch.half
|
||||
input = torch.randn(batch, seq_len, hidden_size, dtype=input_dtype, device="mlu")
|
||||
weight = torch.randn(expert_num, hidden_size, dtype=torch.float32, device="mlu")
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_cast_gating,
|
||||
input,
|
||||
weight)
|
||||
io_bytes = batch * seq_len * hidden_size * input.element_size() + \
|
||||
expert_num * hidden_size * weight.element_size() + batch * seq_len * expert_num * weight.element_size()
|
||||
io_coeff = io_bytes / hardware_time / bandwidth
|
||||
content = [f"{batch}", f"{seq_len}", f"{hidden_size}", f"{expert_num}", f"{input_dtype}",
|
||||
f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
166
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py
Normal file
166
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_combine_result.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"num_tokens": 16, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 128, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 490, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 525, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 2048, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 4096, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 8192, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
{"num_tokens": 32768, "num_expert": 32, "topk": 5, "start_expert_id": 0,
|
||||
"expert_size": 32, "has_residual": False, "hidden_size": 8192,
|
||||
"dtype": [torch.bfloat16]},
|
||||
]
|
||||
|
||||
def gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
has_bias,
|
||||
has_residual,
|
||||
dtype,
|
||||
device):
|
||||
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device)
|
||||
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device)
|
||||
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device)
|
||||
bias = None
|
||||
residual = None
|
||||
|
||||
cusum_token_count = None
|
||||
|
||||
if has_bias:
|
||||
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device)
|
||||
if has_residual:
|
||||
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
|
||||
|
||||
if has_bias or expert_size < num_expert:
|
||||
cusum_token_count, _ = generate_token_count(num_expert, num_tokens * topk)
|
||||
cusum_token_count = cusum_token_count.to(device=device)
|
||||
return input, reduce_weight, gather_ids, residual, bias, cusum_token_count
|
||||
|
||||
|
||||
def get_io_bytes(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
has_bias,
|
||||
has_residual,
|
||||
dtype,
|
||||
cusum_token_count,
|
||||
gather_ids):
|
||||
io_bytes = 0
|
||||
dtype_size = 4 if dtype is torch.float32 else 2
|
||||
if cusum_token_count is not None:
|
||||
filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \
|
||||
(gather_ids < cusum_token_count[start_expert_id + expert_size])
|
||||
filtered_ids = filtered_ids.to(dtype=torch.float32)
|
||||
io_bytes += torch.sum(filtered_ids).item() * hidden_size * dtype_size
|
||||
else:
|
||||
io_bytes += num_tokens * topk * hidden_size * dtype_size
|
||||
|
||||
if has_bias:
|
||||
io_bytes += expert_size * hidden_size * dtype_size
|
||||
|
||||
if has_residual:
|
||||
io_bytes += num_tokens * hidden_size * dtype_size
|
||||
|
||||
io_bytes += num_tokens * topk * 4
|
||||
io_bytes += num_tokens * hidden_size * dtype_size
|
||||
return io_bytes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["num_tokens", "num_expert", "topk", "start_expert_id", "expert_size", \
|
||||
"hidden_size", "has_residual", "dtype", "hardware_time(us)", "e2e_latency(us)", "io_coeff"]
|
||||
contents = []
|
||||
bandwidth = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
num_tokens = params_dict["num_tokens"]
|
||||
num_expert = params_dict["num_expert"]
|
||||
topk = params_dict["topk"]
|
||||
start_expert_id = params_dict["start_expert_id"]
|
||||
expert_size = params_dict["expert_size"]
|
||||
has_residual = params_dict["has_residual"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
dtype_list = params_dict["dtype"]
|
||||
for dtype in dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
inputs = gen_case(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
False,
|
||||
has_residual,
|
||||
dtype,
|
||||
device)
|
||||
input = inputs[0]
|
||||
reduce_weight = inputs[1]
|
||||
gather_ids = inputs[2]
|
||||
residual = inputs[3]
|
||||
bias = inputs[4]
|
||||
cusum_token_count = inputs[5]
|
||||
|
||||
io_bytes = get_io_bytes(num_tokens,
|
||||
topk,
|
||||
hidden_size,
|
||||
num_expert,
|
||||
expert_size,
|
||||
start_expert_id,
|
||||
False,
|
||||
has_residual,
|
||||
dtype,
|
||||
cusum_token_count,
|
||||
gather_ids)
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_combine_result, input, reduce_weight,
|
||||
gather_ids,residual, cusum_token_count,
|
||||
start_expert_id, expert_size,
|
||||
repeats=args.repeat_times)
|
||||
io_coeff = io_bytes / hardware_time / bandwidth
|
||||
content = [f"{num_tokens}", f"{num_expert}", f"{topk}", f"{start_expert_id}", \
|
||||
f"{expert_size}", f"{hidden_size}", f"{has_residual}", f"{dtype}", \
|
||||
f"{hardware_time}", f"{e2e_time}", f"{io_coeff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
e2e_time_param_dict_list = [{"token_num": 1, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 16, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 32, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 64, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 128, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 512, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 1024, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 4096, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 8192, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]},
|
||||
{"token_num": 32768, "hidden_size": 4096, "expert_num": 32, "topk": 5,
|
||||
"start_expert_id": 0, "expert_size": 32, "input_dtype": [torch.int8, torch.float16, torch.bfloat16]}]
|
||||
|
||||
def gen_tensor(token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype):
|
||||
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
|
||||
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,)).to(torch.int32).to('mlu')
|
||||
cusum_token_count, _ = generate_token_count(expert_num, token_num * topk)
|
||||
cusum_token_count = cusum_token_count.to('mlu')
|
||||
use_all_experts = expert_num == expert_size
|
||||
if use_all_experts:
|
||||
cusum_token_count = None
|
||||
real_token_count = token_num * topk
|
||||
else:
|
||||
real_token_count = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
|
||||
return input, gather_idx, cusum_token_count, real_token_count
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
titles = ["token_num", "hidden_size", "expert_num", "topk", "start_expert_id", "expert_size", "input_dtype",
|
||||
"hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
token_num = params_dict["token_num"]
|
||||
hidden_size = params_dict["hidden_size"]
|
||||
expert_num = params_dict["expert_num"]
|
||||
topk = params_dict["topk"]
|
||||
start_expert_id = params_dict["start_expert_id"]
|
||||
expert_size = params_dict["expert_size"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input, gather_idx, cusum_token_count, real_token_count = \
|
||||
gen_tensor(token_num, hidden_size, expert_num,topk, start_expert_id, expert_size, dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_expand_input,
|
||||
input,
|
||||
gather_idx,
|
||||
cusum_token_count,
|
||||
start_expert_id,
|
||||
expert_size,
|
||||
repeats=args.repeat_times)
|
||||
io_bytes = input.element_size() * input.nelement() + \
|
||||
gather_idx.element_size() * gather_idx.nelement() + \
|
||||
(cusum_token_count.element_size() * cusum_token_count.nelement() if cusum_token_count is not None else 0) + \
|
||||
real_token_count * input.element_size()
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{token_num}", f"{hidden_size}", f"{expert_num}", f"{topk}", f"{start_expert_id}", f"{expert_size}",
|
||||
f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
69
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py
Normal file
69
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_gen_idx.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"token_num": 1, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 16, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 32, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 64, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 512, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 1024, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 4096, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 8192, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 32767, "expert_num": 32, "topk": 5, "input_dtype": torch.int32},
|
||||
{"token_num": 1, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 16, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 32, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 64, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 512, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 1024, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 4096, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 8192, "expert_num": 8, "topk": 2, "input_dtype": torch.int32},
|
||||
{"token_num": 32767, "expert_num": 8, "topk": 2, "input_dtype": torch.int32}]
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
titles = ["token_num", "expert_num", "topk", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
token_num = params_dict["token_num"]
|
||||
expert_num = params_dict["expert_num"]
|
||||
topk = params_dict["topk"]
|
||||
dtype = params_dict["input_dtype"]
|
||||
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
|
||||
gather_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
|
||||
combine_idx = torch.empty((token_num * topk), dtype=dtype, device='mlu')
|
||||
token_count = torch.empty((expert_num), dtype=dtype, device='mlu')
|
||||
cusum_token_count = torch.empty((expert_num + 1), dtype=dtype, device='mlu')
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_gen_idx,
|
||||
expert_id,
|
||||
expert_num,
|
||||
repeats=args.repeat_times)
|
||||
io_bytes = expert_id.element_size() * expert_id.nelement() + \
|
||||
gather_idx.element_size() * gather_idx.nelement() + \
|
||||
combine_idx.element_size() * combine_idx.nelement() + \
|
||||
token_count.element_size() * token_count.nelement() + \
|
||||
cusum_token_count.element_size() * cusum_token_count.nelement()
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{token_num}", f"{expert_num}", f"{topk}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
114
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py
Normal file
114
torch_mlu_ops-v1.3.2/benchmarks/benchmark_moe_quantize.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
params_dict = [
|
||||
{"token_num": 1, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 16, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 128, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 490, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 512, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 525, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 2048, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 4096, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 8192, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
{"token_num": 32768, "hidden_size": 8192, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": True, "dtype": torch.bfloat16},
|
||||
|
||||
{"token_num": 1, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 16, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 128, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 490, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 512, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 525, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 2048, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 4096, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 8192, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
{"token_num": 32768, "hidden_size": 1024, "expert_num": 32, "topk": 5,
|
||||
"has_gather_idx": False, "dtype": torch.bfloat16},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["token_num", "hidden_size", "expert_num", "topk", "has_gather_idx", "dtype",
|
||||
"hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
for param in params_dict:
|
||||
token_num, hidden_size, expert_num, topk, has_gather_idx, dtype = param.values()
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
if "MLU3" in torch.mlu.get_device_name():
|
||||
has_gather_idx = False
|
||||
expand_token_num = token_num * topk
|
||||
input_shape = (token_num if has_gather_idx else expand_token_num, hidden_size)
|
||||
input = torch.randn(input_shape).to(device).to(dtype)
|
||||
scale = torch.randn(expert_num, hidden_size).to(device).to(torch.float32)
|
||||
|
||||
avg, rem = expand_token_num // expert_num, expand_token_num % expert_num
|
||||
m_list = [avg + (i < rem) for i in range(expert_num)]
|
||||
token_count = torch.tensor(m_list, dtype=torch.int32, device='mlu')
|
||||
|
||||
if has_gather_idx:
|
||||
gather_idx = torch.arange(0, token_num).repeat([topk])
|
||||
gather_idx = gather_idx[torch.randperm(gather_idx.size(0))].to(torch.int32).mlu()
|
||||
else:
|
||||
gather_idx = None
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_quantize,
|
||||
input,
|
||||
scale,
|
||||
None,
|
||||
token_count,
|
||||
gather_idx,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
repeats=args.repeat_times)
|
||||
expand_num = topk if has_gather_idx else 1
|
||||
io_bytes = (input.element_size() + 1) * input.nelement() * expand_num
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{token_num}", f"{hidden_size}", f"{expert_num}",
|
||||
f"{topk}", f"{has_gather_idx}", f"{dtype}",
|
||||
f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 32, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 72, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 8192, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 32768, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 1, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 2, "seq_len": 16, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 2, "seq_len": 36, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 8, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 16, "seq_len": 128, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 4, "seq_len": 1024, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 2, "seq_len": 4096, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 16, "seq_len": 2048, "num_expert": 32, "topk": 5, "num_expert_group": -1, "topk_group": -1, "normalize": True, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 1, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 16, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 64, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 1024, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 2048, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 8192, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
{"num_batch": 1, "seq_len": 32768, "num_expert": 160, "topk": 6, "num_expert_group": 8, "topk_group": 3, "normalize": False, "input_dtype": torch.bfloat16},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["num_batch", "seq_len", "num_expert", "topk", "num_expert_group", "topk_group", "normalize", "input_dtype", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bd = get_band_width()
|
||||
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
num_batch = params_dict["num_batch"]
|
||||
seq_len = params_dict["seq_len"]
|
||||
num_expert = params_dict["num_expert"]
|
||||
topk = params_dict["topk"]
|
||||
num_expert_group = params_dict["num_expert_group"]
|
||||
topk_group = params_dict["topk_group"]
|
||||
normalize = params_dict["normalize"]
|
||||
dtype = params_dict["input_dtype"]
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
input = torch.randn(num_batch, seq_len, num_expert, dtype=dtype, device='mlu')
|
||||
mask = torch.randint(0, 2, (1, seq_len, num_expert), dtype = dtype, device='mlu')
|
||||
if num_expert_group > 1:
|
||||
mask = None
|
||||
normed_by = "softmax_logit"
|
||||
reduce_weight = torch.empty(num_batch, topk, dtype=torch.float, device='mlu')
|
||||
expert_id = torch.empty(num_batch, topk, dtype=torch.int32, device='mlu')
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.moe_softmax_topk,
|
||||
input,
|
||||
topk,
|
||||
normalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
mask,
|
||||
normed_by,
|
||||
repeats=args.repeat_times)
|
||||
io_bytes = input.element_size() * input.nelement() + \
|
||||
reduce_weight.element_size() * reduce_weight.nelement() + \
|
||||
expert_id.element_size() * expert_id.nelement()
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{num_batch}", f"{seq_len}", f"{num_expert}", f"{topk}", f"{num_expert_group}", f"{topk_group}", f"{normalize}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
|
||||
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_head"},
|
||||
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": True, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"},
|
||||
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": False, "input_dtype": [torch.float16, torch.bfloat16], "quantize_mode": "per_channel"}
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "quantize_mode", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
max_batch = params_dict["max_batch"]
|
||||
batch = params_dict["batch"]
|
||||
cache_mem_len = params_dict["cache_mem_len"]
|
||||
max_context_len = params_dict["max_context_len"]
|
||||
max_seq_offset = params_dict["max_seq_offset"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
packed = params_dict["packed"]
|
||||
quantize_mode = params_dict["quantize_mode"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
context_lens = torch.randint(size=(batch, ), low=max_context_len,
|
||||
high=max_context_len+1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
# max_seq_offset = max_context_len // 3 + 1
|
||||
context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
|
||||
high=(cache_mem_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
total_heads = head_num_q + 2 * head_num_kv
|
||||
if packed > 0:
|
||||
context = torch.randn((total_seqlen, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
else:
|
||||
context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
|
||||
context = context.to(dtype)
|
||||
cache = cache.to(dtype)
|
||||
key = context[..., head_num_q : head_num_q + head_num_kv, :]
|
||||
value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
|
||||
key_cache = cache[0]
|
||||
value_cache = cache[1]
|
||||
|
||||
cache_bs_id = None
|
||||
cache_bs_id = random.sample([*range(0, max_batch)], batch)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
key_cache = (key_cache - 0.5) * 256
|
||||
value_cache = (value_cache - 0.5) * 256
|
||||
key_cache = key_cache.to(torch.int8)
|
||||
value_cache = value_cache.to(torch.int8)
|
||||
if packed > 0:
|
||||
if quantize_mode == "per_channel":
|
||||
key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
|
||||
value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
key_cache_quantize_scale,
|
||||
value_cache_quantize_scale,
|
||||
cu_context_lens, max_context_len, 0,
|
||||
packed > 0, None,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
repeats=args.repeat_times)
|
||||
elif quantize_mode == "per_head":
|
||||
key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
|
||||
value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
key_cache_quantize_scale,
|
||||
value_cache_quantize_scale,
|
||||
cu_context_lens, max_context_len, 1,
|
||||
packed > 0, None,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
repeats=args.repeat_times)
|
||||
else:
|
||||
if quantize_mode == "per_channel":
|
||||
key_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
|
||||
value_cache_quantize_scale = torch.randn(head_num_kv, head_size).to('mlu').to(torch.float32)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
key_cache_quantize_scale,
|
||||
value_cache_quantize_scale,
|
||||
context_lens, max_context_len, 0,
|
||||
packed > 0, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
repeats=args.repeat_times)
|
||||
elif quantize_mode == "per_head":
|
||||
key_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
|
||||
value_cache_quantize_scale = torch.randn(head_num_kv, cache_mem_len).to('mlu').to(torch.float32)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
key_cache_quantize_scale,
|
||||
value_cache_quantize_scale,
|
||||
context_lens, max_context_len, 1,
|
||||
packed > 0, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{quantize_mode}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from itertools import product
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"token_num": [1024, 2048, 3072, 4096], "head_num_kv": 1, "head_size": 128, "block_size": 16,
|
||||
"input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"token_num": [1024 * 32, 2048 * 32, 3072 * 32, 4096 * 32], "head_num_kv": 1, "head_size": 128, "block_size": 16,
|
||||
"input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"token_num": [1024 * 64, 2048 * 64, 3072 * 64, 4096 * 64], "head_num_kv": 1, "head_size": 128, "block_size": 16,
|
||||
"input_dtype": [torch.float16, torch.bfloat16]},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["token_num", "head_num_kv", "head_size", "block_size", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
if "MLU3" in torch.mlu.get_device_name():
|
||||
print("Op offline_quant_to_paged_cache does not support MLU300 devices.")
|
||||
return
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
token_num_list = params_dict["token_num"]
|
||||
# block_num = params_dict["block_num"]
|
||||
block_size = params_dict["block_size"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for token_num, dtype in product(token_num_list, input_dtype_list):
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
block_num = (token_num + block_size - 1) // block_size
|
||||
key = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
|
||||
value = torch.randn(token_num, head_num_kv, head_size, dtype=dtype, device="mlu")
|
||||
key_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
|
||||
value_cache = torch.randint(-128, 127, (block_num, head_num_kv, block_size, head_size), dtype=torch.int8).to("mlu")
|
||||
num_slots = block_num * block_size
|
||||
slot_mapping = random.sample(range(num_slots), token_num)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu")
|
||||
key_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
|
||||
value_cache_scale = torch.randn(head_num_kv, head_size, dtype=torch.float, device="mlu")
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.offline_quant_to_paged_cache,
|
||||
key, value,
|
||||
key_cache_scale, value_cache_scale,
|
||||
slot_mapping,
|
||||
key_cache, value_cache,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
|
||||
content = [f"{token_num}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
import csv
|
||||
|
||||
params_dict = {
|
||||
"token_num": [n * 5 for n in [1, 72, 512, 1024, 4096, 32768]],
|
||||
"hidden_size": [1024, 8192],
|
||||
"input_dtype": [torch.float16]
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
params_list = product(params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
|
||||
bd = get_band_width()
|
||||
for params in params_list:
|
||||
token_num, hidden_size = params[0], params[1]
|
||||
input_shape = (token_num, hidden_size)
|
||||
smooth_shape = (hidden_size)
|
||||
dtype = params[2]
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input = torch.randn(input_shape).to(device).to(dtype)
|
||||
smooth = torch.randn(smooth_shape).to(device).to(torch.float32)
|
||||
zero = None
|
||||
token_count = None
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
|
||||
input,
|
||||
smooth,
|
||||
zero,
|
||||
token_count,
|
||||
repeats=args.repeat_times)
|
||||
io_bytes = (input.element_size() + 1) * input.nelement() + \
|
||||
smooth.element_size() * smooth.nelement() + \
|
||||
token_num * 4
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
46
torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py
Normal file
46
torch_mlu_ops-v1.3.2/benchmarks/benchmark_preload.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"input_shape": [100, 100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"input_shape": [100, 100], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"input_shape": [50, 50, 50], "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"input_shape": [1, 100, 1000], "input_dtype": [torch.float16, torch.bfloat16]}
|
||||
]
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["input_shape", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
input_shape = params_dict["input_shape"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input = torch.randn(input_shape).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.preload,
|
||||
input,
|
||||
input.element_size() * input.numel(),
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{input_shape}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,201 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import random
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 128},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 8, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 1024, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 2048, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 4096, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 8192, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 8, "batch": 1, "cache_mem_len": 32768, "max_context_len": 32768, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": True, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 16, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 72, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 128, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 256, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
{"max_batch": 512, "batch": 512, "cache_mem_len": 4096, "max_context_len": 1, "head_num_kv": 1,
|
||||
"head_size": 128, "packed": False, "input_dtype": [torch.bfloat16], "quant_bit": 4, "group_size": 16},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["batch", "max_context_len", "head_num_kv", "head_size", "packed", "input_dtype",
|
||||
"quant_bit", "group_size", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
bandwidth = get_band_width()
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
max_batch = params_dict["max_batch"]
|
||||
batch = params_dict["batch"]
|
||||
cache_mem_len = params_dict["cache_mem_len"]
|
||||
max_context_len = params_dict["max_context_len"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
packed = params_dict["packed"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
quant_bit = params_dict["quant_bit"]
|
||||
group_size = params_dict["group_size"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
context_lens = torch.tensor([max_context_len] * batch).to(torch.int32).mlu()
|
||||
context_seq_offsets = torch.zeros(batch, dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch, ),
|
||||
low=0,
|
||||
high = 1 if max_context_len > 1 else cache_mem_len,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
if packed > 0:
|
||||
key = torch.randn((total_seqlen, head_num_kv, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
value = torch.randn((total_seqlen, head_num_kv, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
else:
|
||||
key = torch.randn((batch, max_context_len, head_num_kv, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
value = torch.randn((batch, max_context_len, head_num_kv, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
key = key.to(dtype)
|
||||
value = value.to(dtype)
|
||||
if quant_bit == 8 and group_size == head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
|
||||
if quant_bit == 8 and group_size != head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
if quant_bit == 4 and group_size == head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len), dtype=torch.float32).mlu()
|
||||
if quant_bit == 4 and group_size != head_size:
|
||||
key_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len, head_size // 2), dtype=torch.int8).mlu()
|
||||
value_cache = torch.randint(-128, 127, (max_batch, head_num_kv, cache_mem_len // 2, head_size), dtype=torch.int8).mlu()
|
||||
key_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn((max_batch, head_num_kv, cache_mem_len, head_size // group_size), dtype=torch.float32).mlu()
|
||||
|
||||
cache_bs_id = None
|
||||
cache_bs_id = random.sample([*range(0, max_batch)], batch)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
|
||||
if packed > 0:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
cu_context_lens, max_context_len,
|
||||
packed > 0, None,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
quant_bit,
|
||||
repeats=args.repeat_times)
|
||||
else:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
context_lens, max_context_len,
|
||||
packed > 0, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
quant_bit,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
io_bytes = key.nelement() * (key.element_size() + 1) * 2
|
||||
io_eff = io_bytes / hardware_time / bandwidth
|
||||
content = [f"{batch}", f"{max_context_len}", f"{head_num_kv}", f"{head_size}", f"{packed}",
|
||||
f"{dtype}", f"{quant_bit}", f"{group_size}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
61
torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py
Normal file
61
torch_mlu_ops-v1.3.2/benchmarks/benchmark_quantize.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
params_dict = {"dynamic": [True],
|
||||
"token_num": [1, 72, 490, 512, 525, 1024, 4096, 8192, 32768],
|
||||
"hidden_size": [8192, 1024],
|
||||
"input_dtype": [torch.bfloat16]}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
titles = ["dynamic", "token_num", "hidden_size", "input_dytpe", "hardware_time(us)", "e2e_latency(us)", "IO efficiency"]
|
||||
contents = []
|
||||
params_list = product(params_dict["dynamic"], params_dict["token_num"], params_dict["hidden_size"], params_dict["input_dtype"])
|
||||
bd = get_band_width()
|
||||
for param in params_list:
|
||||
dynamic, token_num, hidden_size, dtype = param[0], param[1], param[2], param[3]
|
||||
input_shape = (token_num, hidden_size)
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
dtype = torch.half
|
||||
input = torch.randn(input_shape).to(device).to(dtype)
|
||||
scale = torch.randn(input_shape[-1]).to(device).to(torch.float32)
|
||||
zero = None
|
||||
if dynamic:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.per_token_smooth_quantize,
|
||||
input,
|
||||
scale,
|
||||
zero,
|
||||
None,
|
||||
repeats=args.repeat_times)
|
||||
else:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.quantize,
|
||||
input,
|
||||
scale,
|
||||
zero,
|
||||
repeats=args.repeat_times)
|
||||
io_bytes = (input.element_size() + 1) * input.nelement() + scale.element_size() * scale.nelement()
|
||||
io_eff = io_bytes / hardware_time / bd
|
||||
content = [f"{dynamic}", f"{token_num}", f"{hidden_size}", f"{dtype}", f"{hardware_time}", f"{e2e_time}", f"{io_eff}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,109 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"max_batch": 20, "batch": 10, "cache_mem_len": 1024, "max_context_len": 512,
|
||||
"max_seq_offset": 20, "head_num_q": 32, "head_num_kv": 32, "head_size": 128,
|
||||
"packed": False, "input_dtype": [torch.float16, torch.bfloat16]}
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["batch", "max_context_len", "head_num_q", "head_num_kv", "head_size", "packed", "input_dytpe", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
max_batch = params_dict["max_batch"]
|
||||
batch = params_dict["batch"]
|
||||
cache_mem_len = params_dict["cache_mem_len"]
|
||||
max_context_len = params_dict["max_context_len"]
|
||||
max_seq_offset = params_dict["max_seq_offset"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
packed = params_dict["packed"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
context_lens = torch.randint(size=(batch, ), low=max_context_len,
|
||||
high=max_context_len+1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
# max_seq_offset = max_context_len // 3 + 1
|
||||
context_seq_offsets = torch.randint(size=(batch, ), low=max_seq_offset, high=max_seq_offset+1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
cache_seq_offsets = torch.randint(size=(batch, ), low=-1,
|
||||
high=(cache_mem_len - max_context_len) // 3 + 1,
|
||||
dtype=torch.int32, device='mlu')
|
||||
|
||||
cu_context_lens = torch.cumsum(context_lens, dim=-1)
|
||||
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
|
||||
total_seqlen = cu_context_lens[-1]
|
||||
total_heads = head_num_q + 2 * head_num_kv
|
||||
if packed > 0:
|
||||
context = torch.randn((total_seqlen, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
else:
|
||||
context = torch.randn((batch, max_context_len + max_seq_offset, total_heads, head_size),
|
||||
dtype=torch.float, device='mlu')
|
||||
cache = torch.randn((2, max_batch, head_num_kv, cache_mem_len, head_size), dtype=torch.float, device='mlu')
|
||||
context = context.to(dtype)
|
||||
cache = cache.to(dtype)
|
||||
key = context[..., head_num_q : head_num_q + head_num_kv, :]
|
||||
value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
|
||||
key_cache = cache[0]
|
||||
value_cache = cache[1]
|
||||
|
||||
cache_bs_id = None
|
||||
cache_bs_id = random.sample([*range(0, max_batch)], batch)
|
||||
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
|
||||
|
||||
if packed > 0:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
cu_context_lens, max_context_len,
|
||||
packed > 0, None,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
else:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.reshape_linear_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
context_lens, max_context_len,
|
||||
packed > 0, context_seq_offsets,
|
||||
cache_bs_id, cache_seq_offsets,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
content = [f"{batch}", f"{max_context_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{packed}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
|
||||
"head_num_kv": 32, "head_size": 128, "quantize": True, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"num_tokens": 1024, "num_block": 500, "block_size": 6, "head_num_q": 32,
|
||||
"head_num_kv": 32, "head_size": 128, "quantize": False, "input_dtype": [torch.float16, torch.bfloat16]}
|
||||
]
|
||||
|
||||
def main():
|
||||
if 'MLU3' in torch.mlu.get_device_name():
|
||||
exit()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
titles = ["num_tokens", "num_block", "block_size", "head_num_q", "head_num_kv", "head_size", "input_dytpe", "quantize", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
num_tokens = params_dict["num_tokens"]
|
||||
num_blocks = params_dict["num_block"]
|
||||
block_size = params_dict["block_size"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
quantize = params_dict["quantize"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
qkv = torch.randn(num_tokens, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
|
||||
key = qkv[:, head_num_q : head_num_q + head_num_kv, :]
|
||||
value = qkv[:, head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :]
|
||||
|
||||
key_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
|
||||
value_cache = torch.randn(num_blocks, head_num_kv, block_size, head_size, dtype=dtype).mlu()
|
||||
|
||||
num_slots = num_blocks * block_size
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
|
||||
slot_mapping[-1] = -1
|
||||
if not quantize:
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.reshape_paged_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
slot_mapping,
|
||||
repeats=args.repeat_times)
|
||||
else:
|
||||
k_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
|
||||
v_cache_quant_scale = torch.randn(num_blocks, head_num_kv, block_size).to('mlu').to(torch.float32)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.quant_to_paged_cache,
|
||||
key, value,
|
||||
key_cache, value_cache,
|
||||
k_cache_quant_scale,
|
||||
v_cache_quant_scale,
|
||||
slot_mapping,
|
||||
repeats=args.repeat_times)
|
||||
|
||||
content = [f"{num_tokens}", f"{num_blocks}", f"{block_size}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{dtype}", f"{quantize}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [
|
||||
{"batch": 16, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
|
||||
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
|
||||
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 128, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
|
||||
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
|
||||
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 512, "max_seq_len": 4096, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
|
||||
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
|
||||
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
|
||||
|
||||
{"batch": 16, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
|
||||
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
|
||||
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
|
||||
{"batch": 128, "max_seq_len": 32768, "head_num_q": 8, "head_num_kv": 1, "head_size": 128,
|
||||
"block_size": 16, "alibi_bias": False, "kv_cache_dtype": torch.int8, "use_paged_attn": False,
|
||||
"is_pertoken": False, "input_dtype": [torch.bfloat16]},
|
||||
]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "max_seq_len", "head_num_q", "head_num_kv", "head_size", "block_size", "alibi_bias", "kv_cache_dtype", "use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
max_seq_len = params_dict["max_seq_len"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
block_size = params_dict["block_size"]
|
||||
alibi_bias = params_dict["alibi_bias"]
|
||||
kv_cache_dtype = params_dict["kv_cache_dtype"]
|
||||
use_paged_attn = params_dict["use_paged_attn"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
is_pertoken = params_dict["is_pertoken"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
|
||||
input_q = input_qkv[..., 0 : head_num_q, :]
|
||||
|
||||
context_lens = torch.randint(max_seq_len, max_seq_len + 1, (batch, ), dtype=torch.int32).to(device)
|
||||
max_context_len = int(max(context_lens))
|
||||
if use_paged_attn:
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name:
|
||||
print("pagedattn is not implement on mlu370, skip it")
|
||||
continue
|
||||
block_size = 16
|
||||
else:
|
||||
block_size = max_seq_len + 512
|
||||
num_blocks = batch * ((max_seq_len + block_size - 1) // block_size)
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape = (num_blocks, head_num_kv, block_size, head_size)
|
||||
scale_shape = (num_blocks, head_num_kv, block_size) if is_pertoken else (head_num_kv, head_size)
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
if kv_cache_dtype is not torch.int8:
|
||||
key_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
|
||||
value_cache = torch.randn(size=cache_shape, dtype=torch.float16).to(device)
|
||||
key_cache_scale = None
|
||||
value_cache_scale = None
|
||||
else:
|
||||
key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
|
||||
value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_cache_dtype).to(device)
|
||||
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
|
||||
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).to(device)
|
||||
alibi_slopes = None
|
||||
if alibi_bias:
|
||||
alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.single_query_cached_kv_attn,
|
||||
input_q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
None,
|
||||
block_tables,
|
||||
context_lens,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
alibi_slopes,
|
||||
max_context_len,
|
||||
-1,
|
||||
-1,
|
||||
softmax_scale,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{max_seq_len}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{block_size}", f"{alibi_bias}", f"{kv_cache_dtype}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import *
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 16, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
|
||||
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
|
||||
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
|
||||
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
|
||||
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 512, "max_seq_len_lp": 4064, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
|
||||
"head_size": 128, "alibi_bias": True, "quant_bit_lp": 4, "quant_bit_hp": 8,
|
||||
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 16, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
|
||||
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
|
||||
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"batch": 128, "max_seq_len_lp": 32736, "max_seq_len_hp": 32, "head_num_q": 8, "head_num_kv": 1,
|
||||
"head_size": 128, "alibi_bias": False, "quant_bit_lp": 4, "quant_bit_hp": 8,
|
||||
"use_paged_attn": False, "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def gen_cache(batch, num_kv_heads, head_size, is_pagedattn, max_context_len, data_type, quant_bit, quant_mode):
|
||||
int_max = float(2 ** (quant_bit - 1) - 1)
|
||||
int_min = -float(2 ** (quant_bit - 1))
|
||||
context_lens = torch.randint(max_context_len, max_context_len + 1, (batch, ), dtype=torch.int32).mlu()
|
||||
block_size = 16
|
||||
if is_pagedattn is False:
|
||||
block_size = max_context_len
|
||||
num_blocks = (int)(batch * ((max_context_len + block_size - 1)/ block_size))
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
|
||||
if quant_mode == "per_token":
|
||||
scale_shape = (num_blocks, num_kv_heads, block_size, 1)
|
||||
else: # per channel
|
||||
scale_shape = (num_kv_heads, head_size)
|
||||
|
||||
if quant_bit == 4:
|
||||
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2)
|
||||
cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size)
|
||||
key_cache = torch.zeros(cache_shape_k_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
value_cache = torch.zeros(cache_shape_v_int4).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
elif quant_bit == 8:
|
||||
key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
|
||||
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
|
||||
elif quant_bit == -1:
|
||||
key_cache = torch.randn(cache_shape, dtype=data_type).mlu()
|
||||
value_cache = torch.randn(cache_shape, dtype=data_type).mlu()
|
||||
key_cache_scale = None
|
||||
value_cache_scale = None
|
||||
else:
|
||||
print("!!!!!!!!!!!gen case error, quant_bit must be in {-1, 4, 8}")
|
||||
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
|
||||
return key_cache, value_cache, key_cache_scale, value_cache_scale, context_lens, block_tables
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "max_seq_len_lp", "max_seq_len_hp", "head_num_q", "head_num_kv", "head_size", "alibi_bias", "quant_bit_lp", "quant_bit_hp","use_paged_attn", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
max_seq_len_lp = params_dict["max_seq_len_lp"]
|
||||
max_seq_len_hp = params_dict["max_seq_len_hp"]
|
||||
head_num_q = params_dict["head_num_q"]
|
||||
head_num_kv = params_dict["head_num_kv"]
|
||||
head_size = params_dict["head_size"]
|
||||
alibi_bias = params_dict["alibi_bias"]
|
||||
quant_bit_lp = params_dict["quant_bit_lp"]
|
||||
quant_bit_hp = params_dict["quant_bit_hp"]
|
||||
use_paged_attn = params_dict["use_paged_attn"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
input_qkv = torch.randn((batch, 1, head_num_q + 2 * head_num_kv, head_size)).to(device).to(dtype)
|
||||
input_q = input_qkv[..., 0 : head_num_q, :]
|
||||
|
||||
params_lp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_lp, dtype, quant_bit_lp, "per_token")
|
||||
params_hp = gen_cache(batch, head_num_kv, head_size, use_paged_attn, max_seq_len_hp, dtype, quant_bit_hp, "per_channel")
|
||||
key_cache_lp, value_cache_lp, key_cache_scale_lp, value_cache_scale_lp, context_lens_lp, block_tables_lp = params_lp
|
||||
key_cache_hp, value_cache_hp, key_cache_scale_hp, value_cache_scale_hp, context_lens_hp, block_tables_hp = params_hp
|
||||
|
||||
alibi_slopes = None
|
||||
if alibi_bias:
|
||||
alibi_slopes = torch.zeros((batch, head_num_q), dtype=torch.float32).to(device)
|
||||
alibi_slopes.uniform_(0, 0.125)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.single_query_mixed_cached_kv_attn,
|
||||
input_q,
|
||||
key_cache_lp, value_cache_lp,
|
||||
key_cache_hp, value_cache_hp,
|
||||
None, #output
|
||||
block_tables_lp, block_tables_hp,
|
||||
context_lens_lp, context_lens_hp,
|
||||
key_cache_scale_lp, value_cache_scale_lp,
|
||||
key_cache_scale_hp, value_cache_scale_hp,
|
||||
alibi_slopes,
|
||||
max_seq_len_lp, max_seq_len_hp,
|
||||
softmax_scale, True,
|
||||
quant_bit_lp, quant_bit_hp,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{max_seq_len_lp}", f"{max_seq_len_hp}", f"{head_num_q}", f"{head_num_kv}", f"{head_size}", f"{alibi_bias}", f"{quant_bit_lp}", f"{quant_bit_hp}", f"{use_paged_attn}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
|
||||
"act_mode": "none", "output_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
|
||||
"act_mode": "silu", "output_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "output_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
m = params_dict["m"]
|
||||
k = params_dict["k"]
|
||||
n = params_dict["n"]
|
||||
has_c = params_dict["has_c"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
output_dtype_list = params_dict["output_dtype"]
|
||||
for dtype in output_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
a = torch.randn(m, k).to(device).to(torch.int8)
|
||||
b = torch.randn(n, k).to(device).to(torch.int8)
|
||||
a_scale = torch.randn(m).to(device)
|
||||
b_scale = torch.randn(n).to(device)
|
||||
c = None
|
||||
if has_c:
|
||||
c = torch.randn(m, n).to(device).to(dtype)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(n).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.smooth_quant_matmul,
|
||||
a,
|
||||
a_scale,
|
||||
b,
|
||||
b_scale,
|
||||
dtype,
|
||||
bias,
|
||||
c,
|
||||
act_mode,
|
||||
1.0,
|
||||
1.0,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
15
torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh
Normal file
15
torch_mlu_ops-v1.3.2/benchmarks/benchmark_test.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
LOG_PATH=${LOG_PATH:-.}
|
||||
|
||||
files=($(ls benchmark_*.py))
|
||||
for file in "${files[@]}"; do
|
||||
echo "test ${file}..."
|
||||
op_name=$(basename "$file" .py)
|
||||
python "$file" > ${LOG_PATH}/${op_name}.log 2>&1
|
||||
ret_tmp=$?
|
||||
cat ${LOG_PATH}/${op_name}.log
|
||||
if [ $ret_tmp != 0 ]; then
|
||||
echo "${sc} test failed..."
|
||||
exit $ret_tmp
|
||||
fi
|
||||
done
|
||||
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
import random
|
||||
|
||||
e2e_time_param_dict_list = [{"batch": 16, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
|
||||
"dtype": [torch.float16], "pack": False},
|
||||
{"batch": 128, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
|
||||
"dtype": [torch.float16], "pack": False},
|
||||
{"batch": 512, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
|
||||
"dtype": [torch.float16], "pack": False},
|
||||
{"batch": 1024, "head_num": 8, "head_size": 128, "block_seq_len": 1, "max_seq_len": 1,
|
||||
"dtype": [torch.float16], "pack": False},
|
||||
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 1024, "max_seq_len": 1024,
|
||||
"dtype": [torch.float16], "pack": True},
|
||||
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 2048, "max_seq_len": 2048,
|
||||
"dtype": [torch.float16], "pack": True},
|
||||
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 8192, "max_seq_len": 8192,
|
||||
"dtype": [torch.float16], "pack": True},
|
||||
{"batch": 1, "head_num": 8, "head_size": 128, "block_seq_len": 32768, "max_seq_len": 32768,
|
||||
"dtype": [torch.float16], "pack": True},]
|
||||
|
||||
def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack):
|
||||
if not pack:
|
||||
out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype)
|
||||
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
|
||||
block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype)
|
||||
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
|
||||
seq_offset = None
|
||||
cu_seqs = None
|
||||
block_cu_seqs = None
|
||||
else:
|
||||
seq_lens = torch.randint(low=max_seq_len, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32)
|
||||
block_seq_lens = torch.randint(low=block_seq_len, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32)
|
||||
block_seq_lens = torch.minimum(seq_lens, block_seq_lens)
|
||||
seq_offset = torch.zeros_like(seq_lens)
|
||||
for i in range(batch):
|
||||
seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32)
|
||||
seq_offset = seq_offset.mlu()
|
||||
cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu()
|
||||
block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu()
|
||||
total_seqs = torch.sum(seq_lens)
|
||||
block_total_seqs = torch.sum(block_seq_lens)
|
||||
|
||||
out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype)
|
||||
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
|
||||
block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype)
|
||||
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
|
||||
return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["batch", "head_num", "head_size", "block_seq_len", "max_seq_len", "dtype", "pack", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
batch = params_dict["batch"]
|
||||
head_num = params_dict["head_num"]
|
||||
head_size = params_dict["head_size"]
|
||||
block_seq_len = params_dict["block_seq_len"]
|
||||
max_seq_len = params_dict["max_seq_len"]
|
||||
dtype_list = params_dict["dtype"]
|
||||
pack = params_dict["pack"]
|
||||
for dtype in dtype_list:
|
||||
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
|
||||
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
|
||||
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.update_out_and_lse,
|
||||
out,
|
||||
lse,
|
||||
block_out,
|
||||
block_lse,
|
||||
seq_offset,
|
||||
cu_seqs,
|
||||
block_cu_seqs,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{batch}", f"{head_num}", f"{head_size}", f"{block_seq_len}", f"{max_seq_len}", f"{dtype}", f"{pack}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch_mlu_ops as tmo
|
||||
from common import benchmark_forward, save_to_csv, save_to_csv
|
||||
import argparse
|
||||
from tabulate import tabulate
|
||||
import os
|
||||
|
||||
e2e_time_param_dict_list = [{"m": 1024, "k": 4096, "n": 14336, "has_c": True, "has_bias": True,
|
||||
"act_mode": "none", "quant_bit": 8, "input_dtype": [torch.float16, torch.bfloat16]},
|
||||
{"m": 1024, "k": 5120, "n": 13824, "has_c": False, "has_bias": True,
|
||||
"act_mode": "silu", "quant_bit": 4, "input_dtype": [torch.float16, torch.bfloat16]}]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--repeat_times', type=int, default=10, help='repeat times for testing')
|
||||
parser.add_argument('--csv', action='store_true', help='write the report data to csv')
|
||||
parser.add_argument('-o', type=str, help='specify the output folder name under --csv mode')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = 'mlu'
|
||||
|
||||
titles = ["m", "k", "n", "has_c", "has_bias", "act_mode", "quant_bit", "input_dtype", "hardware_time(us)", "e2e_latency(us)"]
|
||||
contents = []
|
||||
for params_dict in e2e_time_param_dict_list:
|
||||
m = params_dict["m"]
|
||||
k = params_dict["k"]
|
||||
n = params_dict["n"]
|
||||
quant_bit = params_dict["quant_bit"]
|
||||
has_c = params_dict["has_c"]
|
||||
has_bias = params_dict["has_bias"]
|
||||
act_mode = params_dict["act_mode"]
|
||||
input_dtype_list = params_dict["input_dtype"]
|
||||
for dtype in input_dtype_list:
|
||||
if dtype == torch.bfloat16 and not torch_mlu.mlu.is_bf16_supported():
|
||||
continue
|
||||
a = torch.randn(m, k).to(device).to(dtype)
|
||||
b = torch.randn(n, k if quant_bit == 8 else k//2).to(device).to(torch.int8)
|
||||
scale = torch.randn(n).to(device)
|
||||
zero = None
|
||||
c = None
|
||||
if has_c:
|
||||
c = torch.randn(m, n).to(device).to(dtype)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(n).to(device).to(dtype)
|
||||
hardware_time, e2e_time = benchmark_forward(tmo.weight_only_quant_matmul,
|
||||
a,
|
||||
b,
|
||||
scale,
|
||||
zero,
|
||||
bias,
|
||||
c,
|
||||
act_mode,
|
||||
quant_bit,
|
||||
1.0,
|
||||
1.0,
|
||||
repeats=args.repeat_times)
|
||||
content = [f"{m}", f"{k}", f"{n}", f"{has_c}", f"{has_bias}", f"{act_mode}", f"{quant_bit}", f"{dtype}", f"{hardware_time}", f"{e2e_time}"]
|
||||
contents.append(content)
|
||||
table = [titles] + contents
|
||||
print(tabulate(table, headers="firstrow", tablefmt="grid"))
|
||||
|
||||
if args.csv:
|
||||
current_file_path = __file__
|
||||
_, file_name = os.path.split(current_file_path)
|
||||
save_to_csv(table, args.o, file_name)
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
65
torch_mlu_ops-v1.3.2/benchmarks/common.py
Normal file
65
torch_mlu_ops-v1.3.2/benchmarks/common.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch_mlu
|
||||
import time
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import os
|
||||
import subprocess
|
||||
from itertools import product
|
||||
|
||||
def benchmark_forward(fn, *inputs, repeats=1, **kwinputs):
|
||||
notify_start = torch.mlu.Event(enable_timing=True)
|
||||
notify_end = torch.mlu.Event(enable_timing=True)
|
||||
notify_start.record()
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(repeats):
|
||||
fn(*inputs, **kwinputs)
|
||||
notify_end.record()
|
||||
notify_end.synchronize()
|
||||
total_e2e_time = time.perf_counter() - t0
|
||||
average_e2e_time = total_e2e_time / repeats * 1e6
|
||||
total_hardware_time = notify_start.hardware_time(notify_end)
|
||||
average_hardware_time = total_hardware_time / repeats
|
||||
return average_hardware_time, average_e2e_time
|
||||
|
||||
def save_to_csv(table, file_path, file_name):
|
||||
file_name_without_ext, _ = os.path.splitext(file_name)
|
||||
new_file_name = file_name_without_ext + '.csv'
|
||||
if file_path is None:
|
||||
file_path = './'
|
||||
path = Path(file_path)
|
||||
if path.suffix:
|
||||
directory = path.parent
|
||||
filename = path.name
|
||||
else:
|
||||
directory = path
|
||||
filename = new_file_name
|
||||
if not directory.exists():
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
full_path = directory / filename
|
||||
if not full_path.exists():
|
||||
full_path.touch()
|
||||
with open(full_path, mode="w", newline="") as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerows(table)
|
||||
print(f"output saved at: {full_path}")
|
||||
|
||||
def get_band_width(card_id: int = 0):
|
||||
cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2"
|
||||
res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert res.returncode == 0, "Failed to get BandWidth."
|
||||
bd = int(res.stdout.decode().strip())
|
||||
return bd
|
||||
|
||||
def generate_token_count(num_expert,
|
||||
total_token_count):
|
||||
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
|
||||
sum = torch.sum(token_count, dim=-1) * 1.0
|
||||
token_count *= total_token_count / sum.item()
|
||||
token_count = token_count.to(dtype=torch.int32)
|
||||
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
|
||||
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
|
||||
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
|
||||
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
|
||||
cusum_token_count[-1] = total_token_count
|
||||
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]
|
||||
8
torch_mlu_ops-v1.3.2/build.property
Normal file
8
torch_mlu_ops-v1.3.2/build.property
Normal file
@@ -0,0 +1,8 @@
|
||||
TMO_VERSION=1.3.2-1
|
||||
CNCL_VERSION=1.24.1-1
|
||||
CNNL_VERSION=1.28.3-1
|
||||
CNNLEXTRA_VERSION=1.12.3-1
|
||||
CNTOOLKIT_VERSION=3.15.6-1
|
||||
MLUOPS_VERSION=1.4.2-1
|
||||
TORCH_VERSION=1.24.1-1
|
||||
TRITON_VERSION=1.3.1-1
|
||||
66
torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp
Normal file
66
torch_mlu_ops-v1.3.2/csrc/common/stack_exception.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifdef __GNUC__
|
||||
|
||||
#include "stack_exception.h"
|
||||
#include <cxxabi.h>
|
||||
#include <dlfcn.h>
|
||||
#include <execinfo.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#define MAX_DEPTH 32
|
||||
|
||||
namespace tmo {
|
||||
namespace stack_exception {
|
||||
|
||||
call_stack::call_stack(const size_t num_discard /*= 0*/) {
|
||||
using namespace abi;
|
||||
|
||||
// retrieve call-stack
|
||||
void *trace[MAX_DEPTH];
|
||||
int stack_depth = backtrace(trace, MAX_DEPTH);
|
||||
|
||||
for (int i = num_discard + 1; i < stack_depth; i++) {
|
||||
Dl_info dlinfo;
|
||||
if (!dladdr(trace[i], &dlinfo)) break;
|
||||
|
||||
const char *symname = dlinfo.dli_sname;
|
||||
|
||||
int status;
|
||||
char *demangled = abi::__cxa_demangle(symname, NULL, 0, &status);
|
||||
if (status == 0 && demangled) symname = demangled;
|
||||
|
||||
// printf("entry: %s, %s\n", dlinfo.dli_fname,symname);
|
||||
|
||||
// store entry to stack
|
||||
if (dlinfo.dli_fname && symname) {
|
||||
entry e;
|
||||
e.file = dlinfo.dli_fname;
|
||||
e.line = 0; // unsupported
|
||||
e.function = symname;
|
||||
stack.push_back(e);
|
||||
} else {
|
||||
break; // skip last entries below main
|
||||
}
|
||||
|
||||
if (demangled) free(demangled);
|
||||
}
|
||||
}
|
||||
|
||||
call_stack::~call_stack() throw() {
|
||||
// automatic cleanup
|
||||
}
|
||||
} // namespace stack_exception
|
||||
} // namespace tmo
|
||||
|
||||
#endif // __GNUC__
|
||||
127
torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h
Normal file
127
torch_mlu_ops-v1.3.2/csrc/common/stack_exception.h
Normal file
@@ -0,0 +1,127 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_COMMON_STACK_EXCEPTION_H_
|
||||
#define CSRC_COMMON_STACK_EXCEPTION_H_
|
||||
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tmo {
|
||||
namespace stack_exception {
|
||||
/** Call-stack entry datastructure. */
|
||||
struct entry {
|
||||
/** Default constructor that clears all fields. */
|
||||
entry() : line(0) {}
|
||||
|
||||
std::string file; ///< filename
|
||||
size_t line; ///< line number
|
||||
std::string function; ///< name of function or method
|
||||
|
||||
/** Serialize entry into a text string. */
|
||||
std::string to_string() const {
|
||||
std::ostringstream os;
|
||||
os << file;
|
||||
if (line > 0) {
|
||||
os << ":" << line;
|
||||
}
|
||||
os << " @" << function;
|
||||
return os.str();
|
||||
}
|
||||
};
|
||||
|
||||
/** Stack-trace base class, for retrieving the current call-stack. */
|
||||
class call_stack {
|
||||
public:
|
||||
/** Stack-trace consructor.
|
||||
\param num_discard - number of stack entries to discard at the top. */
|
||||
call_stack(const size_t num_discard = 0);
|
||||
|
||||
virtual ~call_stack() throw();
|
||||
|
||||
/** Serializes the entire call-stack into a text string. */
|
||||
std::string to_string() const {
|
||||
std::ostringstream os;
|
||||
for (size_t i = 0; i < stack.size(); i++)
|
||||
os << stack[i].to_string() << std::endl;
|
||||
return os.str();
|
||||
}
|
||||
|
||||
/** Call stack. */
|
||||
std::vector<entry> stack;
|
||||
};
|
||||
|
||||
/** Abstract base-class for all stack-augmented exception classes.
|
||||
* Enables catching of all stack-augmented exception classes. */
|
||||
class stack_exception_base : public call_stack {
|
||||
public:
|
||||
stack_exception_base(const bool _show_stack) : call_stack(2), show_stack(_show_stack) {}
|
||||
virtual ~stack_exception_base() throw() {}
|
||||
|
||||
virtual const char *what() const throw() = 0;
|
||||
|
||||
/// flag to indicate if stack-trace is included in what() messages
|
||||
bool show_stack;
|
||||
};
|
||||
|
||||
/** Template for stack-augmented exception classes. */
|
||||
template <class T>
|
||||
class stack_exception : public T, public stack_exception_base {
|
||||
public:
|
||||
stack_exception(const std::string &msg) : T(msg), stack_exception_base(true) {}
|
||||
virtual ~stack_exception() throw() {}
|
||||
|
||||
stack_exception(const char *file, int line, const char *pretty_function, const std::string &msg)
|
||||
: T(msg), stack_exception_base(true) {
|
||||
entry e;
|
||||
e.file = file;
|
||||
e.line = line;
|
||||
e.function = pretty_function;
|
||||
stack.insert(stack.begin(), e);
|
||||
}
|
||||
|
||||
virtual const char *what() const throw() {
|
||||
if (show_stack) {
|
||||
// concatenate message with stack trace
|
||||
buffer = "[" + std::string(T::what()) + "]\n" + stack_exception::to_string();
|
||||
return buffer.c_str();
|
||||
} else {
|
||||
return T::what();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::string buffer;
|
||||
};
|
||||
/** Stack-augmented exception classes for all std::exception classes. */
|
||||
// typedef stack_exception<std::runtime_error> TmoException;
|
||||
// typedef stack_exception<std::range_error> stack_range_error;
|
||||
// typedef stack_exception<std::overflow_error> stack_overflow_error;
|
||||
// typedef stack_exception<std::underflow_error> stack_underflow_error;
|
||||
// typedef stack_exception<std::logic_error> stack_logic_error;
|
||||
// typedef stack_exception<std::domain_error> stack_domain_error;
|
||||
// typedef stack_exception<std::invalid_argument> stack_invalid_argument;
|
||||
// typedef stack_exception<std::length_error> stack_length_error;
|
||||
// typedef stack_exception<std::out_of_range> stack_out_of_range;
|
||||
} // namespace stack_exception
|
||||
|
||||
class _TmoException : public stack_exception::stack_exception<std::runtime_error> {
|
||||
public:
|
||||
_TmoException(const std::string &msg) : stack_exception<std::runtime_error>(msg) {}
|
||||
_TmoException(const char *file, int line, const char *pretty_function, const std::string &msg)
|
||||
: stack_exception<std::runtime_error>(file, line, pretty_function, msg) {}
|
||||
};
|
||||
#define TmoException(msg) tmo::_TmoException(__FILE__, __LINE__, __PRETTY_FUNCTION__, msg)
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_COMMON_STACK_EXCEPTION_H_
|
||||
293
torch_mlu_ops-v1.3.2/csrc/common/utils.h
Normal file
293
torch_mlu_ops-v1.3.2/csrc/common/utils.h
Normal file
@@ -0,0 +1,293 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_COMMON_UTILS_H_
|
||||
#define CSRC_COMMON_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <future> // NOLINT
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "cn_api.h"
|
||||
#include "cnnl.h"
|
||||
#include "cnnl_extra.h"
|
||||
#include "cnrt.h"
|
||||
#include "stack_exception.h"
|
||||
|
||||
namespace tmo {
|
||||
inline cnnlQuantizeLayout_t strToQuantizeLayout(std::string param) {
|
||||
static std::map<std::string, cnnlQuantizeLayout_t> quantize_layout_map = {
|
||||
{"quantize_none", CNNL_QUANTIZE_NONE},
|
||||
{"quantize_per_tensor", CNNL_QUANTIZE_PER_TENSOR},
|
||||
{"quantize_per_channel", CNNL_QUANTIZE_PER_CHANNEL},
|
||||
{"quantize_per_token", CNNL_QUANTIZE_PER_TOKEN},
|
||||
{"quantize_group_wise", CNNL_QUANTIZE_GROUP_WISE}};
|
||||
return quantize_layout_map[param];
|
||||
}
|
||||
|
||||
inline cnnlActivationMode_t strToActivationMode(std::string param) {
|
||||
static std::map<std::string, cnnlActivationMode_t> act_mode_map = {
|
||||
{"gelu", CNNL_ACTIVATION_GELU},
|
||||
{"relu", CNNL_ACTIVATION_RELU},
|
||||
{"sigmoid", CNNL_ACTIVATION_SIGMOID},
|
||||
{"silu", CNNL_ACTIVATION_SWISH},
|
||||
{"none", CNNL_ACTIVATION_IDENTITY}};
|
||||
return act_mode_map[param];
|
||||
}
|
||||
inline cnnlLLMQuantAlgo_t strToQuantizeAlgo(std::string param) {
|
||||
static std::map<std::string, cnnlLLMQuantAlgo_t> quant_algo_map = {
|
||||
{"weight_only", CNNL_WEIGHT_ONLY},
|
||||
{"smooth_quant", CNNL_SMOOTH_QUANT},
|
||||
{"none", CNNL_NO_QUANT}};
|
||||
return quant_algo_map[param];
|
||||
}
|
||||
|
||||
namespace lnres {
|
||||
namespace internal {
|
||||
using LnresEnum = cnnlTransformerLayernormResidualStructure_t;
|
||||
struct Helper {
|
||||
int layernorm_position; // 0: no layernorm, 1: pre layernorm, 2: post layernorm
|
||||
int residual_position; // 0: no residual, 1: layernorm inside residual, 2: layernorm outside
|
||||
// residual
|
||||
constexpr Helper(cnnlTransformerLayernormResidualStructure_t mode);
|
||||
|
||||
constexpr Helper(int layernorm_position, int residual_position)
|
||||
: layernorm_position(layernorm_position), residual_position(residual_position) {}
|
||||
|
||||
constexpr bool operator==(const Helper &other) const {
|
||||
return layernorm_position == other.layernorm_position &&
|
||||
residual_position == other.residual_position;
|
||||
}
|
||||
|
||||
constexpr operator cnnlTransformerLayernormResidualStructure_t() const;
|
||||
};
|
||||
|
||||
constexpr int NO = 0;
|
||||
constexpr int PRE = 1;
|
||||
constexpr int POST = 2;
|
||||
constexpr int CONTAIN = 1;
|
||||
constexpr int EXCLUDE = 2;
|
||||
using TPair = std::pair<Helper, LnresEnum>;
|
||||
constexpr std::array<TPair, 9> pairs = {
|
||||
TPair{{NO, NO}, CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL}, // noResidual
|
||||
{{NO, CONTAIN}, CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL}, // useInputAsResidual
|
||||
{{NO, EXCLUDE}, CNNL_TRANSFORMER_NO_LAYERNORM_WITH_RESIDUAL}, // useInputAsResidual
|
||||
{{PRE, NO}, CNNL_TRANSFORMER_PRE_LAYERNORM_NO_RESIDUAL}, // noResidual
|
||||
{{PRE, CONTAIN}, CNNL_TRANSFORMER_PRE_LAYERNORM_INSIDE_RESIDUAL}, // useInputAsResidual
|
||||
// residualThenLayernorm
|
||||
{{PRE, EXCLUDE}, CNNL_TRANSFORMER_PRE_LAYERNORM_OUTSIDE_RESIDUAL}, // useLayernormAsResidual
|
||||
// residualThenLayernorm
|
||||
{{POST, NO}, CNNL_TRANSFORMER_POST_LAYERNORM_NO_RESIDUAL}, // noResidual
|
||||
{{POST, CONTAIN}, CNNL_TRANSFORMER_POST_LAYERNORM_INSIDE_RESIDUAL}, // useInputAsResidual
|
||||
// layernormThenResidual
|
||||
{{POST, EXCLUDE}, CNNL_TRANSFORMER_POST_LAYERNORM_OUTSIDE_RESIDUAL}, // useInputAsResidual
|
||||
// residualThenLayernorm
|
||||
};
|
||||
|
||||
constexpr Helper from(LnresEnum mode) {
|
||||
for (size_t i = 0; i < pairs.size(); ++i) {
|
||||
if (pairs[i].second == mode) {
|
||||
return pairs[i].first;
|
||||
}
|
||||
}
|
||||
// throw TmoException("Invalid cnnlTransformerLayernormResidualStructure_t");
|
||||
return Helper(NO, NO);
|
||||
}
|
||||
|
||||
constexpr LnresEnum to(Helper mode) {
|
||||
for (size_t i = 0; i < pairs.size(); ++i) {
|
||||
if (pairs[i].first == mode) {
|
||||
return pairs[i].second;
|
||||
}
|
||||
}
|
||||
return CNNL_TRANSFORMER_NO_LAYERNORM_NO_RESIDUAL;
|
||||
// throw TmoException("Invalid Helper");
|
||||
}
|
||||
|
||||
constexpr Helper::Helper(LnresEnum mode) : Helper(from(mode)) {}
|
||||
|
||||
constexpr Helper::operator LnresEnum() const {
|
||||
return to(*this);
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
using namespace internal;
|
||||
|
||||
inline LnresEnum makeLnresEnum(bool has_ln, bool has_residual, bool residual_is_input) {
|
||||
return Helper(has_ln ? PRE : NO, has_residual ? (residual_is_input ? CONTAIN : EXCLUDE) : NO);
|
||||
}
|
||||
|
||||
inline LnresEnum removeResidual(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return Helper(helper.layernorm_position, NO);
|
||||
}
|
||||
|
||||
inline LnresEnum removeLayernorm(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return Helper(NO, helper.residual_position);
|
||||
}
|
||||
|
||||
inline bool useLayernormAsResidual(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return helper.layernorm_position == PRE && helper.residual_position == EXCLUDE;
|
||||
}
|
||||
|
||||
inline bool useInputAsResidual(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return helper.residual_position == CONTAIN ||
|
||||
(helper.layernorm_position == NO && helper.residual_position != NO) ||
|
||||
(helper.layernorm_position == POST && helper.residual_position == EXCLUDE);
|
||||
}
|
||||
|
||||
inline bool hasResidual(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return helper.residual_position != NO;
|
||||
}
|
||||
|
||||
inline bool hasLayernorm(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return helper.layernorm_position != NO;
|
||||
}
|
||||
|
||||
inline bool isPostLayernorm(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return helper.layernorm_position == POST;
|
||||
}
|
||||
|
||||
inline bool isPreLayernorm(LnresEnum mode) {
|
||||
Helper helper(mode);
|
||||
return helper.layernorm_position == PRE;
|
||||
}
|
||||
|
||||
inline bool residualThenLayernorm(LnresEnum first_layer, LnresEnum second_layer) {
|
||||
Helper h1(first_layer);
|
||||
Helper h2(second_layer);
|
||||
if (h1.residual_position == NO) { // h1 has no residual
|
||||
return false;
|
||||
}
|
||||
|
||||
if (h1.layernorm_position == POST && h2.layernorm_position == PRE) {
|
||||
throw TmoException("too many layernorms");
|
||||
}
|
||||
|
||||
return (h1.residual_position != NO && h1.layernorm_position != POST &&
|
||||
h2.layernorm_position == PRE) || // l1 residual + l2 pre layernorm
|
||||
(h1.layernorm_position == POST && h2.layernorm_position != PRE &&
|
||||
h1.residual_position == EXCLUDE); // l1 inside residual + l1 post layernorm
|
||||
}
|
||||
|
||||
inline bool layernormThenResidual(LnresEnum first_layer, LnresEnum second_layer) {
|
||||
Helper h1(first_layer);
|
||||
Helper h2(second_layer);
|
||||
if (h1.residual_position == NO) { // h1 has no residual
|
||||
return false;
|
||||
}
|
||||
|
||||
if (h1.layernorm_position == POST && h2.layernorm_position == PRE) {
|
||||
throw TmoException("too many layernorms");
|
||||
}
|
||||
return (h1.layernorm_position == POST && h1.residual_position == CONTAIN);
|
||||
}
|
||||
|
||||
inline bool residualOnly(LnresEnum first_layer, LnresEnum second_layer) {
|
||||
Helper h1(first_layer);
|
||||
Helper h2(second_layer);
|
||||
return h1.residual_position != NO && h1.layernorm_position != POST &&
|
||||
h2.layernorm_position != PRE;
|
||||
}
|
||||
} // namespace lnres
|
||||
} // namespace tmo
|
||||
|
||||
#ifndef CNNL_CHECK
|
||||
#define CNNL_CHECK(expr) \
|
||||
if (expr != CNNL_STATUS_SUCCESS) { \
|
||||
std::cerr << __FILE__ << ":" << __LINE__ \
|
||||
<< " Check failed: " #expr " == CNNL_STATUS_SUCCESS. " << std::endl; \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CNNL_CHECK_FATAL(expr) \
|
||||
if ((expr) != CNNL_STATUS_SUCCESS) { \
|
||||
std::cerr << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< " Check failed: " #expr " == CNNL_STATUS_SUCCESS. " << std::endl; \
|
||||
throw TmoException("Check failed: " #expr " == CNNL_STATUS_SUCCESS."); \
|
||||
}
|
||||
|
||||
#define TMO_KERNEL_CHECK_FATAL(expr) \
|
||||
if ((expr) != tmo::KernelStatus::KERNEL_STATUS_SUCCESS) { \
|
||||
std::cerr << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< " Check failed: " #expr " == KernelStatus::KERNEL_STATUS_SUCCESS. " << std::endl; \
|
||||
throw TmoException("Check failed: " #expr " == KernelStatus::KERNEL_STATUS_SUCCESS."); \
|
||||
}
|
||||
|
||||
#define CHECK_FATAL(expr, ...) \
|
||||
if (!(expr)) { \
|
||||
std::cerr << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< " Check failed: " #expr ". " << tmo::stringize(__VA_ARGS__) << std::endl; \
|
||||
throw TmoException("Check failed: " #expr ". " + tmo::stringize(__VA_ARGS__)); \
|
||||
}
|
||||
|
||||
#undef CNRT_CHECK
|
||||
#define CNRT_CHECK(val) \
|
||||
do { \
|
||||
cnrtRet_t __ret = val; \
|
||||
if (__ret) { \
|
||||
printf("[%s:%d] CNRT error, code=%d(%s) \"%s\" \n", __FILE__, __LINE__, (unsigned int)__ret, \
|
||||
cnrtGetErrorStr(__ret), #val); \
|
||||
throw TmoException(cnrtGetErrorStr(__ret)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CN_CHECK(val) \
|
||||
do { \
|
||||
CNresult __ret = val; \
|
||||
if (__ret) { \
|
||||
const char *cn_err_string = nullptr; \
|
||||
cnGetErrorString(__ret, &cn_err_string); \
|
||||
printf("[%s:%d] CN error, code=%d(%s) \"%s\" \n", __FILE__, __LINE__, (unsigned int)__ret, \
|
||||
cn_err_string, #val); \
|
||||
throw TmoException(cn_err_string); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define PAD_UP_DIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#define TMO_EXPORT __attribute__((__visibility__("default")))
|
||||
#define TMO_HIDDEN __attribute__((__visibility__("hidden")))
|
||||
|
||||
#define DELETE_COPY_ASSIGN_CONSTRUCT(CLASSNAME) \
|
||||
CLASSNAME(const CLASSNAME &) = delete; \
|
||||
CLASSNAME(CLASSNAME &&) = delete; \
|
||||
CLASSNAME &operator=(const CLASSNAME &) = delete; \
|
||||
CLASSNAME &operator=(CLASSNAME &&) = delete;
|
||||
|
||||
// Note: Return type without const when const object called.
|
||||
#define CLASS_CAST_TYPE_OPERATOR_DEFINE(DESCNAME, DESCOBJECT) \
|
||||
inline operator DESCNAME() const { \
|
||||
return const_cast<DESCNAME>(DESCOBJECT); \
|
||||
} \
|
||||
inline operator DESCNAME() { \
|
||||
return DESCOBJECT; \
|
||||
}
|
||||
|
||||
#endif // CSRC_COMMON_UTILS_H_
|
||||
103
torch_mlu_ops-v1.3.2/csrc/kernels/CMakeLists.txt
Normal file
103
torch_mlu_ops-v1.3.2/csrc/kernels/CMakeLists.txt
Normal file
@@ -0,0 +1,103 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(tmo_kernels)
|
||||
message(STATUS "project name: ${PROJECT_NAME}")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
################################################################################
|
||||
# Build Evironment
|
||||
################################################################################
|
||||
set(BANG_TARGET_CPU_ARCH ${TARGET_CPU_ARCH})
|
||||
message("-- TARGET_CPU_ARCH=${TARGET_CPU_ARCH}")
|
||||
set(TARGET_MLU_ARCH ${TARGET_MLU_ARCH})
|
||||
message("-- TARGET_MLU_ARCH=${TARGET_MLU_ARCH}")
|
||||
set(NEUWARE_HOME ${NEUWARE_HOME})
|
||||
message("-- NEUWARE_HOME=${NEUWARE_HOME}")
|
||||
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
|
||||
"${CMAKE_SOURCE_DIR}/cmake"
|
||||
"${NEUWARE_HOME}/cmake"
|
||||
"${NEUWARE_HOME}/cmake/modules"
|
||||
)
|
||||
|
||||
find_package(BANG)
|
||||
if(NOT BANG_FOUND)
|
||||
message(FATAL_ERROR "BANG cannot be found.")
|
||||
else ()
|
||||
if (NOT BANG_CNCC_EXECUTABLE)
|
||||
message(FATAL_ERROR "cncc not found, please ensure cncc is in your PATH env or set variable BANG_CNCC_EXECUTABLE from cmake. Otherwise you should check path used by find_program(BANG_CNCC_EXECUTABLE) in FindBANG.cmake")
|
||||
endif()
|
||||
endif()
|
||||
set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}/test")
|
||||
set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/lib")
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -pthread -pipe")
|
||||
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${CMAKE_C_FLAGS} -g3 -O0")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${CMAKE_C_FLAGS} -O3")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fPIC -std=c++17 -pthread -pipe")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${CMAKE_CXX_FLAGS} -g3 -O0")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CMAKE_CXX_FLAGS} -O3")
|
||||
|
||||
set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS_RELEASE} -Wl,--gc-sections -fPIC")
|
||||
|
||||
set(BANG_CNCC_FLAGS "-Wall -Werror -Wdeprecated-declarations -fPIC -std=c++17 -pthread --target=${TARGET_CPU_ARCH}")
|
||||
|
||||
if ( "${_cncc_version}" VERSION_LESS "5.0.0") # [CNNLCORE-19128]
|
||||
message(STATUS "Default rounding mode will be rn when computing float numbers, otherwise will be tz when computing int numbers")
|
||||
# This compile option was enabled by JIRA: CNNLCORE-12027
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -Xbang-cnas --deprecated-cvt-default-round-mode-rn")
|
||||
endif()
|
||||
|
||||
if(${TARGET_CPU_ARCH} MATCHES ".*x86_64.*")
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -mcmodel=large")
|
||||
endif()
|
||||
|
||||
string(TOLOWER ${CMAKE_BUILD_TYPE} _CMAKE_BUILD_TYPE_LOWER)
|
||||
if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "debug")
|
||||
message(STATUS "Build debug mode")
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g3 -O0")
|
||||
endif()
|
||||
|
||||
if(${_CMAKE_BUILD_TYPE_LOWER} MATCHES "release")
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3 -DNDEBUG")
|
||||
endif()
|
||||
|
||||
if(${TARGET_MLU_ARCH} MATCHES "CNFATBIN")
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=mtp_592 --bang-mlu-arch=mtp_613 --no-neuware-version-check")
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64")
|
||||
else()
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=${TARGET_MLU_ARCH}")
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-wram-align64")
|
||||
endif()
|
||||
|
||||
# setup predefined macro for host sources, only for single mlu arch, useful for edge
|
||||
if (${TARGET_MLU_ARCH} MATCHES "^(m?tp_)?([0-9]+)$")
|
||||
# convert mtp_xxx or tp_xxx to xxx
|
||||
string(REGEX REPLACE "^(m?tp_)?([0-9]+)$" "\\2" _TARGET_MLU_ARCH ${TARGET_MLU_ARCH})
|
||||
add_definitions(-DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH})
|
||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -DTARGET_MLU_ARCH=${_TARGET_MLU_ARCH}")
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Neuware Evironment
|
||||
################################################################################
|
||||
if(EXISTS ${NEUWARE_HOME})
|
||||
include_directories("${NEUWARE_HOME}/include")
|
||||
link_directories("${NEUWARE_HOME}/lib64")
|
||||
link_directories("${NEUWARE_HOME}/lib")
|
||||
else()
|
||||
message(FATAL_ERROR "NEUWARE cannot be found, refer README.md to prepare NEUWARE_HOME environment.")
|
||||
endif()
|
||||
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
|
||||
################################################################################
|
||||
# Build TMO kernels
|
||||
################################################################################
|
||||
# aux_source_directory(src DIR_SRCS)
|
||||
file(GLOB_RECURSE bang_src_files FOLLOW_SYMLINKS "${CMAKE_CURRENT_SOURCE_DIR}/*.mlu")
|
||||
|
||||
bang_add_library(tmo_kernels STATIC "${bang_src_files}")
|
||||
target_link_libraries(tmo_kernels cnnl cnrt cndrv dl)
|
||||
28
torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mlu
Normal file
28
torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mlu
Normal file
@@ -0,0 +1,28 @@
|
||||
#include "add_scalar.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define ONCHIP_DATA_NUM ((int)(__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)))
|
||||
__nram__ int nram_buffer[ONCHIP_DATA_NUM];
|
||||
|
||||
__mlu_global__ void MLUBlockAddScalar(int *dst, int *src, int count, int scalar) {
|
||||
int offset = ONCHIP_DATA_NUM * taskId;
|
||||
int deal_num = std::min(ONCHIP_DATA_NUM, count - offset);
|
||||
if (deal_num <= 0) return;
|
||||
__memcpy(nram_buffer, src + offset, deal_num * sizeof(int), GDRAM2NRAM);
|
||||
__bang_add_scalar(nram_buffer, nram_buffer, scalar, deal_num);
|
||||
__memcpy(dst + offset, nram_buffer, deal_num * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar) {
|
||||
uint32_t task_dim = (count + ONCHIP_DATA_NUM - 1) / ONCHIP_DATA_NUM;
|
||||
cnrtDim3_t dim{task_dim, 1, 1};
|
||||
kernels::MLUBlockAddScalar<<<dim, cnrtFuncTypeBlock, queue>>>(dst, src, count, scalar);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
29
torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mluh
Normal file
29
torch_mlu_ops-v1.3.2/csrc/kernels/add_scalar.mluh
Normal file
@@ -0,0 +1,29 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_ADD_SCALAR_MLUH_
|
||||
#define CSRC_KERNELS_ADD_SCALAR_MLUH_
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Add src with a scalar and save the result to dst.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param dst: Pointer to the MLU memory of dst.
|
||||
* @param src: Pointer to the MLU memory of src.
|
||||
* @param count: The elements number in src.
|
||||
* @param scalar: The scalar to add.
|
||||
* @note: only support int. dst can overlap with src.
|
||||
*/
|
||||
KernelStatus invokeMLUAddScalar(cnrtQueue_t queue, int *dst, int *src, int count, int scalar);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_ADD_SCALAR_MLUH_
|
||||
205
torch_mlu_ops-v1.3.2/csrc/kernels/build.sh
Executable file
205
torch_mlu_ops-v1.3.2/csrc/kernels/build.sh
Executable file
@@ -0,0 +1,205 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
TOP_DIR="$( cd "$( dirname "$0" )" && pwd )"
|
||||
cd ${TOP_DIR}
|
||||
|
||||
################################################################################
|
||||
# Evironment Variables
|
||||
|
||||
# BUILD_MODE: release/debug
|
||||
# BUILD_DIR: build(default)
|
||||
# TARGET_MLU_ARCH: CNFATBIN/MLU590
|
||||
# TARGET_CPU_ARCH: x86_64-linux-gnu
|
||||
# TARGET_C_COMPILER: C comppiler full-path
|
||||
# TARGET_CXX_COMPILER: CXX comppiler full-path
|
||||
# STRIP strip tool path
|
||||
################################################################################
|
||||
BUILD_MODE=${BUILD_MODE:-release}
|
||||
BUILD_DIR="${BUILD_DIR:-build}"
|
||||
BUILD_JOBS=${BUILD_JOBS:-32}
|
||||
TARGET_MLU_ARCH=${TARGET_MLU_ARCH:-CNFATBIN}
|
||||
TARGET_CPU_ARCH=${TARGET_CPU_ARCH:-$(uname -m)-linux-gnu}
|
||||
TARGET_C_COMPILER=${TARGET_C_COMPILER:-gcc}
|
||||
TARGET_CXX_COMPILER=${TARGET_CXX_COMPILER:-g++}
|
||||
STRIP="${STRIP}" # empty by default, check later
|
||||
|
||||
# to forward variable to other scripts
|
||||
export BUILD_DIR
|
||||
|
||||
################################################################################
|
||||
# Shell Common Functions
|
||||
################################################################################
|
||||
check_deb_package() {
|
||||
if [ -z "$(dpkg -l | grep ${1})" ]; then
|
||||
echo "-- Please sudo apt install ${1}"
|
||||
exit -1
|
||||
fi
|
||||
}
|
||||
|
||||
check_rpm_package() {
|
||||
if [ -z "$(rpm -qa | grep ${1})" ]; then
|
||||
echo "-- Please sudo yum install ${1}"
|
||||
exit -1
|
||||
fi
|
||||
}
|
||||
|
||||
usage () {
|
||||
echo "USAGE: build.sh <options>"
|
||||
echo
|
||||
echo " If need specify neuware path, please:"
|
||||
echo " export NEUWARE_HOME=/path/of/your/neuware"
|
||||
echo
|
||||
echo "OPTIONS:"
|
||||
echo " -h, --help Print usage"
|
||||
echo " <null> If no --mluxxx specified, default arch is cnfatbin which contain all mlu arch"
|
||||
echo " --mlu590 Build for target product MLU590: __BANG_ARCH__ = 592"
|
||||
echo " cncc --bang-mlu-arch=mtp_592, cnas --mlu-arch mtp_592"
|
||||
echo " -d, --debug Build test case with debug mode"
|
||||
echo " -v, --verbose Build with verbose output"
|
||||
echo " -j, --jobs=* Build parallel jobs"
|
||||
echo " --cache Build without deleting BUILD_DIR contents first"
|
||||
}
|
||||
|
||||
################################################################################
|
||||
# Build Main Entry
|
||||
################################################################################
|
||||
# 1. Check cmake tool for build, cmake-3.23.1 is recommended
|
||||
if [ -f "/etc/os-release" ]; then
|
||||
source /etc/os-release
|
||||
if [[ "${NAME}" == Ubuntu* ]] || [[ "${NAME}" == Debian* ]]; then
|
||||
check_deb_package cmake
|
||||
CMAKE=cmake
|
||||
elif [[ "${NAME}" == CentOS* ]] || [[ "${NAME}" == Kylin* ]]; then
|
||||
if [[ "${VERSION_ID}" == 7 ]]; then
|
||||
check_rpm_package cmake3
|
||||
CMAKE=cmake3
|
||||
else
|
||||
check_rpm_package cmake
|
||||
CMAKE=cmake
|
||||
fi
|
||||
elif [[ "${NAME}" == Anolis* ]];then
|
||||
check_rpm_package cmake
|
||||
CMAKE=cmake
|
||||
else
|
||||
echo "-- Not support build on this os!"
|
||||
exit -1
|
||||
fi
|
||||
else
|
||||
echo "-- Not support build on this os!"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
# 2. Create build dir
|
||||
if [ ! -d "$BUILD_DIR" ]; then
|
||||
mkdir "$BUILD_DIR"
|
||||
fi
|
||||
|
||||
# 3. Handle build options
|
||||
cmdline_args=$(getopt -o h,d,v,j: --long help,debug,verbose,jobs:,mlu590,cache -n 'build.sh' -- "$@")
|
||||
eval set -- "$cmdline_args"
|
||||
if [ $? != 0 ]; then echo "Unknown options, use -h or --help" >&2 ; exit -1; fi
|
||||
if [ $# != 0 ]; then
|
||||
while true; do
|
||||
case "$1" in
|
||||
--mlu590)
|
||||
TARGET_MLU_ARCH="mtp_592"
|
||||
shift
|
||||
;;
|
||||
-h | --help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
-d | --debug)
|
||||
BUILD_MODE="debug"
|
||||
echo "-- Using debug mode."
|
||||
shift
|
||||
;;
|
||||
-v | --verbose)
|
||||
BUILD_VERBOSE="VERBOSE=1"
|
||||
shift
|
||||
;;
|
||||
-j | --jobs)
|
||||
shift
|
||||
BUILD_JOBS=$1
|
||||
shift
|
||||
;;
|
||||
--cache)
|
||||
FLAG_KEEP_CACHE=1
|
||||
shift
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
echo "-- Unknown options ${1}, use -h or --help"
|
||||
usage
|
||||
exit -1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
fi
|
||||
|
||||
# 5. Check NEUWARE_HOME and cncc
|
||||
if [ ! -z "${NEUWARE_HOME}" ]; then
|
||||
echo "-- using NEUWARE_HOME = ${NEUWARE_HOME}"
|
||||
else
|
||||
echo "-- NEUWARE_HOME is null, refer README.md to prepare NEUWARE_HOME environment."
|
||||
exit -1
|
||||
fi
|
||||
|
||||
# 6. Check device compiler
|
||||
export PATH="${NEUWARE_HOME}/bin":$PATH
|
||||
export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64":$LD_LIBRARY_PATH
|
||||
if [ -z $(which cncc) ]; then
|
||||
echo "-- ERROR: cannot find cncc"
|
||||
exit -1
|
||||
fi
|
||||
cncc --version || ( echo "-- ERROR: cncc is not for current CPU target" && exit -1 )
|
||||
echo "-- cncc: $(which cncc)"
|
||||
|
||||
# Check host compiler
|
||||
## check compiler version and consider activate devtoolset for CentOS 7
|
||||
if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then
|
||||
if [ ! -f "/opt/rh/devtoolset-7/enable" ]; then
|
||||
echo "You are using CentOS 7 but without 'devtoolset-7' installed."
|
||||
echo "Please install devtoolset-7 or gnu-g++ that verion >= 5."
|
||||
sleep 2
|
||||
else
|
||||
source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \
|
||||
|| echo "devtoolset-7 has installed on your server, but source failed."
|
||||
fi
|
||||
fi
|
||||
if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then
|
||||
echo "we do not support g++<5, try to use higher version"
|
||||
exit 1
|
||||
fi
|
||||
TARGET_C_COMPILER=$(which gcc)
|
||||
TARGET_CXX_COMPILER=$(which g++)
|
||||
echo "-- TARGET_C_COMPILER: " ${TARGET_C_COMPILER}
|
||||
echo "-- TARGET_CXX_COMPILER: " ${TARGET_CXX_COMPILER}
|
||||
export CC=$(basename ${TARGET_C_COMPILER})
|
||||
export CXX=$(basename ${TARGET_CXX_COMPILER})
|
||||
|
||||
################################################################################
|
||||
# Project Build
|
||||
################################################################################
|
||||
CMAKE_EXTRA_OPTIONS=()
|
||||
|
||||
SOURCE_DIR=${TOP_DIR}
|
||||
pushd ${BUILD_DIR}
|
||||
if [[ -z "${FLAG_KEEP_CACHE}" ]]; then
|
||||
echo "Remove cmake cache ${PWD}"
|
||||
rm -rf ./*
|
||||
fi
|
||||
${CMAKE} -DCMAKE_BUILD_TYPE="${BUILD_MODE}" \
|
||||
-DNEUWARE_HOME="${NEUWARE_HOME}" \
|
||||
-DTARGET_MLU_ARCH="${TARGET_MLU_ARCH}" \
|
||||
-DTARGET_CPU_ARCH="${TARGET_CPU_ARCH}" \
|
||||
-DCMAKE_C_COMPILER="$(basename ${TARGET_C_COMPILER})" \
|
||||
-DCMAKE_CXX_COMPILER="$(basename ${TARGET_CXX_COMPILER})" \
|
||||
-DCMAKE_STRIP="${STRIP}" \
|
||||
${CMAKE_EXTRA_OPTIONS[@]} ${SOURCE_DIR}
|
||||
popd
|
||||
${CMAKE} --build ${BUILD_DIR} -- ${BUILD_VERBOSE} -j${BUILD_JOBS}
|
||||
192
torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mlu
Normal file
192
torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mlu
Normal file
@@ -0,0 +1,192 @@
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "copy_blocks.mluh"
|
||||
#include "kernel_utils.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
#define NRAM_REMAIN_SIZE (32 * 1024)
|
||||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
||||
#define USE_GATHER_THRESHHOLD_BLOCKSIZE 458753
|
||||
#define LAYER_SIZE 128
|
||||
#define BLOCK_PAIR_SIZE 512
|
||||
#define ALIGN_BYTES 64
|
||||
|
||||
struct CopyBlocksInfo {
|
||||
void *key_addrs[LAYER_SIZE];
|
||||
void *value_addrs[LAYER_SIZE];
|
||||
unsigned int mapping_addrs[BLOCK_PAIR_SIZE * 2];
|
||||
bool has_value_cache = true;
|
||||
};
|
||||
|
||||
namespace kernels {
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
__mlu_func__ void copyBlocksNodld(CopyBlocksInfo info,
|
||||
uint32_t num_per_core,
|
||||
uint32_t block_mapping_offset,
|
||||
int32_t num_layers,
|
||||
uint32_t block_size_in_bytes) {
|
||||
for (uint32_t i = 0; i < num_per_core; i++) {
|
||||
uint32_t map_offset = block_mapping_offset + i * 2;
|
||||
uint32_t src_idx = info.mapping_addrs[map_offset];
|
||||
uint32_t dst_idx = info.mapping_addrs[map_offset + 1];
|
||||
int64_t src_offset = block_size_in_bytes * src_idx;
|
||||
int64_t dst_offset = block_size_in_bytes * dst_idx;
|
||||
for (uint32_t j = 0; j < num_layers; j++) {
|
||||
__memcpy((int8_t *)info.key_addrs[j] + dst_offset, (int8_t *)info.key_addrs[j] + src_offset,
|
||||
block_size_in_bytes, GDRAM2GDRAM);
|
||||
if (info.has_value_cache) {
|
||||
__memcpy((int8_t *)info.value_addrs[j] + dst_offset,
|
||||
(int8_t *)info.value_addrs[j] + src_offset, block_size_in_bytes, GDRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void launchCopyBlocksKernel(CopyBlocksInfo info,
|
||||
int32_t num_pairs,
|
||||
int32_t num_layers,
|
||||
uint32_t block_size_in_bytes) {
|
||||
uint32_t num_per_core = num_pairs / taskDim;
|
||||
uint32_t remain_for_core = num_pairs % taskDim;
|
||||
num_per_core += ((taskId < remain_for_core) ? 1 : 0);
|
||||
uint32_t block_mapping_offset =
|
||||
num_per_core * taskId + ((taskId < remain_for_core) ? 0 : remain_for_core);
|
||||
block_mapping_offset *= 2;
|
||||
#if (__BANG_ARCH__ >= 592)
|
||||
if (block_size_in_bytes < USE_GATHER_THRESHHOLD_BLOCKSIZE) {
|
||||
auto num_pair_data_width = sizeof(int32_t);
|
||||
uint32_t align_num = ALIGN_BYTES / num_pair_data_width;
|
||||
unsigned int num_per_core_2 = num_per_core * 2;
|
||||
unsigned int num_per_core_2_align = (num_per_core_2 + align_num - 1) / align_num * align_num;
|
||||
unsigned int *gather_src_offset = (unsigned int *)nram_buffer;
|
||||
unsigned int *block_mapping_src_dst = gather_src_offset + num_per_core_2_align;
|
||||
int8_t *n_buffer = (int8_t *)(block_mapping_src_dst + num_per_core_2_align);
|
||||
uint32_t nram_remain = NRAM_BUFFER_SIZE - sizeof(unsigned int *) * num_per_core_2_align * 2;
|
||||
unsigned int *scatter_dst_offset = gather_src_offset + num_per_core;
|
||||
|
||||
uint32_t num_per_loop = nram_remain / block_size_in_bytes;
|
||||
uint32_t repeat = num_per_core / num_per_loop;
|
||||
uint32_t remain = num_per_core % num_per_loop;
|
||||
|
||||
for (int i = 0; i < num_per_core; i++) {
|
||||
unsigned int mapping_addrs_idx = block_mapping_offset + i * 2;
|
||||
block_mapping_src_dst[i] = info.mapping_addrs[mapping_addrs_idx];
|
||||
block_mapping_src_dst[num_per_core + i] = info.mapping_addrs[mapping_addrs_idx + 1];
|
||||
}
|
||||
__bang_mul_scalar(gather_src_offset, block_mapping_src_dst, (unsigned int)block_size_in_bytes,
|
||||
num_per_core_2);
|
||||
__sync();
|
||||
for (uint32_t k = 0; k < num_layers; k++) {
|
||||
for (uint32_t i = 0; i < repeat; i++) {
|
||||
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + i * num_per_loop,
|
||||
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
||||
(unsigned int)block_size_in_bytes, num_per_loop);
|
||||
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
|
||||
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
||||
(unsigned int)block_size_in_bytes, num_per_loop);
|
||||
if (info.has_value_cache) {
|
||||
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + i * num_per_loop,
|
||||
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
||||
(unsigned int)block_size_in_bytes, num_per_loop);
|
||||
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + i * num_per_loop,
|
||||
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
||||
(unsigned int)block_size_in_bytes, num_per_loop);
|
||||
}
|
||||
}
|
||||
if (remain != 0) {
|
||||
uint32_t repeat_nums = repeat * num_per_loop;
|
||||
__gather_async(n_buffer, info.key_addrs[k], gather_src_offset + repeat_nums,
|
||||
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
||||
(unsigned int)block_size_in_bytes, remain);
|
||||
__scatter_async(info.key_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
|
||||
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
||||
(unsigned int)block_size_in_bytes, remain);
|
||||
if (info.has_value_cache) {
|
||||
__gather_async(n_buffer, info.value_addrs[k], gather_src_offset + repeat_nums,
|
||||
(unsigned int)block_size_in_bytes, GDRAM2NRAM,
|
||||
(unsigned int)block_size_in_bytes, remain);
|
||||
__scatter_async(info.value_addrs[k], n_buffer, scatter_dst_offset + repeat_nums,
|
||||
(unsigned int)block_size_in_bytes, NRAM2GDRAM,
|
||||
(unsigned int)block_size_in_bytes, remain);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
|
||||
}
|
||||
#else
|
||||
copyBlocksNodld(info, num_per_core, block_mapping_offset, num_layers, block_size_in_bytes);
|
||||
#endif
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
|
||||
const std::vector<void *> &key_caches,
|
||||
const std::vector<void *> &value_caches,
|
||||
const std::vector<int32_t> &block_mapping_vec,
|
||||
const size_t block_size_in_bytes) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
|
||||
if (key_caches.empty()) {
|
||||
std::cerr << "[invokeCopyBlocksKernel]: key_caches can not be empty." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (!value_caches.empty() && key_caches.size() != value_caches.size()) {
|
||||
std::cerr << "[invokeCopyBlocksKernel]: key_caches size must equal to value_caches "
|
||||
<< "size if value_caches is not empty." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
int32_t mapping_size = block_mapping_vec.size();
|
||||
int32_t num_pairs = mapping_size / 2;
|
||||
uint32_t task_dim = std::min(num_pairs, cluster_num * core_num);
|
||||
cnrtDim3_t k_dim{task_dim, 1, 1};
|
||||
|
||||
int32_t num_layers = key_caches.size();
|
||||
int32_t layer_loop_num = std::ceil(float(num_layers) / LAYER_SIZE);
|
||||
int32_t layer_num_per_loop = std::ceil(float(num_layers) / layer_loop_num);
|
||||
int32_t pair_loop_num = std::ceil(float(num_pairs) / BLOCK_PAIR_SIZE);
|
||||
int32_t pair_num_per_loop = std::ceil(float(num_pairs) / pair_loop_num);
|
||||
|
||||
CopyBlocksInfo info;
|
||||
if (value_caches.empty()) {
|
||||
info.has_value_cache = false;
|
||||
}
|
||||
for (int32_t i = 0; i < layer_loop_num; i++) {
|
||||
int32_t sub_num_layers =
|
||||
std::min(int32_t(layer_num_per_loop), num_layers - i * layer_num_per_loop);
|
||||
for (int32_t l = 0; l < sub_num_layers; l++) {
|
||||
info.key_addrs[l] = key_caches[l + i * layer_num_per_loop];
|
||||
if (info.has_value_cache) {
|
||||
info.value_addrs[l] = value_caches[l + i * layer_num_per_loop];
|
||||
}
|
||||
}
|
||||
for (int32_t j = 0; j < pair_loop_num; j++) {
|
||||
int32_t sub_num_pairs =
|
||||
std::min(int32_t(pair_num_per_loop), num_pairs - j * pair_num_per_loop);
|
||||
int32_t lens_block_mapping = sub_num_pairs * 2;
|
||||
int32_t block_vec_offset = j * pair_num_per_loop * 2;
|
||||
for (int32_t m = 0; m < lens_block_mapping; m++) {
|
||||
info.mapping_addrs[m] = block_mapping_vec[m + block_vec_offset];
|
||||
}
|
||||
kernels::launchCopyBlocksKernel<<<k_dim, k_type, queue>>>(info, sub_num_pairs, sub_num_layers,
|
||||
block_size_in_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
37
torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mluh
Normal file
37
torch_mlu_ops-v1.3.2/csrc/kernels/copy_blocks.mluh
Normal file
@@ -0,0 +1,37 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_COPY_BLOCKS_MLUH_
|
||||
#define CSRC_KERNELS_COPY_BLOCKS_MLUH_
|
||||
|
||||
#include <vector>
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Perform copy_blocks operation.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param key_caches: Output/Input. Pointer to the MLU memory that stores the key_caches
|
||||
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
|
||||
* @param value_caches: Output/Input. Pointer to the MLU memory that stores the value_caches
|
||||
* vector<tensor> which has shape [num_layers<num_blocks, num_heads, block_size, head_size>].
|
||||
* @param block_mapping_vec: block_mapping vector.
|
||||
* @param block_size_in_bytes: one block data size.
|
||||
*/
|
||||
KernelStatus invokeCopyBlocksKernel(const cnrtQueue_t queue,
|
||||
const std::vector<void *> &key_caches,
|
||||
const std::vector<void *> &value_caches,
|
||||
const std::vector<int32_t> &block_mapping_vec,
|
||||
const size_t block_size_in_bytes);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_COPY_BLOCKS_MLUH_
|
||||
271
torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mlu
Normal file
271
torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mlu
Normal file
@@ -0,0 +1,271 @@
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "create_cos_sin_table.mluh"
|
||||
|
||||
namespace {
|
||||
// constexpr int LINEAR_SCALING = 0;
|
||||
// constexpr int FIX_NTK_SCALING = 1;
|
||||
constexpr int DYNAMIC_NTK_SCALING = 2;
|
||||
} // namespace
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - 32 * 1024];
|
||||
__nram__ const float range[64] = {
|
||||
0.0F, 2.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 20.0F,
|
||||
22.0F, 24.0F, 26.0F, 28.0F, 30.0F, 32.0F, 34.0F, 36.0F, 38.0F, 40.0F, 42.0F,
|
||||
44.0F, 46.0F, 48.0F, 50.0F, 52.0F, 54.0F, 56.0F, 58.0F, 60.0F, 62.0F, 64.0F,
|
||||
66.0F, 68.0F, 70.0F, 72.0F, 74.0F, 76.0F, 78.0F, 80.0F, 82.0F, 84.0F, 86.0F,
|
||||
88.0F, 90.0F, 92.0F, 94.0F, 96.0F, 98.0F, 100.0F, 102.0F, 104.0F, 106.0F, 108.0F,
|
||||
110.0F, 112.0F, 114.0F, 116.0F, 118.0F, 120.0F, 122.0F, 124.0F, 126.0F};
|
||||
|
||||
__mlu_func__ void genRangeDims(float *range_nram, int elem_count) {
|
||||
int count = 64;
|
||||
__bang_move(range_nram, range, std::min(count, elem_count) * sizeof(float));
|
||||
while (count < elem_count) {
|
||||
__bang_add_scalar(range_nram + count, range_nram, (float)count * 2.0F,
|
||||
std::min(count, elem_count - count));
|
||||
count *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ int getBatchMaxSeqLen(int *seq_lens_nram, int *seq_lens, int batch) {
|
||||
__memcpy(seq_lens_nram, seq_lens, batch * sizeof(int), GDRAM2NRAM);
|
||||
__bang_argmax((float *)seq_lens_nram, (float *)seq_lens_nram, batch);
|
||||
return __load_nram(seq_lens_nram);
|
||||
}
|
||||
|
||||
__mlu_func__ float getNTKAlpha(int curr_seq_len, int max_position_embeddings, int kv_seq_len) {
|
||||
int seq_len = kv_seq_len > max_position_embeddings ? curr_seq_len : kv_seq_len;
|
||||
float context_value = std::log2((float)seq_len / (float)max_position_embeddings) + 1.0F;
|
||||
float ntk_alpha = std::pow(2.0F, std::ceil(context_value)) - 1.0F;
|
||||
return std::max(ntk_alpha, 1.0F);
|
||||
}
|
||||
|
||||
__mlu_func__ void getRotaryInvFreq(float *inv_freq_nram,
|
||||
float *base_nram,
|
||||
float *range_nram,
|
||||
float base,
|
||||
int rotary_dim,
|
||||
int elem_count) {
|
||||
// inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
|
||||
__bang_write_value(base_nram, elem_count, base);
|
||||
__bang_mul_scalar(inv_freq_nram, range_nram, 1.0F / (float)rotary_dim, elem_count);
|
||||
__bang_log(base_nram, base_nram, elem_count);
|
||||
__bang_mul(inv_freq_nram, inv_freq_nram, base_nram, elem_count);
|
||||
__bang_pow2(inv_freq_nram, inv_freq_nram, elem_count);
|
||||
__bang_recip(inv_freq_nram, inv_freq_nram, elem_count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void convertCosSinTable(float *cos_table, float *sin_table, int elem_count) {}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void convertCosSinTable<half>(float *cos_table, float *sin_table, int elem_count) {
|
||||
__bang_float2half((half *)cos_table, cos_table, elem_count);
|
||||
__bang_float2half((half *)sin_table, sin_table, elem_count);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void convertCosSinTable<bfloat16_t>(float *cos_table,
|
||||
float *sin_table,
|
||||
int elem_count) {
|
||||
__bang_float2bfloat16((bfloat16_t *)cos_table, cos_table, elem_count);
|
||||
__bang_float2bfloat16((bfloat16_t *)sin_table, sin_table, elem_count);
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUUpdateCachedAlpha(float *rotary_emb_alpha_cached,
|
||||
int *seq_lens,
|
||||
int max_position_embeddings,
|
||||
int batch) {
|
||||
int *seq_lens_nram = (int *)nram_buffer; // [batch]
|
||||
int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch);
|
||||
rotary_emb_alpha_cached[taskIdY] =
|
||||
getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUCreateCosSinTableKernel(void *cos_sin_table,
|
||||
float *rotary_emb_alpha_cached,
|
||||
int *seq_lens,
|
||||
int max_position_embeddings,
|
||||
int batch,
|
||||
int batch_stride,
|
||||
int rotary_seq_len,
|
||||
int rotary_dim,
|
||||
int rotary_stride,
|
||||
float rotary_base,
|
||||
float rotary_scaling,
|
||||
int rotary_scaling_type,
|
||||
int seq_seg,
|
||||
bool interleaved,
|
||||
cnnlDataType_t dtype) {
|
||||
int half_rotary_dim = rotary_dim / 2;
|
||||
float *base_nram = (float *)nram_buffer; // [rotary_dim / 2]
|
||||
float *range_nram = base_nram + half_rotary_dim; // [rotary_dim / 2]
|
||||
float *inv_freq_nram = range_nram + half_rotary_dim; // [rotary_dim / 2]
|
||||
float *freqs_nram = inv_freq_nram + half_rotary_dim; // [rotary_dim / 2]
|
||||
float *cos_nram = freqs_nram + half_rotary_dim; // [rotary_dim]
|
||||
float *sin_nram = cos_nram + rotary_dim; // [rotary_dim]
|
||||
float *swap_nram = sin_nram + rotary_dim; // [rotary_dim]
|
||||
int *seq_lens_nram = (int *)(swap_nram + rotary_dim); // [batch]
|
||||
|
||||
genRangeDims(range_nram, half_rotary_dim);
|
||||
|
||||
float adjust_base = rotary_base;
|
||||
if (rotary_scaling_type == DYNAMIC_NTK_SCALING) {
|
||||
int kv_seq_len = getBatchMaxSeqLen(seq_lens_nram, seq_lens, batch);
|
||||
float ntk_alpha = getNTKAlpha(seq_lens[taskIdY], max_position_embeddings, kv_seq_len);
|
||||
if (rotary_emb_alpha_cached[taskIdY] == ntk_alpha) {
|
||||
return;
|
||||
}
|
||||
adjust_base = rotary_base * std::pow(ntk_alpha, (float)rotary_dim / (float)(rotary_dim - 2));
|
||||
}
|
||||
|
||||
getRotaryInvFreq(inv_freq_nram, base_nram, range_nram, adjust_base, rotary_dim, half_rotary_dim);
|
||||
|
||||
int seq_start = taskIdX * seq_seg;
|
||||
int seq_end = (taskIdX + 1) * seq_seg > rotary_seq_len ? rotary_seq_len : (taskIdX + 1) * seq_seg;
|
||||
|
||||
T *cos_table = (T *)cos_sin_table + (size_t)taskIdY * batch_stride;
|
||||
T *sin_table = cos_table + rotary_dim;
|
||||
|
||||
for (int idx = seq_start; idx < seq_end; ++idx) {
|
||||
__bang_mul_scalar(freqs_nram, inv_freq_nram, idx, half_rotary_dim);
|
||||
__bang_cos(cos_nram, freqs_nram, half_rotary_dim);
|
||||
__bang_sin(sin_nram, freqs_nram, half_rotary_dim);
|
||||
convertCosSinTable<T>(cos_nram, sin_nram, half_rotary_dim);
|
||||
if (!interleaved) {
|
||||
__memcpy(cos_table + idx * rotary_stride, cos_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM,
|
||||
half_rotary_dim * sizeof(T), 0, 1);
|
||||
__memcpy(sin_table + idx * rotary_stride, sin_nram, half_rotary_dim * sizeof(T), NRAM2GDRAM,
|
||||
half_rotary_dim * sizeof(T), 0, 1);
|
||||
} else {
|
||||
__bang_move((T *)cos_nram + half_rotary_dim, (T *)cos_nram, half_rotary_dim * sizeof(T));
|
||||
__bang_transpose((T *)swap_nram, (T *)cos_nram, 2, half_rotary_dim);
|
||||
__memcpy(cos_table + idx * rotary_stride, (T *)swap_nram, half_rotary_dim * 2 * sizeof(T),
|
||||
NRAM2GDRAM);
|
||||
__bang_move((T *)sin_nram + half_rotary_dim, (T *)sin_nram, half_rotary_dim * sizeof(T));
|
||||
__bang_transpose((T *)cos_nram, (T *)sin_nram, 2, half_rotary_dim);
|
||||
__memcpy((T *)sin_table + idx * rotary_stride, (T *)cos_nram, half_rotary_dim * 2 * sizeof(T),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if __BANG_ARCH__ < 592
|
||||
template <>
|
||||
__mlu_global__ void MLUCreateCosSinTableKernel<bfloat16_t>(void *cos_sin_table,
|
||||
float *rotary_emb_alpha_cached,
|
||||
int *seq_lens,
|
||||
int max_position_embeddings,
|
||||
int batch,
|
||||
int batch_stride,
|
||||
int rotary_seq_len,
|
||||
int rotary_dim,
|
||||
int rotary_stride,
|
||||
float rotary_base,
|
||||
float rotary_scaling,
|
||||
int rotary_scaling_type,
|
||||
int seq_seg,
|
||||
bool interleaved,
|
||||
cnnlDataType_t dtype) {}
|
||||
#endif
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue,
|
||||
void *cos_sin_table,
|
||||
float *rotary_emb_alpha_cached,
|
||||
int *seq_lens,
|
||||
int max_position_embeddings,
|
||||
int batch,
|
||||
int batch_stride,
|
||||
int rotary_seq_len,
|
||||
int rotary_dim,
|
||||
int rotary_stride,
|
||||
float rotary_base,
|
||||
float rotary_scaling,
|
||||
int rotary_scaling_type,
|
||||
bool interleaved,
|
||||
cnnlDataType_t data_type) {
|
||||
bool is_supported_dtype = data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_FLOAT ||
|
||||
data_type == CNNL_DTYPE_BFLOAT16;
|
||||
if (!is_supported_dtype) {
|
||||
std::cerr << "[invokeCreateCosSinTable]: unsupport data type for create cos sin table kernel."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
void (*create_sin_cos_kernels[])(
|
||||
void*, /* cos_sin_table */
|
||||
float*, /* rotary_emb_alpha_cached */
|
||||
int*, /* seq_lens */
|
||||
int, /* max_position_embeddings */
|
||||
int, /* batch */
|
||||
int, /* batch_stride */
|
||||
int, /* rotary_seq_len */
|
||||
int, /* rotary_dim */
|
||||
int, /* rotary_stride */
|
||||
float, /* rotary_base */
|
||||
float, /* rotary_scaling */
|
||||
int, /* rotary_scaling_type */
|
||||
int, /* seq_seg */
|
||||
bool, /* interleaved */
|
||||
cnnlDataType_t /* data_type */
|
||||
) = {
|
||||
kernels::MLUCreateCosSinTableKernel<half>,
|
||||
kernels::MLUCreateCosSinTableKernel<float>,
|
||||
kernels::MLUCreateCosSinTableKernel<bfloat16_t>
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
int kernel_index = 0;
|
||||
if (data_type == CNNL_DTYPE_HALF) {
|
||||
kernel_index = 0;
|
||||
} else if (data_type == CNNL_DTYPE_FLOAT) {
|
||||
kernel_index = 1;
|
||||
} else {
|
||||
kernel_index = 2;
|
||||
}
|
||||
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num = 1;
|
||||
int core_num = 1;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
int used_core_num = std::min(rotary_seq_len, cluster_num * core_num);
|
||||
int seq_seg = (rotary_seq_len + used_core_num - 1) / used_core_num;
|
||||
|
||||
cnrtDim3_t dim1;
|
||||
dim1.x = used_core_num;
|
||||
dim1.y = rotary_scaling_type == DYNAMIC_NTK_SCALING ? batch : 1;
|
||||
dim1.z = 1;
|
||||
|
||||
if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
|
||||
std::cerr << "[invokeCreateCosSinTable]: MLU300 devices do not support bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
create_sin_cos_kernels[kernel_index]<<<dim1, cnrtFuncTypeBlock, queue>>>(
|
||||
cos_sin_table, rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch,
|
||||
batch_stride, rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling,
|
||||
rotary_scaling_type, seq_seg, interleaved, data_type);
|
||||
|
||||
if (rotary_scaling_type == DYNAMIC_NTK_SCALING) {
|
||||
cnrtDim3_t dim2;
|
||||
dim2.x = 1;
|
||||
dim2.y = batch;
|
||||
dim2.z = 1;
|
||||
kernels::MLUUpdateCachedAlpha<<<dim2, cnrtFuncTypeBlock, queue>>>(
|
||||
rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch);
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
62
torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mluh
Normal file
62
torch_mlu_ops-v1.3.2/csrc/kernels/create_cos_sin_table.mluh
Normal file
@@ -0,0 +1,62 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
|
||||
#define CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
/**
|
||||
* @brief Create cos and sin table for rotary embedding.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param cos_sin_table: Output. Pointer to the MLU memory that stores the cos and sin table.
|
||||
* If rotary_scaling_type is linear, the shape is [rotary_seq_len, rotary_stride].
|
||||
* If rotary_scaling_type is dynamic ntk, the shape is [batch, rotary_seq_len,
|
||||
* rotary_stride].
|
||||
* @param rotary_emb_alpha_cached: Output/Input. Pointer to the MLU memory that
|
||||
* stores the ntk alpha cache. Only used in dynamic ntk, the shape is [batch].
|
||||
* @param seq_lens: Input. Pointer to the MLU memory that stores the true sequence len.
|
||||
* The shape is [batch].
|
||||
* @param max_position_embeddings: The maximum rotary embedding positions.
|
||||
* @param batch: Batch size.
|
||||
* @param batch_stride: The stride for batch dim of cos_sin_table.
|
||||
* Only used in dynamic ntk, the value is rotary_seq_len * rotary_stride.
|
||||
* @param rotary_seq_len: The rotary sequence length of cos and sin table.
|
||||
* @param rotary_dim: The rotary dim value of cos and sin table.
|
||||
* @param rotary_stride: The stride of rotary_seq_len dim for cos and sin table.
|
||||
* @param rotary_base: The rotary base, value is usually 10000.
|
||||
* @param rotary_scaling: The rotary scaling, value is usually 1.
|
||||
* @param rotary_scaling_type: The rotary scaling type, value is linear or dynamic ntk.
|
||||
* @param interleaved: A boolean value indicates compute mode of rotary embedding.
|
||||
* @param dtype: Data type of cos and sin table generated.
|
||||
*/
|
||||
KernelStatus invokeCreateCosSinTable(cnrtQueue_t queue,
|
||||
void *cos_sin_table,
|
||||
float *rotary_emb_alpha_cached,
|
||||
int *seq_lens,
|
||||
int max_position_embeddings,
|
||||
int batch,
|
||||
int batch_stride,
|
||||
int rotary_seq_len,
|
||||
int rotary_dim,
|
||||
int rotary_stride,
|
||||
float rotary_base,
|
||||
float rotary_scaling,
|
||||
int rotary_scaling_type,
|
||||
bool interleaved,
|
||||
cnnlDataType_t dtype);
|
||||
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_CREATE_COS_SIN_TABLE_MLUH_
|
||||
812
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mlu
Normal file
812
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mlu
Normal file
@@ -0,0 +1,812 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include "dequant_from_linear_cache.mluh"
|
||||
#include "quant_utils.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#pragma bang walign(16)
|
||||
#define REM_FOR_STACK (32 * 1024)
|
||||
#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
|
||||
#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK)
|
||||
#define DEQUANT_LINEAR_PERHEAD kernels::MLUDequantFromLinearCacheKernelPerHead
|
||||
#define DEQUANT_LINEAR_PERCHANNEL kernels::MLUDequantFromLinearCacheKernelPerChannel
|
||||
#define DEQUANT_FUNC_LEN (24)
|
||||
#define DEQUANT_BATCH_NUM (1024)
|
||||
|
||||
__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE];
|
||||
__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE];
|
||||
__nram__ uint8_t pre_table_nram[TRANS_TABLE_SIZE];
|
||||
// Uses 8K = 1K * (4 + 4) to process offsets
|
||||
__nram__ int32_t n_lens[DEQUANT_BATCH_NUM];
|
||||
__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM];
|
||||
|
||||
__mlu_func__ void calcu_offsets_per_channel(int32_t &cache_id,
|
||||
size_t &context_offset,
|
||||
size_t &cache_offset,
|
||||
size_t &scale_offset,
|
||||
const int32_t *cache_bs_id,
|
||||
const int32_t *cache_seq_offsets,
|
||||
const int32_t cache_mem_len,
|
||||
const int32_t seq_len,
|
||||
const int32_t seq_begin,
|
||||
const int32_t seq_offset,
|
||||
const int32_t batch_idx,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t scale_bs_stride) {
|
||||
cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
|
||||
int32_t cache_seq_offset =
|
||||
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
|
||||
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
|
||||
context_offset = context_seq_stride * (seq_offset + seq_begin);
|
||||
cache_offset = cache_bs_stride * cache_id + cache_seq_stride * (cache_seq_offset + seq_begin);
|
||||
scale_offset = scale_bs_stride * cache_id;
|
||||
} else {
|
||||
cache_id = -1;
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void calcu_offsets_per_head(int32_t &cache_id,
|
||||
size_t &context_offset,
|
||||
size_t &key_cache_offset,
|
||||
size_t &value_cache_offset,
|
||||
size_t &scale_offset,
|
||||
const int32_t *cache_bs_id,
|
||||
const int32_t *cache_seq_offsets,
|
||||
const int32_t cache_mem_len,
|
||||
const int32_t seq_len,
|
||||
const int32_t seq_begin,
|
||||
const int32_t seq_offset,
|
||||
const int32_t batch_idx,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride) {
|
||||
cache_id = cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
|
||||
int32_t cache_seq_offset =
|
||||
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
|
||||
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
|
||||
context_offset = context_seq_stride * (seq_offset + seq_begin);
|
||||
key_cache_offset =
|
||||
cache_bs_stride * cache_id + key_cache_seq_stride * (cache_seq_offset + seq_begin);
|
||||
value_cache_offset =
|
||||
cache_bs_stride * cache_id + value_cache_seq_stride * (cache_seq_offset + seq_begin);
|
||||
scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id;
|
||||
} else {
|
||||
cache_id = -1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts>
|
||||
__mlu_func__ void dequantize_per_channel(T *output_nram,
|
||||
Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
T *data,
|
||||
Tc *cache,
|
||||
Ts *scale,
|
||||
const int32_t scale_num,
|
||||
const int32_t seq_num,
|
||||
const int32_t head_num,
|
||||
const int32_t head_size,
|
||||
const size_t context_offset,
|
||||
const size_t cache_offset,
|
||||
const size_t scale_offset,
|
||||
const size_t context_seq_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride) {
|
||||
if (scale_bs_stride != 0) {
|
||||
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM,
|
||||
head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
|
||||
}
|
||||
|
||||
if (std::is_same<Tc, int4x2_t>::value) {
|
||||
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM,
|
||||
head_size >> 1, head_num - 1, scale_num >> 1, seq_num - 1, cache_head_stride,
|
||||
head_num - 1, cache_seq_stride, seq_num - 1);
|
||||
} else {
|
||||
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM,
|
||||
head_size * sizeof_(Tc), head_num - 1, scale_num * sizeof_(Tc), seq_num - 1,
|
||||
cache_head_stride * sizeof_(Tc), head_num - 1, cache_seq_stride * sizeof_(Tc),
|
||||
seq_num - 1);
|
||||
}
|
||||
|
||||
dequantize<T, Tc, Ts>((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf,
|
||||
seq_num * scale_num, scale_num);
|
||||
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
|
||||
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
|
||||
seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1);
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts>
|
||||
__mlu_func__ void dequantize_value_per_channel(T *output_nram,
|
||||
Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
int8_t *temp_nram,
|
||||
T *data,
|
||||
Tc *cache,
|
||||
Ts *scale,
|
||||
const int32_t scale_num,
|
||||
const int32_t seq_num,
|
||||
const int32_t head_num,
|
||||
const int32_t head_size,
|
||||
const size_t context_offset,
|
||||
const size_t cache_offset,
|
||||
const size_t scale_offset,
|
||||
const size_t context_seq_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride,
|
||||
const bool pad_front) {
|
||||
/* Step 1. load scale [head_num, head_size]*/
|
||||
if (scale_bs_stride != 0) {
|
||||
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, head_size * sizeof_(Ts), GDRAM2NRAM,
|
||||
head_size * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
|
||||
}
|
||||
|
||||
/* Step 2. load cache [load_seq_num, head_num, head_size] */
|
||||
int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2);
|
||||
int32_t deal_seq_num = load_seq_num << 1;
|
||||
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size,
|
||||
head_num - 1, scale_num, load_seq_num - 1, cache_head_stride, head_num - 1,
|
||||
cache_seq_stride, load_seq_num - 1);
|
||||
/* Step 3. convert into int8 [load_seq_num, head_num, head_size, 2] */
|
||||
convert((int8_t *)output_nram, (int4x2_t *)input_nram, deal_seq_num * scale_num);
|
||||
/* Step 4. transpose to [deal_seq_num (load_seq_num, 2), head_num, head_size] */
|
||||
trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram,
|
||||
load_seq_num, head_num, head_size, 2);
|
||||
/* Step 5. dequantize [save_seq_num, head_num, head_size] */
|
||||
int save_seq_num = pad_front ? seq_num - 1 : seq_num;
|
||||
dequantize<T, int8_t, Ts>((T *)output_nram, (int8_t *)temp_nram + (pad_front ? scale_num : 0),
|
||||
(Ts *)scale_nram, (Ts *)nbuf, save_seq_num * scale_num, scale_num);
|
||||
/* Step 6. store [save_seq_num, head_num, head_size]*/
|
||||
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
|
||||
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
|
||||
save_seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T),
|
||||
save_seq_num - 1);
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts>
|
||||
__mlu_func__ void dequantize_per_head(T *output_nram,
|
||||
Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
T *temp_nram,
|
||||
T *data,
|
||||
Tc *cache,
|
||||
Ts *scale,
|
||||
const int32_t scale_num,
|
||||
const int32_t seq_num,
|
||||
const int32_t head_num,
|
||||
const int32_t head_size,
|
||||
const size_t context_offset,
|
||||
const size_t cache_offset,
|
||||
const size_t scale_offset,
|
||||
const size_t context_seq_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t scale_head_stride) {
|
||||
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, seq_num * sizeof_(Ts), GDRAM2NRAM,
|
||||
seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
|
||||
if (std::is_same<Tc, int4x2_t>::value) {
|
||||
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size >> 1, GDRAM2NRAM,
|
||||
head_size >> 1, seq_num - 1, seq_num * (head_size >> 1), head_num - 1,
|
||||
cache_seq_stride, seq_num - 1, cache_head_stride, head_num - 1);
|
||||
} else {
|
||||
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size * sizeof_(Tc), GDRAM2NRAM,
|
||||
head_size * sizeof_(Tc), seq_num - 1, seq_num * head_size * sizeof_(Tc), head_num - 1,
|
||||
cache_seq_stride * sizeof_(Tc), seq_num - 1, cache_head_stride * sizeof_(Tc),
|
||||
head_num - 1);
|
||||
}
|
||||
|
||||
convert((float *)output_nram, (Tc *)input_nram, head_num * seq_num * head_size);
|
||||
if (std::is_same<T, float>::value) {
|
||||
conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
|
||||
head_num * seq_num, head_size, 1);
|
||||
} else {
|
||||
conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
|
||||
head_num * seq_num, head_size, 1);
|
||||
output_nram = (T *)temp_nram;
|
||||
}
|
||||
|
||||
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
|
||||
context_seq_stride * sizeof_(T), seq_num - 1, context_head_stride * sizeof_(T),
|
||||
head_num - 1, head_size * sizeof_(T), seq_num - 1, head_size * seq_num * sizeof_(T),
|
||||
head_num - 1);
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts>
|
||||
__mlu_func__ void dequantize_value_per_head(T *output_nram,
|
||||
Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
T *temp_nram,
|
||||
T *data,
|
||||
Tc *cache,
|
||||
Ts *scale,
|
||||
const int32_t scale_num,
|
||||
const int32_t seq_num,
|
||||
const int32_t head_num,
|
||||
const int32_t head_size,
|
||||
const size_t context_offset,
|
||||
const size_t cache_offset,
|
||||
const size_t scale_offset,
|
||||
const size_t context_seq_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride,
|
||||
const bool pad_front) {
|
||||
int32_t load_seq_num = (seq_num >> 1) + int32_t(seq_num % 2);
|
||||
int32_t deal_seq_num = load_seq_num << 1;
|
||||
/* Step1. load scale first, [head_num, deal_seq_num] */
|
||||
__memcpy((Ts *)scale_nram, (Ts *)scale + scale_offset, deal_seq_num * sizeof_(Ts), GDRAM2NRAM,
|
||||
deal_seq_num * sizeof_(Ts), scale_head_stride * sizeof_(Ts), head_num - 1);
|
||||
/* Step2. load cache input, [head_num, load_seq_num, head_size, 2] for int4 */
|
||||
__memcpy((Tc *)input_nram, (Tc *)cache + cache_offset, head_size, GDRAM2NRAM, head_size,
|
||||
load_seq_num - 1, load_seq_num * head_size, head_num - 1, cache_seq_stride,
|
||||
load_seq_num - 1, cache_head_stride, head_num - 1);
|
||||
convert((int8_t *)output_nram, (Tc *)input_nram, head_num * head_size * deal_seq_num);
|
||||
/* Step3. trans to [head_num, load_seq_num, 2, head_size]*/
|
||||
trans_nhwc2nchw_smallc((int8_t *)temp_nram, (int8_t *)output_nram, (uint8_t *)pre_table_nram,
|
||||
head_num * load_seq_num, head_size, 1, 2);
|
||||
/* Step4. dequant to T [head_num, deal_seq_num, head_size] */
|
||||
convert((float *)output_nram, (int8_t *)temp_nram, head_num * deal_seq_num * head_size);
|
||||
if (std::is_same<T, float>::value) {
|
||||
conv_fuse_mul_cvt((T *)output_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
|
||||
head_num * deal_seq_num, head_size, 1);
|
||||
} else {
|
||||
conv_fuse_mul_cvt((T *)temp_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
|
||||
head_num * deal_seq_num, head_size, 1);
|
||||
output_nram = (T *)temp_nram;
|
||||
}
|
||||
|
||||
/* Step5. save [head_num, save_seq_num, head_size]*/
|
||||
int32_t save_seq_num = pad_front ? seq_num - 1 : seq_num;
|
||||
__memcpy((T *)data + context_offset, (T *)output_nram + (pad_front ? head_size : 0),
|
||||
head_size * sizeof_(T), NRAM2GDRAM, context_seq_stride * sizeof_(T), save_seq_num - 1,
|
||||
context_head_stride * sizeof_(T), head_num - 1, head_size * sizeof_(T), save_seq_num - 1,
|
||||
head_size * deal_seq_num * sizeof_(T), head_num - 1);
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
|
||||
__mlu_global__ void MLUDequantFromLinearCacheKernelPerChannel(void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const int32_t *context_lens,
|
||||
const int32_t *context_seq_offsets,
|
||||
const int32_t *cache_bs_id,
|
||||
const int32_t *cache_seq_offsets,
|
||||
const int32_t max_context_len,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t cache_mem_len,
|
||||
const int32_t head_size,
|
||||
const int32_t seq_block,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride) {
|
||||
bool has_key = (key && key_cache && key_scale);
|
||||
bool has_value = (value && value_cache && value_scale);
|
||||
if (!(has_key || has_value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* *********************************nram space **************************************
|
||||
* NRAM |scale[head_num, head_size]|output/input[seq_block, head_num, head_size]|
|
||||
*/
|
||||
int32_t scale_num = head_num * head_size;
|
||||
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
|
||||
float *output_nram = (float *)nbuf + scale_num;
|
||||
Tc *input_nram = (Tc *)output_nram +
|
||||
(std::is_same<Tc, int4x2_t>::value
|
||||
? (7 * seq_block * (scale_num >> 1))
|
||||
: (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)));
|
||||
// temp_nram for nram to store int8 input
|
||||
int8_t *temp_nram = (int8_t *)output_nram + seq_block * scale_num * 3;
|
||||
int32_t seq_offset;
|
||||
int32_t seq_len;
|
||||
int32_t seq_begin;
|
||||
int32_t deal_seq_num;
|
||||
int32_t cache_id;
|
||||
int32_t cache_seq_offset;
|
||||
size_t context_offset;
|
||||
size_t cache_offset;
|
||||
size_t scale_offset;
|
||||
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
|
||||
(int32_t *)context_seq_offsets, batch_size);
|
||||
if (has_key) {
|
||||
load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride,
|
||||
scale_head_stride);
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
calcu_offsets_per_channel(cache_id, context_offset, cache_offset, scale_offset,
|
||||
(int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets, cache_mem_len,
|
||||
seq_len, seq_begin, seq_offset, batch_idx, context_seq_stride,
|
||||
cache_bs_stride, key_cache_seq_stride, scale_bs_stride);
|
||||
if (cache_id < 0) continue;
|
||||
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key,
|
||||
(Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num, head_num,
|
||||
head_size, context_offset, cache_offset, scale_offset,
|
||||
context_seq_stride, context_head_stride, cache_head_stride,
|
||||
key_cache_seq_stride, scale_bs_stride, scale_head_stride);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_value) {
|
||||
if (std::is_same<Tc, int4x2_t>::value) {
|
||||
__reshape_nhwc2nchw_smallc_init<int8_t>(pre_table_nram, 2);
|
||||
}
|
||||
load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride,
|
||||
scale_head_stride);
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
if (std::is_same<Tc, int4x2_t>::value) {
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
cache_id =
|
||||
cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
|
||||
cache_seq_offset = cache_seq_offsets == nullptr
|
||||
? 0
|
||||
: __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
|
||||
// move seq_begin left by 1 when cache_seq_offset is odd
|
||||
seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
if (cache_id >= 0 && cache_seq_offset >= 0 &&
|
||||
(cache_seq_offset + seq_len) <= cache_mem_len) {
|
||||
context_offset =
|
||||
context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0));
|
||||
// value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t
|
||||
cache_offset = cache_bs_stride * cache_id +
|
||||
value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2);
|
||||
scale_offset = scale_bs_stride * cache_id;
|
||||
} else {
|
||||
cache_id = -1;
|
||||
}
|
||||
if (cache_id < 0) continue;
|
||||
dequantize_value_per_channel(
|
||||
(T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (int8_t *)temp_nram, (T *)value,
|
||||
(Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num, head_num, head_size,
|
||||
context_offset, cache_offset, scale_offset, context_seq_stride, context_head_stride,
|
||||
cache_head_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride,
|
||||
seq_begin == -1);
|
||||
} else {
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
calcu_offsets_per_channel(
|
||||
cache_id, context_offset, cache_offset, scale_offset, (int32_t *)cache_bs_id,
|
||||
(int32_t *)cache_seq_offsets, cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx,
|
||||
context_seq_stride, cache_bs_stride, value_cache_seq_stride, scale_bs_stride);
|
||||
if (cache_id < 0) continue;
|
||||
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value,
|
||||
(Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num,
|
||||
head_num, head_size, context_offset, cache_offset, scale_offset,
|
||||
context_seq_stride, context_head_stride, cache_head_stride,
|
||||
value_cache_seq_stride, scale_bs_stride, scale_head_stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
|
||||
__mlu_global__ void MLUDequantFromLinearCacheKernelPerHead(void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const int32_t *context_lens,
|
||||
const int32_t *context_seq_offsets,
|
||||
const int32_t *cache_bs_id,
|
||||
const int32_t *cache_seq_offsets,
|
||||
const int32_t max_context_len,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t cache_mem_len,
|
||||
const int32_t head_size,
|
||||
const int32_t seq_block,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride) {
|
||||
bool has_key = (key && key_cache && key_scale);
|
||||
bool has_value = (value && value_cache && value_scale);
|
||||
if (!(has_key || has_value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* *********************************nram space **************************************
|
||||
* NRAM |scale[seq_block, head_num]|output/input[head_size, seq_block, head_num]|temp|
|
||||
*/
|
||||
int32_t scale_num = seq_block * head_num;
|
||||
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
|
||||
float *output_nram = (float *)nbuf + scale_num;
|
||||
Tc *input_nram = (Tc *)output_nram +
|
||||
(std::is_same<Tc, int4x2_t>::value
|
||||
? (7 * head_size * (scale_num >> 1))
|
||||
: head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc));
|
||||
// temp_nram for nram to store converted output
|
||||
float *temp_nram = (float *)output_nram + head_size * scale_num;
|
||||
int32_t seq_offset;
|
||||
int32_t seq_len;
|
||||
int32_t seq_begin;
|
||||
int32_t deal_seq_num;
|
||||
int32_t cache_id;
|
||||
size_t context_offset;
|
||||
size_t key_cache_offset;
|
||||
size_t value_cache_offset;
|
||||
size_t scale_offset;
|
||||
__bang_write_value((float *)nbuf, head_size * 16, 1.0f);
|
||||
mvNram2WramLT16<float>((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16);
|
||||
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
|
||||
(int32_t *)context_seq_offsets, batch_size);
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
calcu_offsets_per_head(cache_id, context_offset, key_cache_offset, value_cache_offset,
|
||||
scale_offset, (int32_t *)cache_bs_id, (int32_t *)cache_seq_offsets,
|
||||
cache_mem_len, seq_len, seq_begin, seq_offset, batch_idx,
|
||||
context_seq_stride, cache_bs_stride, key_cache_seq_stride,
|
||||
value_cache_seq_stride, scale_bs_stride);
|
||||
if (cache_id < 0) continue;
|
||||
if (has_key) {
|
||||
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
|
||||
(T *)key, (Tc *)key_cache, (Ts *)key_scale, scale_num, deal_seq_num,
|
||||
head_num, head_size, context_offset, key_cache_offset, scale_offset,
|
||||
context_seq_stride, context_head_stride, cache_head_stride,
|
||||
key_cache_seq_stride, scale_head_stride);
|
||||
}
|
||||
if (has_value && std::is_same<Tc, int8_t>::value) {
|
||||
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
|
||||
(T *)value, (Tc *)value_cache, (Ts *)value_scale, scale_num, deal_seq_num,
|
||||
head_num, head_size, context_offset, value_cache_offset, scale_offset,
|
||||
context_seq_stride, context_head_stride, cache_head_stride,
|
||||
value_cache_seq_stride, scale_head_stride);
|
||||
}
|
||||
}
|
||||
|
||||
// process value int4 differently
|
||||
if (has_value && std::is_same<Tc, int4x2_t>::value) {
|
||||
int32_t cache_seq_offset;
|
||||
__reshape_nhwc2nchw_smallc_init<int8_t>(pre_table_nram, 2);
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
cache_id =
|
||||
cache_bs_id == nullptr ? batch_idx : __load_gdram((int32_t *)cache_bs_id + batch_idx);
|
||||
cache_seq_offset =
|
||||
cache_seq_offsets == nullptr ? 0 : __load_gdram((int32_t *)cache_seq_offsets + batch_idx);
|
||||
// move seq_begin left by 1 when cache_seq_offset is odd
|
||||
seq_begin = cache_seq_offset % 2 ? seq_begin - 1 : seq_begin;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
if (cache_id >= 0 && cache_seq_offset >= 0 && (cache_seq_offset + seq_len) <= cache_mem_len) {
|
||||
context_offset =
|
||||
context_seq_stride * (seq_offset + seq_begin + ((seq_begin == -1) ? 1 : 0));
|
||||
// value cache is [max_batch_size, head_num, cache_mem_len/2, head_size] for int4x2_t
|
||||
value_cache_offset = cache_bs_stride * cache_id +
|
||||
value_cache_seq_stride * ((cache_seq_offset + seq_begin) / 2);
|
||||
scale_offset = cache_seq_offset + seq_begin + scale_bs_stride * cache_id;
|
||||
} else {
|
||||
cache_id = -1;
|
||||
}
|
||||
if (cache_id < 0) continue;
|
||||
dequantize_value_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram,
|
||||
(T *)temp_nram, (T *)value, (Tc *)value_cache, (Ts *)value_scale,
|
||||
scale_num, deal_seq_num, head_num, head_size, context_offset,
|
||||
value_cache_offset, scale_offset, context_seq_stride,
|
||||
context_head_stride, cache_head_stride, value_cache_seq_stride,
|
||||
scale_bs_stride, scale_head_stride, seq_begin == -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
#define DEQUANT_LINEAR_INIT(T, Tc, Ts, C, Name) \
|
||||
template __mlu_global__ void kernels::MLUDequantFromLinearCacheKernel##Name<T, Tc, Ts, C>( \
|
||||
void *key, void *value, const void *key_cache, const void *value_cache, \
|
||||
const void *key_scale, const void *value_scale, const int32_t *context_lens, \
|
||||
const int32_t *context_seq_offsets, const int32_t *cache_bs_id, \
|
||||
const int32_t *cache_seq_offsets, const int32_t max_context_len, const int32_t batch_size, \
|
||||
const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \
|
||||
const int32_t cache_mem_len, const int32_t head_size, const int32_t seq_block, \
|
||||
const size_t context_head_stride, const size_t context_seq_stride, \
|
||||
const size_t cache_bs_stride, const size_t cache_head_stride, \
|
||||
const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \
|
||||
const size_t scale_bs_stride, const size_t scale_head_stride);
|
||||
|
||||
DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(half, int8_t, float, false, PerHead)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, false, PerHead)
|
||||
DEQUANT_LINEAR_INIT(float, int8_t, float, false, PerHead)
|
||||
DEQUANT_LINEAR_INIT(half, int4x2_t, float, false, PerHead)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, false, PerHead)
|
||||
DEQUANT_LINEAR_INIT(float, int4x2_t, float, false, PerHead)
|
||||
DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerChannel)
|
||||
DEQUANT_LINEAR_INIT(half, int8_t, float, true, PerHead)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int8_t, float, true, PerHead)
|
||||
DEQUANT_LINEAR_INIT(float, int8_t, float, true, PerHead)
|
||||
DEQUANT_LINEAR_INIT(half, int4x2_t, float, true, PerHead)
|
||||
DEQUANT_LINEAR_INIT(bfloat16_t, int4x2_t, float, true, PerHead)
|
||||
DEQUANT_LINEAR_INIT(float, int4x2_t, float, true, PerHead)
|
||||
|
||||
typedef void (*DequantFromLinearCachePointer)(void *, // key
|
||||
void *, // value
|
||||
const void *, // key_cache
|
||||
const void *, // value_cache
|
||||
const void *, // key_scale
|
||||
const void *, // value_scale
|
||||
const int32_t *, // context_lens
|
||||
const int32_t *, // context_seq_offsets
|
||||
const int32_t *, // cache_bs_id
|
||||
const int32_t *, // cache_seq_offsets
|
||||
const int32_t, // max_context_len
|
||||
const int32_t, // batch_size
|
||||
const int32_t, // head_num
|
||||
const int32_t, // key_group_num
|
||||
const int32_t, // value_group_num
|
||||
const int32_t, // cache_mem_len
|
||||
const int32_t, // head_size
|
||||
const int32_t, // seq_block
|
||||
const size_t, // context_head_stride
|
||||
const size_t, // context_seq_stride
|
||||
const size_t, // cache_bs_stride
|
||||
const size_t, // cache_head_stride
|
||||
const size_t, // key_cache_seq_stride
|
||||
const size_t, // value_cache_seq_stride
|
||||
const size_t, // scale_bs_stride
|
||||
const size_t); // scale_head_stride
|
||||
|
||||
static DequantFromLinearCachePointer DequantFromLinearCacheFuncArr[DEQUANT_FUNC_LEN] = {
|
||||
DEQUANT_LINEAR_PERCHANNEL<half, int8_t, float, false>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int8_t, float, false>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<float, int8_t, float, false>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<half, int4x2_t, float, false>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int4x2_t, float, false>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<float, int4x2_t, float, false>,
|
||||
DEQUANT_LINEAR_PERHEAD<half, int8_t, float, false>,
|
||||
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int8_t, float, false>,
|
||||
DEQUANT_LINEAR_PERHEAD<float, int8_t, float, false>,
|
||||
DEQUANT_LINEAR_PERHEAD<half, int4x2_t, float, false>,
|
||||
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int4x2_t, float, false>,
|
||||
DEQUANT_LINEAR_PERHEAD<float, int4x2_t, float, false>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<half, int8_t, float, true>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int8_t, float, true>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<float, int8_t, float, true>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<half, int4x2_t, float, true>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<bfloat16_t, int4x2_t, float, true>,
|
||||
DEQUANT_LINEAR_PERCHANNEL<float, int4x2_t, float, true>,
|
||||
DEQUANT_LINEAR_PERHEAD<half, int8_t, float, true>,
|
||||
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int8_t, float, true>,
|
||||
DEQUANT_LINEAR_PERHEAD<float, int8_t, float, true>,
|
||||
DEQUANT_LINEAR_PERHEAD<half, int4x2_t, float, true>,
|
||||
DEQUANT_LINEAR_PERHEAD<bfloat16_t, int4x2_t, float, true>,
|
||||
DEQUANT_LINEAR_PERHEAD<float, int4x2_t, float, true>};
|
||||
|
||||
uint32_t getDequantLinearIdx(cnnlDataType_t dtype,
|
||||
int32_t quant_mode,
|
||||
int32_t quant_bit,
|
||||
const void *context_seq_offset) {
|
||||
uint32_t idx = 0;
|
||||
idx += (quant_mode != 0) ? 6 : 0;
|
||||
idx += (quant_bit != 8) ? 3 : 0;
|
||||
idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0;
|
||||
idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0;
|
||||
idx += (context_seq_offset == nullptr) ? 12 : 0;
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
void getBlockAndDimForLinear(int32_t &seq_block,
|
||||
cnrtDim3_t &task_dim,
|
||||
cnrtFunctionType_t &task_type,
|
||||
const int32_t max_context_len,
|
||||
const int32_t head_num,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_size,
|
||||
const int32_t quant_mode,
|
||||
const int32_t quant_bit,
|
||||
const cnnlDataType_t dtype) {
|
||||
int32_t core_dim;
|
||||
int32_t cluster_dim;
|
||||
int32_t nram_size = 480 * 1024;
|
||||
int32_t wram_size = 512 * 1024;
|
||||
int32_t sram_size = 2016 * 1024;
|
||||
getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK);
|
||||
if (quant_mode == 0) {
|
||||
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1);
|
||||
if (quant_bit == 4) {
|
||||
if (seq_block <= 1) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * head_size * sizeof_(float) should be less than "
|
||||
<< (nram_size >> 1) << " when quant_mode is 0." << std::endl;
|
||||
}
|
||||
} else {
|
||||
if (seq_block <= 0) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * head_size * sizeof_(float) should be less than " << nram_size
|
||||
<< " when quant_mode is 0." << std::endl;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half);
|
||||
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) +
|
||||
(int64_t)head_num * head_size * dtype_size));
|
||||
if (quant_bit == 4) {
|
||||
if (seq_block <= 1) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
|
||||
"context_dtype_size) "
|
||||
<< "should be less than " << (nram_size >> 1) << " when quant_mode is 1."
|
||||
<< std::endl;
|
||||
}
|
||||
} else {
|
||||
if (seq_block <= 0) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
|
||||
"context_dtype_size) "
|
||||
<< "should be less than " << nram_size << " when quant_mode is 1." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
/* head_size * 64B put in the wram. */
|
||||
if (head_size * ONE_LINE >= wram_size) {
|
||||
std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than "
|
||||
<< wram_size << " when quant_mode is 1." << std::endl;
|
||||
}
|
||||
}
|
||||
seq_block = std::min(seq_block, max_context_len);
|
||||
if (seq_block > 16 && seq_block < max_context_len) {
|
||||
seq_block = PAD_DOWN(seq_block, 16);
|
||||
}
|
||||
if (quant_bit == 4) {
|
||||
seq_block = PAD_DOWN(seq_block, 2);
|
||||
}
|
||||
int seq_seg = DIV_UP(max_context_len, seq_block);
|
||||
// need an extra seg block to dealwith int4 value_cache [...,seq_len/2, head_size]
|
||||
if (quant_bit == 4) {
|
||||
seq_seg += 1;
|
||||
}
|
||||
uint32_t core_num = cluster_dim * core_dim;
|
||||
if (batch_size * seq_seg <= (core_num / 2)) {
|
||||
int times = core_num / batch_size / seq_seg;
|
||||
seq_block = std::max(seq_block / times, 2);
|
||||
if (quant_bit == 4) {
|
||||
seq_block = PAD_DOWN(seq_block, 2);
|
||||
}
|
||||
seq_seg = DIV_UP(max_context_len, seq_block);
|
||||
// same as above to dealwise int4 value_cache with an extra seg block
|
||||
if (quant_bit == 4) {
|
||||
seq_seg += 1;
|
||||
}
|
||||
}
|
||||
|
||||
task_dim.x = 1;
|
||||
task_dim.y = uint32_t(std::min(batch_size, cluster_dim * core_dim));
|
||||
task_dim.z = uint32_t(seq_seg);
|
||||
task_type = cnrtFuncTypeBlock;
|
||||
}
|
||||
|
||||
KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue,
|
||||
void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const void *context_lens,
|
||||
const void *context_seq_offsets,
|
||||
const void *cache_bs_id,
|
||||
const void *cache_seq_offsets,
|
||||
const int32_t max_context_len,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t cache_mem_len,
|
||||
const int32_t head_size,
|
||||
const int32_t quant_mode,
|
||||
const int32_t quant_bit,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride,
|
||||
const cnnlDataType_t dtype) {
|
||||
if (dtype == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
|
||||
std::cerr << "[invokeDequantFromPagedCache]: "
|
||||
"MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
int32_t index;
|
||||
int32_t seq_block;
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
getBlockAndDimForLinear(seq_block, k_dim, k_type, max_context_len, head_num, batch_size,
|
||||
head_size, quant_mode, quant_bit, dtype);
|
||||
index = getDequantLinearIdx(dtype, quant_mode, quant_bit, context_seq_offsets);
|
||||
auto dequant_linear_func = DequantFromLinearCacheFuncArr[index];
|
||||
dequant_linear_func<<<k_dim, k_type, queue>>>(
|
||||
(void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache,
|
||||
(const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens,
|
||||
(const int32_t *)context_seq_offsets, (const int32_t *)cache_bs_id,
|
||||
(const int32_t *)cache_seq_offsets, max_context_len, batch_size, head_num, key_group_num,
|
||||
value_group_num, cache_mem_len, head_size, seq_block, context_head_stride, context_seq_stride,
|
||||
cache_bs_stride, cache_head_stride, key_cache_seq_stride, value_cache_seq_stride,
|
||||
scale_bs_stride, scale_head_stride);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
108
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mluh
Normal file
108
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_linear_cache.mluh
Normal file
@@ -0,0 +1,108 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
|
||||
#define CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief De-quantizes the key and value tensors from the provided linear cache and scale.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param key: Pointer to the MLU memory that stores the key tensor,
|
||||
* with shape [total_seqlen, head_num, head_size]. Data type can be float32, half,
|
||||
* or bfloat16. This parameter can be nullptr.
|
||||
* @param value: Pointer to the MLU memory that stores the value tensor,
|
||||
* with shape [total_seqlen, head_num, head_size]. Data type can be float32,
|
||||
* half, or bfloat16.This parameter can be nullptr.
|
||||
* @param key_cache: Pointer to the MLU memory that stores the key cache tensor,
|
||||
* with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization
|
||||
* or [max_bs, head_num, cache_mem_len, head_size//2] for 4-bit quantization.
|
||||
* Data type must be int8. This parameter can be nullptr.
|
||||
* @param value_cache: Pointer to the MLU memory that stores the value cache tensor,
|
||||
* with shape [max_batch_size, head_num, cache_mem_len, head_size] for 8-bit quantization
|
||||
* or [max_bs, head_num, cache_mem_len//2, head_size] for 4-bit quantization.
|
||||
* Data type must be int8. This parameter can be nullptr.
|
||||
* @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale.
|
||||
* Shape depends on quantization mode:
|
||||
* - For per-channel quantization (quant_mode = 0): [head_num, head_size].
|
||||
* - For per-token quantization (quant_mode = 1): [max_batch_size, head_num, cache_mem_len].
|
||||
* Data type must be float32. This parameter can be nullptr.
|
||||
* @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale,
|
||||
* with the same shape as key_scale. Data type must be float32. This parameter can be
|
||||
nullptr.
|
||||
* @param context_lens: Pointer to the MLU memory that stores the sequence lengths.
|
||||
* The shape must be [batch].
|
||||
* @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the
|
||||
context.
|
||||
* The shape must be [batch]. If nullptr, the default value is the cumulative sum of
|
||||
context_lengths.
|
||||
* @param cache_bs_id: Pointer to the MLU memory that stores the batch index in the cache.
|
||||
* The shape must be [batch]. If nullptr, the default value is {0, 1, 2, ..., batch - 1}.
|
||||
* @param cache_seq_offset: Pointer to the MLU memory that stores the sequence offset in the cache.
|
||||
* The shape must be [batch]. If nullptr, the default value is 0 for every batch.
|
||||
* @param max_contxt_len: The maximum sequence length of context.
|
||||
* @param batch: Batch size.
|
||||
* @param head_num: Head number.
|
||||
* @param key_group_num: group number of key group-wise quantization.
|
||||
* @param value_group_num: group number of value group-wise quantization.
|
||||
* @param cache_mem_len: The maximum sequence length of cache.
|
||||
* @param head_size: Head size.
|
||||
* @param quant_mode: An integer value indicating the quantization mode:
|
||||
* 0 for per-channel quantization and 1 for per-token quantization.
|
||||
* @param quant_bit: An integer value indicating the quantization bit width:
|
||||
* 8 for 8-bit quantization and 4 for 4-bit quantization.
|
||||
* @param contxt_head_stride: The stride of head_num in context.
|
||||
* @param contxt_seq_stride: The stride of max_contxt_len in context.
|
||||
* @param cache_bs_stride: The stride of batch in cache.
|
||||
* @param cache_head_stride: The stride of head_num in cache.
|
||||
* @param key_cache_seq_stride: The stride of cache_mem_len in key cache.
|
||||
* @param value_cache_seq_stride: The stride of cache_mem_len in value cache.
|
||||
* @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant.
|
||||
* @param cache_scale_head_stride: The stride of head in cache scale.
|
||||
* @param dtype: The data type of the key and value tensors.
|
||||
* @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key.
|
||||
* If any of value/value_cache/value_scale is nullptr, no operation is performed on the value.
|
||||
*/
|
||||
KernelStatus invokeDequantFromLinearCache(cnrtQueue_t queue,
|
||||
void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const void *context_lens,
|
||||
const void *context_seq_offsets,
|
||||
const void *cache_bs_ids,
|
||||
const void *cache_seq_offsets,
|
||||
const int32_t max_context_len,
|
||||
const int32_t batch,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t cache_mem_len,
|
||||
const int32_t head_size,
|
||||
const int32_t quant_mode,
|
||||
const int32_t quant_bit,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t cache_scale_bs_stride,
|
||||
const size_t cache_scale_head_stride,
|
||||
const cnnlDataType_t dtype);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_DEQUANT_FROM_LINEAR_CACHE_MLUH_
|
||||
616
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mlu
Normal file
616
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mlu
Normal file
@@ -0,0 +1,616 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include "dequant_from_paged_cache.mluh"
|
||||
#include "quant_utils.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#pragma bang walign(16)
|
||||
#define REM_FOR_STACK (32 * 1024)
|
||||
#define DEQUANT_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
|
||||
#define DEQUANT_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK)
|
||||
#define DEQUANT_PAGED_PERHEAD kernels::MLUDequantFromPagedCacheKernelPerHead
|
||||
#define DEQUANT_PAGED_PERCHANNEL kernels::MLUDequantFromPagedCacheKernelPerChannel
|
||||
#define DEQUANT_FUNC_LEN (24)
|
||||
#define DEQUANT_BATCH_NUM (1024)
|
||||
|
||||
__wram__ int8_t wbuf[DEQUANT_WRAM_SIZE];
|
||||
__nram__ int8_t nbuf[DEQUANT_NRAM_SIZE];
|
||||
// Uses 8K = 1K * (4 + 4) to process offsets
|
||||
__nram__ int32_t n_lens[DEQUANT_BATCH_NUM];
|
||||
__nram__ int32_t n_offsets[DEQUANT_BATCH_NUM];
|
||||
|
||||
template <typename T, typename Tc, typename Ts>
|
||||
__mlu_func__ void dequantize_per_channel(T *output_nram,
|
||||
Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
T *data,
|
||||
const int32_t scale_num,
|
||||
const int32_t seq_num,
|
||||
const int32_t head_num,
|
||||
const int32_t head_size,
|
||||
const size_t context_offset,
|
||||
const size_t context_seq_stride,
|
||||
const size_t context_head_stride) {
|
||||
dequantize<T, Tc, Ts>((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (Ts *)nbuf,
|
||||
seq_num * scale_num, scale_num);
|
||||
__memcpy((T *)data + context_offset, (T *)output_nram, head_size * sizeof_(T), NRAM2GDRAM,
|
||||
context_head_stride * sizeof_(T), head_num - 1, context_seq_stride * sizeof_(T),
|
||||
seq_num - 1, head_size * sizeof_(T), head_num - 1, scale_num * sizeof_(T), seq_num - 1);
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts>
|
||||
__mlu_func__ void dequantize_per_head(T *output_nram,
|
||||
Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
T *temp_nram,
|
||||
T *data,
|
||||
const int32_t seq_num,
|
||||
const int32_t head_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const size_t context_offset,
|
||||
const size_t context_seq_stride,
|
||||
const size_t context_head_stride) {
|
||||
int block_count = DIV_UP(seq_num, block_size);
|
||||
int rem_token = seq_num % block_size;
|
||||
convert((float *)output_nram, (Tc *)input_nram, block_count * head_num * block_size * head_size);
|
||||
T *res_nram = std::is_same<T, float>::value ? (T *)output_nram : (T *)temp_nram;
|
||||
// dequantize [block_count, head_num, block_size, head_size]
|
||||
conv_fuse_mul_cvt((T *)res_nram, (float *)scale_nram, (float *)wbuf, (float *)output_nram,
|
||||
block_count * head_num * block_size, head_size, 1);
|
||||
// copy to [seq_num, head_num, head_size]
|
||||
int whole_block_count = block_count - int(rem_token > 0);
|
||||
if (whole_block_count) {
|
||||
for (int i = 0; i < head_num; ++i) {
|
||||
// copy from [whole_block_count, i, block_size, head_size]
|
||||
// to [whole_block_count * block_size, 1, head_size]
|
||||
__memcpy((T *)data + context_offset + i * context_head_stride,
|
||||
(T *)res_nram + i * block_size * head_size, head_size * sizeof_(T), NRAM2GDRAM,
|
||||
context_seq_stride * sizeof_(T), block_size - 1,
|
||||
block_size * context_seq_stride * sizeof_(T), whole_block_count - 1,
|
||||
head_size * sizeof_(T), block_size - 1,
|
||||
head_num * block_size * head_size * sizeof_(T), whole_block_count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (rem_token) {
|
||||
// copy from [last, head_num, block_size(rem_token), head_size]
|
||||
// to [rem_token, head_num, head_size]
|
||||
__memcpy((T *)data + context_offset + whole_block_count * block_size * context_seq_stride,
|
||||
(T *)res_nram + whole_block_count * head_num * block_size * head_size,
|
||||
head_size * sizeof_(T), NRAM2GDRAM, context_head_stride * sizeof_(T), head_num - 1,
|
||||
context_seq_stride * sizeof_(T), rem_token - 1, block_size * head_size * sizeof_(T),
|
||||
head_num - 1, head_size * sizeof_(T), rem_token - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tc>
|
||||
__mlu_func__ void load_input_per_channel(Tc *input_nram,
|
||||
Tc *cache,
|
||||
Tc *temp_nram,
|
||||
int32_t *block_offsets,
|
||||
uint32_t *cache_offsets,
|
||||
const int32_t *block_tables,
|
||||
const int32_t batch_idx,
|
||||
const int32_t scale_num,
|
||||
const int32_t max_block_num,
|
||||
const int32_t head_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const int32_t seq_begin,
|
||||
const int32_t deal_seq_num,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride) {
|
||||
int32_t block_start = batch_idx * max_block_num + seq_begin / block_size;
|
||||
int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size;
|
||||
int32_t block_count = block_end - block_start + 1;
|
||||
// make sure elements in block_tables >= 0
|
||||
__memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start,
|
||||
block_count * sizeof_(int32_t), GDRAM2NRAM);
|
||||
__bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets,
|
||||
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
|
||||
#if __BANG_ARCH__ >= 500
|
||||
// gather [block_count, head_num, block_size, head_size]
|
||||
__gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets,
|
||||
(uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM,
|
||||
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
|
||||
if (head_num != 1 && block_size != 1) {
|
||||
// mv to [head_num, whole_block_count, block_size, head_size]
|
||||
__memcpy((Tc *)temp_nram, (Tc *)input_nram, block_size * head_size * sizeof(Tc), NRAM2NRAM,
|
||||
block_size * head_size * sizeof(Tc), block_count - 1,
|
||||
block_count * block_size * head_size * sizeof(Tc), head_num - 1,
|
||||
head_num * block_size * head_size * sizeof(Tc), block_count - 1,
|
||||
block_size * head_size * sizeof(Tc), head_num - 1);
|
||||
// mv to [whole_block_count, block_size, head_num, head_size]
|
||||
__memcpy((Tc *)input_nram, (Tc *)temp_nram, head_size * sizeof(Tc), NRAM2NRAM,
|
||||
head_size * sizeof(Tc), head_num - 1, head_num * head_size * sizeof(Tc),
|
||||
block_count * block_size - 1, block_count * block_size * head_size * sizeof(Tc),
|
||||
head_num - 1, head_size * sizeof(Tc), block_count * block_size - 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tc, typename Ts>
|
||||
__mlu_func__ void load_input_per_head(Tc *input_nram,
|
||||
Ts *scale_nram,
|
||||
Tc *cache,
|
||||
Ts *scale,
|
||||
Tc *temp_nram,
|
||||
int32_t *block_offsets,
|
||||
uint32_t *cache_offsets,
|
||||
uint32_t *scale_offsets,
|
||||
const int32_t *block_tables,
|
||||
const int32_t batch_idx,
|
||||
const int32_t max_block_num,
|
||||
const int32_t head_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const int32_t seq_begin,
|
||||
const int32_t deal_seq_num,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride) {
|
||||
int32_t block_start = batch_idx * max_block_num + seq_begin / block_size;
|
||||
int32_t block_end = batch_idx * max_block_num + (seq_begin + deal_seq_num - 1) / block_size;
|
||||
int32_t block_count = block_end - block_start + 1;
|
||||
// make sure elements in block_tables >= 0
|
||||
__memcpy((int32_t *)block_offsets, (int32_t *)block_tables + block_start,
|
||||
block_count * sizeof_(int32_t), GDRAM2NRAM);
|
||||
__bang_mul_scalar((uint32_t *)cache_offsets, (uint32_t *)block_offsets,
|
||||
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
|
||||
__bang_mul_scalar((uint32_t *)scale_offsets, (uint32_t *)block_offsets,
|
||||
(uint32_t)scale_bs_stride * sizeof(Ts), block_count);
|
||||
#if __BANG_ARCH__ >= 500
|
||||
// gather [block_count, head_num, block_size, head_size]
|
||||
__gather((Tc *)input_nram, (Tc *)cache, (uint32_t *)cache_offsets,
|
||||
(uint32_t)cache_bs_stride * sizeof(Tc), GDRAM2NRAM,
|
||||
(uint32_t)cache_bs_stride * sizeof(Tc), block_count);
|
||||
// gather [block_count, head_num, block_size]
|
||||
__gather((Ts *)scale_nram, (Ts *)scale, (uint32_t *)scale_offsets,
|
||||
(uint32_t)scale_bs_stride * sizeof(Ts), GDRAM2NRAM,
|
||||
(uint32_t)scale_bs_stride * sizeof(Ts), block_count);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
|
||||
__mlu_global__ void MLUDequantFromPagedCacheKernelPerChannel(void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const int32_t *context_lens,
|
||||
const int32_t *context_seq_offsets,
|
||||
const int32_t *block_tables,
|
||||
const int32_t max_context_len,
|
||||
const int32_t max_block_num,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const int32_t seq_block,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride) {
|
||||
bool has_key = (key && key_cache && key_scale);
|
||||
bool has_value = (value && value_cache && value_scale);
|
||||
if (!(has_key || has_value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* *********************************nram space **************************************
|
||||
* NRAM |scale[head_num, head_size] fp32|output/input[seq_block, head_num, head_size] fp32|
|
||||
*/
|
||||
int32_t scale_num = head_num * head_size;
|
||||
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
|
||||
float *output_nram = (float *)nbuf + scale_num;
|
||||
Tc *input_nram = (Tc *)output_nram +
|
||||
(std::is_same<Tc, int4x2_t>::value
|
||||
? (7 * seq_block * (scale_num >> 1))
|
||||
: (seq_block * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc)));
|
||||
int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + seq_block * scale_num);
|
||||
uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size);
|
||||
int32_t seq_offset;
|
||||
int32_t seq_len;
|
||||
int32_t seq_begin;
|
||||
int32_t deal_seq_num;
|
||||
size_t context_offset;
|
||||
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
|
||||
(int32_t *)context_seq_offsets, batch_size);
|
||||
if (has_key) {
|
||||
load_scale_once((Ts *)scale_nram, (Ts *)key_scale, head_num, head_size, scale_bs_stride,
|
||||
scale_head_stride);
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
// seq_begin % block_size != 0 only when seq_block < block_size
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
context_offset = context_seq_stride * (seq_offset + seq_begin);
|
||||
load_input_per_channel((Tc *)input_nram, (Tc *)key_cache, (Tc *)output_nram,
|
||||
(int32_t *)block_offsets, (uint32_t *)cache_offsets,
|
||||
(int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num,
|
||||
block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride,
|
||||
cache_head_stride);
|
||||
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)key,
|
||||
scale_num, deal_seq_num, head_num, head_size, context_offset,
|
||||
context_seq_stride, context_head_stride);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_value) {
|
||||
load_scale_once((Ts *)scale_nram, (Ts *)value_scale, head_num, head_size, scale_bs_stride,
|
||||
scale_head_stride);
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
// seq_begin % block_size != 0 only when seq_block < block_size
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
context_offset = context_seq_stride * (seq_offset + seq_begin);
|
||||
load_input_per_channel((Tc *)input_nram, (Tc *)value_cache, (Tc *)output_nram,
|
||||
(int32_t *)block_offsets, (uint32_t *)cache_offsets,
|
||||
(int32_t *)block_tables, batch_idx, scale_num, max_block_num, head_num,
|
||||
block_size, head_size, seq_begin, deal_seq_num, cache_bs_stride,
|
||||
cache_head_stride);
|
||||
dequantize_per_channel((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)value,
|
||||
scale_num, deal_seq_num, head_num, head_size, context_offset,
|
||||
context_seq_stride, context_head_stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tc, typename Ts, bool ProcessOffsets>
|
||||
__mlu_global__ void MLUDequantFromPagedCacheKernelPerHead(void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const int32_t *context_lens,
|
||||
const int32_t *context_seq_offsets,
|
||||
const int32_t *block_tables,
|
||||
const int32_t max_context_len,
|
||||
const int32_t max_block_num,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const int32_t seq_block,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride) {
|
||||
bool has_key = (key && key_cache && key_scale);
|
||||
bool has_value = (value && value_cache && value_scale);
|
||||
if (!(has_key || has_value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* *********************************nram space **************************************
|
||||
* NRAM |scale[seq_block, head_num] * fp32|output/input[seq_block, head_num, head_size] * fp32|
|
||||
* |temp[seq_block, head_num, head_size] * output_dtype|
|
||||
* WRAM |head_size * 64B|
|
||||
*/
|
||||
int32_t scale_num = seq_block * head_num;
|
||||
Ts *scale_nram = (Ts *)nbuf + (scale_num * (sizeof_(float) - sizeof_(Ts)) / sizeof_(Ts));
|
||||
float *output_nram = (float *)nbuf + scale_num;
|
||||
Tc *input_nram =
|
||||
(Tc *)output_nram + (head_size * scale_num * (sizeof_(float) - sizeof_(Tc)) / sizeof_(Tc));
|
||||
// temp_nram for nram to store temp output
|
||||
float *temp_nram = (float *)output_nram + head_size * scale_num;
|
||||
int32_t *block_offsets = (int32_t *)((int8_t *)output_nram + head_size * scale_num);
|
||||
uint32_t *cache_offsets = (uint32_t *)block_offsets + DIV_UP(seq_block, block_size);
|
||||
uint32_t *scale_offsets = (uint32_t *)cache_offsets + DIV_UP(seq_block, block_size);
|
||||
int32_t seq_len;
|
||||
int32_t seq_begin;
|
||||
int32_t seq_offset;
|
||||
int32_t deal_seq_num;
|
||||
size_t context_offset;
|
||||
__bang_write_value((float *)nbuf, head_size * 16, 1.0f);
|
||||
mvNram2WramLT16<float>((int8_t *)wbuf, (int8_t *)nbuf, head_size, 16, 16);
|
||||
process_offsets<ProcessOffsets>((int32_t *)n_lens, (int32_t *)n_offsets, (int32_t *)context_lens,
|
||||
(int32_t *)context_seq_offsets, batch_size);
|
||||
if (has_key) {
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
context_offset = context_seq_stride * (seq_offset + seq_begin);
|
||||
load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)key_cache, (Ts *)key_scale,
|
||||
(Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets,
|
||||
(uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx,
|
||||
max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num,
|
||||
cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride);
|
||||
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
|
||||
(T *)key, deal_seq_num, head_num, block_size, head_size, context_offset,
|
||||
context_seq_stride, context_head_stride);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_value) {
|
||||
for (int32_t batch_idx = taskIdY; batch_idx < batch_size; batch_idx += taskDimY) {
|
||||
load_len_offset<ProcessOffsets>(seq_len, seq_offset, (int32_t *)n_lens, (int32_t *)n_offsets,
|
||||
(int32_t *)context_lens, (int32_t *)context_seq_offsets,
|
||||
batch_idx);
|
||||
seq_begin = taskIdZ * seq_block;
|
||||
deal_seq_num = std::min(seq_len - seq_begin, seq_block);
|
||||
if (deal_seq_num <= 0 || seq_offset < 0) continue;
|
||||
context_offset = context_seq_stride * (seq_offset + seq_begin);
|
||||
load_input_per_head((Tc *)input_nram, (Ts *)scale_nram, (Tc *)value_cache, (Ts *)value_scale,
|
||||
(Tc *)temp_nram, (int32_t *)block_offsets, (uint32_t *)cache_offsets,
|
||||
(uint32_t *)scale_offsets, (int32_t *)block_tables, batch_idx,
|
||||
max_block_num, head_num, block_size, head_size, seq_begin, deal_seq_num,
|
||||
cache_bs_stride, cache_head_stride, scale_bs_stride, scale_head_stride);
|
||||
dequantize_per_head((T *)output_nram, (Tc *)input_nram, (Ts *)scale_nram, (T *)temp_nram,
|
||||
(T *)value, deal_seq_num, head_num, block_size, head_size, context_offset,
|
||||
context_seq_stride, context_head_stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
#define DEQUANT_PAGED_INIT(T, Tc, Ts, C, Name) \
|
||||
template __mlu_global__ void kernels::MLUDequantFromPagedCacheKernel##Name<T, Tc, Ts, C>( \
|
||||
void *key, void *value, const void *key_cache, const void *value_cache, \
|
||||
const void *key_scale, const void *value_scale, const int32_t *context_lens, \
|
||||
const int32_t *context_seq_offsets, const int32_t *block_tables, \
|
||||
const int32_t max_context_len, const int32_t max_block_num, const int32_t batch_size, \
|
||||
const int32_t head_num, const int32_t key_group_num, const int32_t value_group_num, \
|
||||
const int32_t block_size, const int32_t head_size, const int32_t seq_block, \
|
||||
const size_t context_head_stride, const size_t context_seq_stride, \
|
||||
const size_t cache_bs_stride, const size_t cache_head_stride, \
|
||||
const size_t key_cache_seq_stride, const size_t value_cache_seq_stride, \
|
||||
const size_t scale_bs_stride, const size_t scale_head_stride);
|
||||
|
||||
DEQUANT_PAGED_INIT(half, int8_t, float, false, PerChannel)
|
||||
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerChannel)
|
||||
DEQUANT_PAGED_INIT(float, int8_t, float, false, PerChannel)
|
||||
// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerChannel)
|
||||
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerChannel)
|
||||
// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerChannel)
|
||||
DEQUANT_PAGED_INIT(half, int8_t, float, false, PerHead)
|
||||
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, false, PerHead)
|
||||
DEQUANT_PAGED_INIT(float, int8_t, float, false, PerHead)
|
||||
// DEQUANT_PAGED_INIT(half, int4x2_t, float, false, PerHead)
|
||||
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, false, PerHead)
|
||||
// DEQUANT_PAGED_INIT(float, int4x2_t, float, false, PerHead)
|
||||
DEQUANT_PAGED_INIT(half, int8_t, float, true, PerChannel)
|
||||
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerChannel)
|
||||
DEQUANT_PAGED_INIT(float, int8_t, float, true, PerChannel)
|
||||
// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerChannel)
|
||||
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerChannel)
|
||||
// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerChannel)
|
||||
DEQUANT_PAGED_INIT(half, int8_t, float, true, PerHead)
|
||||
DEQUANT_PAGED_INIT(bfloat16_t, int8_t, float, true, PerHead)
|
||||
DEQUANT_PAGED_INIT(float, int8_t, float, true, PerHead)
|
||||
// DEQUANT_PAGED_INIT(half, int4x2_t, float, true, PerHead)
|
||||
// DEQUANT_PAGED_INIT(bfloat16_t, int4x2_t, float, true, PerHead)
|
||||
// DEQUANT_PAGED_INIT(float, int4x2_t, float, true, PerHead)
|
||||
|
||||
typedef void (*DequantFromPagedCachePointer)(void *, // key
|
||||
void *, // value
|
||||
const void *, // key_cache
|
||||
const void *, // value_cache
|
||||
const void *, // key_scale
|
||||
const void *, // value_scale
|
||||
const int32_t *, // context_lens
|
||||
const int32_t *, // context_seq_offsets
|
||||
const int32_t *, // block_tables
|
||||
const int32_t, // max_context_len
|
||||
const int32_t, // max_block_num
|
||||
const int32_t, // batch_size
|
||||
const int32_t, // head_num
|
||||
const int32_t, // key_group_num
|
||||
const int32_t, // value_group_num
|
||||
const int32_t, // block_size
|
||||
const int32_t, // head_size
|
||||
const int32_t, // seq_block
|
||||
const size_t, // context_head_stride
|
||||
const size_t, // context_seq_stride
|
||||
const size_t, // cache_bs_stride
|
||||
const size_t, // cache_head_stride
|
||||
const size_t, // key_cache_seq_stride
|
||||
const size_t, // value_cache_seq_stride
|
||||
const size_t, // scale_bs_stride
|
||||
const size_t); // scale_head_stride
|
||||
|
||||
static DequantFromPagedCachePointer DequantFromPagedCacheFuncArr[DEQUANT_FUNC_LEN] = {
|
||||
DEQUANT_PAGED_PERCHANNEL<half, int8_t, float, false>,
|
||||
DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int8_t, float, false>,
|
||||
DEQUANT_PAGED_PERCHANNEL<float, int8_t, float, false>, nullptr, nullptr, nullptr,
|
||||
// DEQUANT_PAGED_PERCHANNEL<half, int4x2_t, float, false>,
|
||||
// DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int4x2_t, float, false>,
|
||||
// DEQUANT_PAGED_PERCHANNEL<float, int4x2_t, float, false>,
|
||||
DEQUANT_PAGED_PERHEAD<half, int8_t, float, false>,
|
||||
DEQUANT_PAGED_PERHEAD<bfloat16_t, int8_t, float, false>,
|
||||
DEQUANT_PAGED_PERHEAD<float, int8_t, float, false>, nullptr, nullptr, nullptr,
|
||||
// DEQUANT_PAGED_PERHEAD<half, int4x2_t, float, false>,
|
||||
// DEQUANT_PAGED_PERHEAD<bfloat16_t, int4x2_t, float, false>,
|
||||
// DEQUANT_PAGED_PERHEAD<float, int4x2_t, float, false>,
|
||||
DEQUANT_PAGED_PERCHANNEL<half, int8_t, float, true>,
|
||||
DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int8_t, float, true>,
|
||||
DEQUANT_PAGED_PERCHANNEL<float, int8_t, float, true>, nullptr, nullptr, nullptr,
|
||||
// DEQUANT_PAGED_PERCHANNEL<half, int4x2_t, float, true>,
|
||||
// DEQUANT_PAGED_PERCHANNEL<bfloat16_t, int4x2_t, float, true>,
|
||||
// DEQUANT_PAGED_PERCHANNEL<float, int4x2_t, float, true>,
|
||||
DEQUANT_PAGED_PERHEAD<half, int8_t, float, true>,
|
||||
DEQUANT_PAGED_PERHEAD<bfloat16_t, int8_t, float, true>,
|
||||
DEQUANT_PAGED_PERHEAD<float, int8_t, float, true>, nullptr, nullptr, nullptr};
|
||||
// DEQUANT_PAGED_PERHEAD<half, int4x2_t, float, true>,
|
||||
// DEQUANT_PAGED_PERHEAD<bfloat16_t, int4x2_t, float, true>,
|
||||
// DEQUANT_PAGED_PERHEAD<float, int4x2_t, float, true>};
|
||||
|
||||
uint32_t getDequantPagedIdx(cnnlDataType_t dtype,
|
||||
int32_t quant_mode,
|
||||
int32_t quant_bit,
|
||||
const void *context_seq_offset) {
|
||||
uint32_t idx = 0;
|
||||
idx += (quant_mode != 0) ? 6 : 0;
|
||||
idx += (quant_bit != 8) ? 3 : 0;
|
||||
idx += (dtype == CNNL_DTYPE_BFLOAT16) ? 1 : 0;
|
||||
idx += (dtype == CNNL_DTYPE_FLOAT) ? 2 : 0;
|
||||
idx += (context_seq_offset == nullptr) ? 12 : 0;
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
void getBlockAndDimForPaged(int32_t &seq_block,
|
||||
cnrtDim3_t &task_dim,
|
||||
cnrtFunctionType_t &task_type,
|
||||
const int32_t max_context_len,
|
||||
const int32_t head_num,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_size,
|
||||
const int32_t block_size,
|
||||
const int32_t quant_mode,
|
||||
const int32_t quant_bit,
|
||||
const cnnlDataType_t dtype) {
|
||||
int32_t core_dim;
|
||||
int32_t cluster_dim;
|
||||
int32_t nram_size = 480 * 1024;
|
||||
int32_t wram_size = 512 * 1024;
|
||||
int32_t sram_size = 2016 * 1024;
|
||||
getDeviceCoreAndRam(cluster_dim, core_dim, nram_size, wram_size, sram_size, REM_FOR_STACK);
|
||||
if (quant_mode == 0) {
|
||||
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * head_size * sizeof_(float)) - 1);
|
||||
if (seq_block < block_size) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * head_size * sizeof_(float) should be less than "
|
||||
<< nram_size / block_size << " when quant_mode is 0." << std::endl;
|
||||
}
|
||||
} else {
|
||||
int32_t dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof_(float) : sizeof_(half);
|
||||
seq_block = int32_t((int64_t)nram_size / ((int64_t)head_num * (head_size + 1) * sizeof_(float) +
|
||||
(int64_t)head_num * head_size * dtype_size));
|
||||
if (seq_block < block_size) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * sizeof_(float) + head_num * head_size * (sizeof_(float) + "
|
||||
"context_dtype_size) "
|
||||
<< "should be less than " << nram_size / block_size << " when quant_mode is 1."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
/* head_size * 64B put in the wram. */
|
||||
if (head_size * ONE_LINE >= wram_size) {
|
||||
std::cerr << __func__ << "," << __LINE__ << " head_size * 64 " << "should be less than "
|
||||
<< wram_size << " when quant_mode is 1." << std::endl;
|
||||
}
|
||||
}
|
||||
// seq_block should be a multiply of block_size
|
||||
seq_block = PAD_DOWN(seq_block, block_size);
|
||||
int seq_seg = DIV_UP(max_context_len, seq_block);
|
||||
int32_t core_num = cluster_dim * core_dim;
|
||||
if (batch_size * seq_seg <= (core_num / 2)) {
|
||||
int times = core_num / batch_size / seq_seg;
|
||||
seq_block = std::max(seq_block / times, 2);
|
||||
if (seq_block > block_size) {
|
||||
seq_block = PAD_DOWN(seq_block, block_size);
|
||||
} else {
|
||||
seq_block = block_size;
|
||||
}
|
||||
seq_seg = DIV_UP(max_context_len, seq_block);
|
||||
}
|
||||
|
||||
task_dim.x = 1;
|
||||
task_dim.y = uint32_t(std::min(batch_size, core_num));
|
||||
task_dim.z = uint32_t(seq_seg);
|
||||
task_type = cnrtFuncTypeBlock;
|
||||
}
|
||||
|
||||
KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue,
|
||||
void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const void *context_lens,
|
||||
const void *context_seq_offsets,
|
||||
const void *block_tables,
|
||||
const int32_t max_context_len,
|
||||
const int32_t max_block_num,
|
||||
const int32_t batch_size,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const int32_t quant_mode,
|
||||
const int32_t quant_bit,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t scale_bs_stride,
|
||||
const size_t scale_head_stride,
|
||||
const cnnlDataType_t dtype) {
|
||||
if (is_arch300()) {
|
||||
std::cerr << "[invokeDequantFromPagedCache]: kernel does not support MLU300 devices."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
int32_t index;
|
||||
int32_t seq_block;
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
getBlockAndDimForPaged(seq_block, k_dim, k_type, max_context_len, head_num, batch_size, head_size,
|
||||
block_size, quant_mode, quant_bit, dtype);
|
||||
index = getDequantPagedIdx(dtype, quant_mode, quant_bit, context_seq_offsets);
|
||||
auto dequant_paged_func = DequantFromPagedCacheFuncArr[index];
|
||||
dequant_paged_func<<<k_dim, k_type, queue>>>(
|
||||
(void *)key, (void *)value, (const void *)key_cache, (const void *)value_cache,
|
||||
(const void *)key_scale, (const void *)value_scale, (const int32_t *)context_lens,
|
||||
(const int32_t *)context_seq_offsets, (const int32_t *)block_tables, max_context_len,
|
||||
max_block_num, batch_size, head_num, key_group_num, value_group_num, block_size, head_size,
|
||||
seq_block, context_head_stride, context_seq_stride, cache_bs_stride, cache_head_stride,
|
||||
key_cache_seq_stride, value_cache_seq_stride, scale_bs_stride, scale_head_stride);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
106
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mluh
Normal file
106
torch_mlu_ops-v1.3.2/csrc/kernels/dequant_from_paged_cache.mluh
Normal file
@@ -0,0 +1,106 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
|
||||
#define CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief De-quantizes the key and value tensors from the provided paged cache and scale.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param key: Pointer to the MLU memory that stores the key tensor,
|
||||
* with shape [total_seqlen, head_num, head_size]. Data type can be half
|
||||
* or bfloat16. This parameter can be nullptr.
|
||||
* @param value: Pointer to the MLU memory that stores the value tensor,
|
||||
* with shape [total_seqlen, head_num, head_size]. Data type can be
|
||||
* half or bfloat16.This parameter can be nullptr.
|
||||
* @param key_cache: Pointer to the MLU memory that stores the key cache tensor,
|
||||
* with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization.
|
||||
* Data type must be int8. This parameter can be nullptr.
|
||||
* @param value_cache: Pointer to the MLU memory that stores the value cache tensor,
|
||||
* with shape [total_blocks, head_num, block_size, head_size] for 8-bit quantization.
|
||||
* Data type must be int8. This parameter can be nullptr.
|
||||
* @param key_scale: Pointer to the MLU memory that stores the key cache quantization scale.
|
||||
* Shape depends on quantization mode:
|
||||
* - For per-channel quantization (quant_mode = 0): [head_num, head_size].
|
||||
* - For per-token quantization (quant_mode = 1): [total_blocks, head_num, block_size].
|
||||
* Data type must be float32. This parameter can be nullptr.
|
||||
* @param value_scale: Pointer to the MLU memory that stores the value cache quantization scale,
|
||||
* with the same shape as key_scale. Data type must be float32. This parameter can be
|
||||
nullptr.
|
||||
* @param context_lens: Pointer to the MLU memory that stores the sequence lengths.
|
||||
* The shape must be [batch].
|
||||
* @param context_seq_offset: Pointer to the MLU memory that stores the sequence offset in the
|
||||
context.
|
||||
* The shape must be [batch]. If nullptr, the default value is the cumulative sum of
|
||||
context_lengths.
|
||||
* @param block_tables: Pointer to the MLU memory that stores the block tables for indexing.
|
||||
* The shape must be [batch, max_block_num].
|
||||
* @param max_contxt_len: The maximum sequence length of context.
|
||||
* @param max_block_num: The maximum block number of each batch.
|
||||
* @param batch: Batch size.
|
||||
* @param head_num: Head number.
|
||||
* @param key_group_num: group number of key group-wise quantization.
|
||||
* @param value_group_num: group number of value group-wise quantization.
|
||||
* @param block_size: The block size of the cache.
|
||||
* @param head_size: Head size.
|
||||
* @param quant_mode: An integer value indicating the quantization mode:
|
||||
* 0 for per-channel quantization and 1 for per-token quantization.
|
||||
* @param quant_bit: An integer value indicating the quantization bit width:
|
||||
* 8 for 8-bit quantization.
|
||||
* @param contxt_head_stride: The stride of head_num in context.
|
||||
* @param contxt_seq_stride: The stride of max_contxt_len in context.
|
||||
* @param cache_bs_stride: The stride of batch in cache.
|
||||
* @param cache_head_stride: The stride of head_num in cache.
|
||||
* @param key_cache_seq_stride: The stride of cache_mem_len in key cache.
|
||||
* @param value_cache_seq_stride: The stride of cache_mem_len in value cache.
|
||||
* @param cache_scale_bs_stride: The stride of batch in cache scale, only valid if quant_per_quant.
|
||||
* @param cache_scale_head_stride: The stride of head in cache scale.
|
||||
* @param dtype: The data type of the key and value tensors.
|
||||
|
||||
* @note If any of key/key_cache/key_scale is nullptr, no operation is performed on the key.
|
||||
* If any of value/value_cache/value_scale is nullptr, no operation is performed on the value.
|
||||
*/
|
||||
KernelStatus invokeDequantFromPagedCache(cnrtQueue_t queue,
|
||||
void *key,
|
||||
void *value,
|
||||
const void *key_cache,
|
||||
const void *value_cache,
|
||||
const void *key_scale,
|
||||
const void *value_scale,
|
||||
const void *context_lens,
|
||||
const void *context_seq_offsets,
|
||||
const void *block_tables,
|
||||
const int32_t max_context_len,
|
||||
const int32_t max_block_num,
|
||||
const int32_t batch,
|
||||
const int32_t head_num,
|
||||
const int32_t key_group_num,
|
||||
const int32_t value_group_num,
|
||||
const int32_t block_size,
|
||||
const int32_t head_size,
|
||||
const int32_t quant_mode,
|
||||
const int32_t quant_bit,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t key_cache_seq_stride,
|
||||
const size_t value_cache_seq_stride,
|
||||
const size_t cache_scale_bs_stride,
|
||||
const size_t cache_scale_head_stride,
|
||||
const cnnlDataType_t dtype);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_DEQUANT_FROM_PAGED_CACHE_MLUH_
|
||||
254
torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu
Normal file
254
torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mlu
Normal file
@@ -0,0 +1,254 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "dequantify.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
template <typename T>
|
||||
struct PackValueNum {
|
||||
const static int value = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackValueNum<int4x2_t> {
|
||||
const static int value = 2;
|
||||
};
|
||||
|
||||
__nram__ int8_t nram_buf[(__MLU_NRAM_SIZE__ * 3 / 8 * 1024)];
|
||||
|
||||
__nram__ int8_t nram_buf_scale[8192];
|
||||
|
||||
__mlu_func__ void convert(float *dst, const int8_t *src, int count, float scale) {
|
||||
__bang_int82float(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, scale, count);
|
||||
}
|
||||
|
||||
__mlu_func__ void convert(float *dst, const int4x2_t *src, int count, float scale) {
|
||||
__bang_int42float_rn(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, scale, count);
|
||||
}
|
||||
|
||||
__mlu_func__ void convert(half *dst, const int8_t *src, int count, float scale) {
|
||||
__bang_int82half(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, (half)scale, count);
|
||||
}
|
||||
|
||||
__mlu_func__ void convert(half *dst, const int4x2_t *src, int count, float scale) {
|
||||
__bang_int42half_rn(dst, src, count, 0);
|
||||
__bang_mul_scalar(dst, dst, (half)scale, count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void swap(T *&ping, T *&pong) {
|
||||
T *tmp = ping;
|
||||
ping = pong;
|
||||
pong = tmp;
|
||||
}
|
||||
|
||||
template <typename TDst, typename TSrc>
|
||||
__mlu_global__ void dequantifyPerTensor(void *all_dst,
|
||||
const void *all_src,
|
||||
size_t all_src_count,
|
||||
float scale) {
|
||||
scale = 1.0f / scale;
|
||||
size_t src_per_core = all_src_count / taskDim;
|
||||
size_t src_remain = all_src_count % taskDim;
|
||||
size_t start = taskId * src_per_core + (taskId < src_remain ? taskId : src_remain);
|
||||
const size_t src_count = src_per_core + (taskId < src_remain ? 1 : 0);
|
||||
TDst *dst = reinterpret_cast<TDst *>(all_dst) + start * PackValueNum<TSrc>::value;
|
||||
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + start;
|
||||
|
||||
constexpr int size_unit = sizeof(nram_buf) / 2 / // divide by 2 for ping pong
|
||||
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
|
||||
128; // align to 128
|
||||
constexpr int src_num_unit = size_unit / sizeof(TSrc);
|
||||
constexpr int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
|
||||
int8_t *nram_buf_ping = nram_buf;
|
||||
int8_t *nram_buf_pong = nram_buf + sizeof(nram_buf) / 2;
|
||||
|
||||
TSrc *nram_src_ping = reinterpret_cast<TSrc *>(nram_buf_ping);
|
||||
TDst *nram_dst_ping =
|
||||
reinterpret_cast<TDst *>(nram_buf_ping + static_cast<int>(sizeof(TSrc)) * size_unit);
|
||||
TSrc *nram_src_pong = reinterpret_cast<TSrc *>(nram_buf_pong);
|
||||
TDst *nram_dst_pong =
|
||||
reinterpret_cast<TDst *>(nram_buf_pong + static_cast<int>(sizeof(TSrc)) * size_unit);
|
||||
|
||||
int loop_count = src_count / src_num_unit;
|
||||
int remain_count = src_count % src_num_unit;
|
||||
|
||||
// L
|
||||
__memcpy_async(nram_src_ping, src, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
|
||||
// L C
|
||||
__memcpy_async(nram_src_ping, src + 1 * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
||||
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
|
||||
// L C S
|
||||
for (int i = 0; i < loop_count - 2; ++i) {
|
||||
__memcpy_async(nram_src_ping, src + (i + 2) * src_num_unit, sizeof(TSrc) * src_num_unit,
|
||||
GDRAM2NRAM);
|
||||
__memcpy_async(dst + i * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
|
||||
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
}
|
||||
|
||||
// C S
|
||||
__memcpy_async(dst + (loop_count - 2) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
|
||||
NRAM2GDRAM);
|
||||
convert(nram_dst_pong, nram_src_pong, dst_num_unit, scale);
|
||||
swap(nram_src_ping, nram_src_pong);
|
||||
swap(nram_dst_ping, nram_dst_pong);
|
||||
__sync_io_move_compute();
|
||||
|
||||
// S
|
||||
__memcpy_async(dst + (loop_count - 1) * dst_num_unit, nram_dst_ping, sizeof(TDst) * dst_num_unit,
|
||||
NRAM2GDRAM);
|
||||
|
||||
__sync_io_move_compute();
|
||||
|
||||
if (remain_count > 0) {
|
||||
__memcpy(nram_src_ping, src + loop_count * src_num_unit, sizeof(TSrc) * remain_count,
|
||||
GDRAM2NRAM);
|
||||
convert(nram_dst_ping, nram_src_ping, remain_count * PackValueNum<TSrc>::value, scale);
|
||||
__memcpy(dst + loop_count * dst_num_unit, nram_dst_ping,
|
||||
sizeof(TDst) * remain_count * PackValueNum<TSrc>::value, NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
// does not use a pipeline because per channel is more complicated but it's a one-time operation, so
|
||||
// performance doesn't matter.
|
||||
template <typename TDst, typename TSrc>
|
||||
__mlu_global__ void dequantifyPerChannel(void *all_dst,
|
||||
const void *all_src,
|
||||
int src_ci,
|
||||
int all_co,
|
||||
const void *scale) {
|
||||
const int co_per_core = all_co / taskDim;
|
||||
const int co_remain = all_co % taskDim;
|
||||
const int start_co = taskId * co_per_core + (taskId < co_remain ? taskId : co_remain);
|
||||
const int co_count = co_per_core + (taskId < co_remain ? 1 : 0);
|
||||
assert(co_count <= sizeof(nram_buf_scale) / sizeof(TDst));
|
||||
|
||||
constexpr int size_unit = sizeof(nram_buf) /
|
||||
(sizeof(TSrc) + sizeof(TDst) * PackValueNum<TSrc>::value) / 128 *
|
||||
128; // align to 128
|
||||
// yes, we only deal with 1 channel at a time
|
||||
// no, there's no need to optimize a one-time operation
|
||||
const int src_num_unit = std::min((int)(size_unit / sizeof(TSrc)), src_ci);
|
||||
const int dst_num_unit = src_num_unit * PackValueNum<TSrc>::value;
|
||||
TSrc *const nram_src = reinterpret_cast<TSrc *>(nram_buf);
|
||||
TDst *const nram_dst =
|
||||
reinterpret_cast<TDst *>(nram_buf + static_cast<int>(sizeof(TSrc)) * size_unit);
|
||||
|
||||
const TDst *nram_scale = reinterpret_cast<const TDst *>(nram_buf_scale);
|
||||
|
||||
const int loop_one_channel = src_ci / src_num_unit;
|
||||
const int remain_one_channel = src_ci % src_num_unit;
|
||||
|
||||
for (int o = start_co; o < start_co + co_count; ++o) {
|
||||
const TSrc *src = reinterpret_cast<const TSrc *>(all_src) + o * src_ci;
|
||||
TDst *dst = reinterpret_cast<TDst *>(all_dst) + o * src_ci;
|
||||
const TDst scale_value = 1. / nram_scale[o];
|
||||
for (int i = 0; i < loop_one_channel; ++i) {
|
||||
__memcpy(nram_src, src + i * src_num_unit, sizeof(TSrc) * src_num_unit, GDRAM2NRAM);
|
||||
convert(nram_dst, nram_src, dst_num_unit, scale_value);
|
||||
__memcpy(dst + i * dst_num_unit, nram_dst, sizeof(TDst) * dst_num_unit, NRAM2GDRAM);
|
||||
}
|
||||
if (remain_one_channel > 0) {
|
||||
__memcpy(nram_src, src + loop_one_channel * src_num_unit, sizeof(TSrc) * remain_one_channel,
|
||||
GDRAM2NRAM);
|
||||
convert(nram_dst, nram_src, remain_one_channel * PackValueNum<TSrc>::value, scale_value);
|
||||
__memcpy(dst + loop_one_channel * dst_num_unit, nram_dst,
|
||||
sizeof(TDst) * remain_one_channel * PackValueNum<TSrc>::value, NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
static const std::map<std::pair<int, cnnlDataType_t>,
|
||||
decltype(&kernels::dequantifyPerTensor<half, int4x2_t>)>
|
||||
per_tensor_func_map = {
|
||||
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int4x2_t>},
|
||||
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int4x2_t>},
|
||||
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerTensor<half, int8_t>},
|
||||
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerTensor<float, int8_t>},
|
||||
};
|
||||
|
||||
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
|
||||
const void *src,
|
||||
int src_bitwidth,
|
||||
void *dst,
|
||||
cnnlDataType_t dst_dtype,
|
||||
size_t src_count,
|
||||
float scale) {
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
CNdev dev;
|
||||
cnnlGetDevice(handle, &dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
||||
auto iter = per_tensor_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
|
||||
if (iter == per_tensor_func_map.end()) {
|
||||
std::cerr << "[invokeDequantifyPerTensor]: unsupported src_bitwidth: " << src_bitwidth
|
||||
<< " dst_dtype: " << dst_dtype;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_count, scale);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static const std::map<std::pair<int, cnnlDataType_t>,
|
||||
decltype(&kernels::dequantifyPerChannel<half, int4x2_t>)>
|
||||
per_channel_func_map = {
|
||||
{{4, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int4x2_t>},
|
||||
{{4, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int4x2_t>},
|
||||
{{8, CNNL_DTYPE_HALF}, &kernels::dequantifyPerChannel<half, int8_t>},
|
||||
{{8, CNNL_DTYPE_FLOAT}, &kernels::dequantifyPerChannel<float, int8_t>},
|
||||
};
|
||||
|
||||
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
|
||||
const void *src,
|
||||
int src_bitwidth,
|
||||
void *dst,
|
||||
cnnlDataType_t dst_dtype,
|
||||
int src_ci,
|
||||
int co,
|
||||
const void *scale) {
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
CNdev dev;
|
||||
cnnlGetDevice(handle, &dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
const cnrtDim3_t dim = {.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
||||
auto iter = per_channel_func_map.find(std::make_pair(src_bitwidth, dst_dtype));
|
||||
if (iter == per_channel_func_map.end()) {
|
||||
std::cerr << "[invokeDequantifyPerChannel]: unsupported src_bitwidth: " << src_bitwidth
|
||||
<< " dst_dtype: " << dst_dtype;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
iter->second<<<dim, cnrtFuncTypeUnion1, queue>>>(dst, src, src_ci, co, scale);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
57
torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mluh
Normal file
57
torch_mlu_ops-v1.3.2/csrc/kernels/dequantify.mluh
Normal file
@@ -0,0 +1,57 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_DEQUANTIFY_MLUH_
|
||||
#define CSRC_KERNELS_DEQUANTIFY_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Dequantify per tensor.
|
||||
* @param handle: The handle of cnnl.
|
||||
* @param src: Input. Pointer to the MLU memory that stores the input.
|
||||
* @param src_bitwidth: The bitwidth of input quantized data.
|
||||
* @param dst: Output. Pointer to the MLU memory that stores the output.
|
||||
* @param dst_dtype: The data type of output.
|
||||
* @param src_count: The number of elements in input.
|
||||
* @param scale: The scale for dequantify.
|
||||
*/
|
||||
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
|
||||
const void *src,
|
||||
int src_bitwidth,
|
||||
void *dst,
|
||||
cnnlDataType_t dst_dtype,
|
||||
size_t src_count,
|
||||
float scale);
|
||||
|
||||
/**
|
||||
* @brief Dequantify per channel.
|
||||
* @param handle: The handle of cnnl.
|
||||
* @param src: Input. Pointer to the MLU memory that stores the input.
|
||||
* @param src_bitwidth: The bitwidth of input quantized data.
|
||||
* @param dst: Output. Pointer to the MLU memory that stores the output.
|
||||
* @param dst_dtype: The data type of output.
|
||||
* @param src_ci: The ci of input.
|
||||
* @param co: The co of input.
|
||||
* @param scale: Pointer to the MLU memory that stores the scale for dequantify.
|
||||
*/
|
||||
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
|
||||
const void *src,
|
||||
int src_bitwidth,
|
||||
void *dst,
|
||||
cnnlDataType_t dst_dtype,
|
||||
int src_ci,
|
||||
int co,
|
||||
const void *scale);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_DEQUANTIFY_MLUH_
|
||||
310
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu
Normal file
310
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu
Normal file
@@ -0,0 +1,310 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "embedding.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
#define MAX_UINT32 (4294967295)
|
||||
#define MAX_SINT32 (2147483647)
|
||||
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
||||
|
||||
__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) {
|
||||
int base = total / num;
|
||||
int tail = total - base * num;
|
||||
every = base + (id < tail ? 1 : 0);
|
||||
offset = base * id + (id < tail ? id : tail);
|
||||
}
|
||||
|
||||
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
|
||||
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void embeddingImpl_500(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
};
|
||||
|
||||
int bs_core = 0;
|
||||
int bs_offset = 0;
|
||||
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
||||
// 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex <elem_count> must be divisible
|
||||
// by 8
|
||||
int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) /
|
||||
(input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t));
|
||||
|
||||
int vocab_start = vocab_offset;
|
||||
int vocab_end = vocab_offset + vocab_size - 1;
|
||||
|
||||
T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T)
|
||||
T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T)
|
||||
int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
||||
int *idxs_nram = ones_nram + limit; // limit * sizeof(int)
|
||||
int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int)
|
||||
int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int)
|
||||
uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t)
|
||||
__bang_write_zero(zeros_nram, input_size);
|
||||
__bang_write_zero(zeros_offset_nram, limit);
|
||||
__bang_write_value(ones_nram, limit, 1);
|
||||
|
||||
int repeat = bs_core / limit;
|
||||
int remain = bs_core % limit;
|
||||
|
||||
for (int i = 0; i < repeat + 1; i++) {
|
||||
if ((i == repeat) && (remain == 0)) {
|
||||
return;
|
||||
}
|
||||
int num = (i == repeat) ? remain : limit;
|
||||
int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex
|
||||
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
||||
__sync();
|
||||
__bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num);
|
||||
__bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num);
|
||||
__bang_mul(mask_nram, mask_nram, temp_nram, num);
|
||||
__bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram,
|
||||
num_pad); // gather valid mask
|
||||
__bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask
|
||||
__bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index
|
||||
__bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram,
|
||||
(unsigned int)input_size * sizeof(T), num); // gather offset
|
||||
__sync();
|
||||
__gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T),
|
||||
GDRAM2NRAM, input_size * sizeof(T), num);
|
||||
__gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T),
|
||||
NRAM2NRAM, input_size * sizeof(T), num);
|
||||
__sync();
|
||||
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
||||
num * input_size * sizeof(T), NRAM2GDRAM);
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void write_zero(T *dst, unsigned int elem_count) {
|
||||
__bang_write_zero(dst, elem_count);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
__bang_write_zero(dst, elem_count);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void embeddingImpl_300(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
};
|
||||
|
||||
int bs_core = 0;
|
||||
int bs_offset = 0;
|
||||
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
||||
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
|
||||
limit = PAD_DOWN(limit, 2);
|
||||
int repeat = bs_core / limit;
|
||||
int remain = bs_core % limit;
|
||||
int vocab_start = vocab_offset;
|
||||
int vocab_end = vocab_offset + vocab_size - 1;
|
||||
|
||||
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
|
||||
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
||||
|
||||
for (int i = 0; i < repeat + 1; i++) {
|
||||
if ((i == repeat) && (remain == 0)) {
|
||||
return;
|
||||
}
|
||||
int num = (i == repeat) ? remain : limit;
|
||||
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
||||
__sync();
|
||||
|
||||
int idx1 = idxs_nram[0];
|
||||
int idx2 = idxs_nram[1];
|
||||
bool first = (idx1 >= vocab_start && idx1 <= vocab_end);
|
||||
bool second = (idx2 >= vocab_start && idx2 <= vocab_end);
|
||||
for (int n = 0; n < num / 2 * 2; n += 2) {
|
||||
if (first && second) {
|
||||
__memcpy_async(emb_nram + n * input_size,
|
||||
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T),
|
||||
1);
|
||||
} else if (!first && !second) {
|
||||
write_zero(emb_nram + n * input_size, 2 * input_size);
|
||||
} else if (first && !second) {
|
||||
write_zero(emb_nram + (n + 1) * input_size, input_size);
|
||||
__memcpy_async(emb_nram + n * input_size,
|
||||
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
} else {
|
||||
write_zero(emb_nram + n * input_size, input_size);
|
||||
__memcpy_async(emb_nram + (n + 1) * input_size,
|
||||
filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
idx1 = idxs_nram[n + 2];
|
||||
idx2 = idxs_nram[n + 3];
|
||||
first = (idx1 >= vocab_start && idx1 <= vocab_end);
|
||||
second = (idx2 >= vocab_start && idx2 <= vocab_end);
|
||||
} // copy loop
|
||||
|
||||
// last idx copy
|
||||
if (num % 2 == 1) {
|
||||
if (first) {
|
||||
__memcpy_async(emb_nram + (num - 1) * input_size,
|
||||
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
} else {
|
||||
write_zero(emb_nram + (num - 1) * input_size, input_size);
|
||||
}
|
||||
}
|
||||
__sync();
|
||||
|
||||
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
||||
num * input_size * sizeof(T), NRAM2GDRAM);
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void embeddingImpl_generic(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
};
|
||||
|
||||
int bs_core = 0;
|
||||
int bs_offset = 0;
|
||||
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
||||
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
|
||||
limit = PAD_DOWN(limit, 2);
|
||||
int repeat = bs_core / limit;
|
||||
int remain = bs_core % limit;
|
||||
int vocab_start = vocab_offset;
|
||||
int vocab_end = vocab_offset + vocab_size - 1;
|
||||
|
||||
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
|
||||
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
||||
|
||||
for (int i = 0; i < repeat + 1; i++) {
|
||||
if ((i == repeat) && (remain == 0)) {
|
||||
return;
|
||||
}
|
||||
int num = (i == repeat) ? remain : limit;
|
||||
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
||||
__sync();
|
||||
|
||||
int idx = idxs_nram[0];
|
||||
bool hit = (idx >= vocab_start && idx <= vocab_end);
|
||||
for (int n = 0; n < num; n++) {
|
||||
if (hit) {
|
||||
__memcpy_async(emb_nram + n * input_size,
|
||||
filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
} else {
|
||||
write_zero(emb_nram + n * input_size, input_size);
|
||||
}
|
||||
idx = idxs_nram[n + 1];
|
||||
hit = (idx >= vocab_start && idx <= vocab_end);
|
||||
}
|
||||
__sync();
|
||||
|
||||
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
||||
num * input_size * sizeof(T), NRAM2GDRAM);
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUEmbeddingKernel(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int total_vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
#if __BANG_ARCH__ > 372
|
||||
// __gather index maximum dtype is unsigned int
|
||||
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) {
|
||||
embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
|
||||
} else {
|
||||
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
|
||||
total_seq);
|
||||
}
|
||||
#else
|
||||
// __memcpy 2D src_stride dtype is int
|
||||
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) {
|
||||
embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
|
||||
} else {
|
||||
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
|
||||
total_seq);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeEmbedding(cnrtQueue_t queue,
|
||||
void *filter,
|
||||
void *input_ids,
|
||||
void *output,
|
||||
const cnnlDataType_t dtype,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int total_vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<float *>(filter), (int *)input_ids, static_cast<float *>(output), vocab_offset,
|
||||
vocab_size, total_vocab_size, input_size, total_seq);
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<half *>(filter), (int *)input_ids, static_cast<half *>(output), vocab_offset,
|
||||
vocab_size, total_vocab_size, input_size, total_seq);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<bfloat16_t *>(filter), (int *)input_ids, static_cast<bfloat16_t *>(output),
|
||||
vocab_offset, vocab_size, total_vocab_size, input_size, total_seq);
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
63
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh
Normal file
63
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mluh
Normal file
@@ -0,0 +1,63 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_EMBEDDING_MLUH_
|
||||
#define CSRC_KERNELS_EMBEDDING_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Look up table for ids which greater than vocab_offset and less than
|
||||
* vocab_offset + vocab_size, and write the results back to the position
|
||||
* corresponding to the ids. For ids that are not in the range, write 0
|
||||
* to the corresponding position.
|
||||
* @example
|
||||
* filter:
|
||||
* [[1, 2, 3, 4],
|
||||
* [5, 6, 7, 8],
|
||||
* [4, 3, 2, 1]]
|
||||
* input_ids:
|
||||
* [[1, 5, 6, 7, 8, 9]]
|
||||
* vocab_offset = 5
|
||||
* vocab_size = 3
|
||||
* input_size = 4
|
||||
* total_seq = 6
|
||||
* output:
|
||||
* [[0, 0, 0, 0], [1, 2, 3, 4], [5, 6, 7, 8],
|
||||
* [4, 3, 2, 1], [0, 0, 0, 0], [0, 0, 0, 0]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param filter: Input. Pointer to the MLU memory that stores the embedding table,
|
||||
* the shape must be [vocab_size, input_size].
|
||||
* @param input_ids: Input. Pointer to the MLU memory that stores the token id,
|
||||
* the shape must be [batch, seq].
|
||||
* @param output: Output. Pointer to the MLU memory that stores the output,
|
||||
* the shape must be [batch, seq, input_size].
|
||||
* @param dtype: Data type.
|
||||
* @param vocab_offset: embedding table offset.
|
||||
* @param vocab_size: embedding table size.
|
||||
* @param total_vocab_size: total embedding table size.
|
||||
* @param input_size: embedding dim.
|
||||
* @param total_seq: Total sequence length.
|
||||
*/
|
||||
KernelStatus invokeEmbedding(cnrtQueue_t queue,
|
||||
void *filter,
|
||||
void *input_ids,
|
||||
void *output,
|
||||
const cnnlDataType_t dtype,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int total_vocab_size,
|
||||
int input_size,
|
||||
int total_seq);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_EMBEDDING_MLUH_
|
||||
658
torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mlu
Normal file
658
torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mlu
Normal file
@@ -0,0 +1,658 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "fused_rope.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
using bf16 = bfloat16_t;
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#ifndef PAD_UP
|
||||
#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y))
|
||||
#endif
|
||||
|
||||
#if __BANG_ARCH__ > 500
|
||||
#include <bang_fusor.h>
|
||||
template <typename T>
|
||||
using bang_cycle_fusor = bang::experimental::cycle_fusor<T>;
|
||||
#endif
|
||||
|
||||
#define NRAM_BUFFER_SIZE (480 * 1024)
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__nram__ float nram_mask[256];
|
||||
__nram__ int nram_rope_offsets[256];
|
||||
__nram__ float nram_zeros[1024] = {0.f};
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void toFloat(float *dst, T *src, int num) {
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float(dst, (half *)src, num);
|
||||
} else if (std::is_same<T, bf16>::value) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_bfloat162float(dst, (bf16 *)src, num);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void floatTo(T *dst, float *src, int num) {
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_float2half_rn((half *)dst, src, num);
|
||||
} else if (std::is_same<T, bf16>::value) {
|
||||
__bang_float2bfloat16_rn((bf16 *)dst, src, num);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin,
|
||||
int *cache_seq_offsets_begin,
|
||||
int *slot_mapping_begin,
|
||||
int *nram_k_cache_offsets,
|
||||
int *nram_v_cache_offsets,
|
||||
int *nram_v_onchip_offsets,
|
||||
int *nram_kv_scale_offsets,
|
||||
float *nram_cache_mask,
|
||||
float *nram_zeros,
|
||||
float *nram_temp,
|
||||
int task_deal_batch,
|
||||
int task_begin_batch,
|
||||
int head_num_k,
|
||||
int head_size,
|
||||
int max_decode_len,
|
||||
int block_size,
|
||||
int kv_out_size,
|
||||
int group_num,
|
||||
bool discrete_batch,
|
||||
bool paged_cache,
|
||||
bool mixed_cache) {
|
||||
// 目前先用标量化计算offset,便于理解(性能无影响)
|
||||
int bh = task_deal_batch * head_num_k;
|
||||
if (paged_cache) {
|
||||
int cache_seq_stride = head_size;
|
||||
int cache_head_stride = block_size * head_size;
|
||||
int cache_scale_head_stride = block_size * group_num;
|
||||
int cache_block_stride = head_num_k * cache_head_stride;
|
||||
int cache_scale_block_stride = head_num_k * cache_scale_head_stride;
|
||||
int *nram_slot_mapping = (int *)nram_mask;
|
||||
__memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
|
||||
for (int i = 0; i < task_deal_batch; i++) {
|
||||
int mapping_idx = __load_nram(nram_slot_mapping + i);
|
||||
if (mapping_idx < 0) {
|
||||
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
|
||||
continue;
|
||||
}
|
||||
int block_idx = mapping_idx / block_size;
|
||||
int seq_idx = mapping_idx % block_size;
|
||||
int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride;
|
||||
int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride;
|
||||
int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num;
|
||||
int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size;
|
||||
|
||||
for (int j = 0; j < head_num_k; j++) {
|
||||
__store_nram(
|
||||
nram_k_cache_offsets + i * head_num_k + j,
|
||||
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
|
||||
if (mixed_cache) {
|
||||
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
|
||||
(int)(v_seq_offset + j * cache_head_stride / 2));
|
||||
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
|
||||
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
|
||||
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int *nram_seq_offsets = (int *)nram_mask;
|
||||
int *nram_bs_id = nram_seq_offsets + 32;
|
||||
int cache_seq_stride = head_size;
|
||||
int cache_head_stride = max_decode_len * head_size;
|
||||
int cache_scale_head_stride = max_decode_len * group_num;
|
||||
int cache_bs_stride = head_num_k * cache_head_stride;
|
||||
int cache_scale_bs_stride = head_num_k * cache_scale_head_stride;
|
||||
__memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
|
||||
if (discrete_batch) {
|
||||
__memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
|
||||
}
|
||||
for (int i = 0; i < task_deal_batch; i++) {
|
||||
int bs_idx = __load_nram(nram_bs_id + i);
|
||||
int seq_idx = __load_nram(nram_seq_offsets + i);
|
||||
int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i;
|
||||
int temp_seq_idx = seq_idx;
|
||||
bool masked = temp_bs_idx < 0 || temp_seq_idx < 0;
|
||||
if (masked) {
|
||||
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
|
||||
continue;
|
||||
}
|
||||
int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride;
|
||||
int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num;
|
||||
int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride;
|
||||
int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size;
|
||||
|
||||
for (int j = 0; j < head_num_k; j++) {
|
||||
__store_nram(
|
||||
nram_k_cache_offsets + i * head_num_k + j,
|
||||
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
|
||||
if (mixed_cache) {
|
||||
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
|
||||
(int)(v_seq_offset + j * cache_head_stride / 2));
|
||||
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
|
||||
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
|
||||
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 此处是为了做上scatter指令的 mask,如果bs offset或seq offset小于0则需要mask掉
|
||||
__bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0);
|
||||
__bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8));
|
||||
}
|
||||
|
||||
__mlu_func__ void layernormImpl(float *nram_k,
|
||||
float *norm_params,
|
||||
int task_deal_batch,
|
||||
int k_hidden,
|
||||
float eps) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
float *buffer = nram_k + task_deal_batch * k_hidden;
|
||||
for (int i = 0; i < task_deal_batch; i++) {
|
||||
float *k_ = nram_k + i * k_hidden;
|
||||
__bang_mul(buffer, k_, k_, k_hidden);
|
||||
float mean = __bang_sum(k_, k_hidden);
|
||||
mean = mean / k_hidden;
|
||||
float rstd = __bang_sum(buffer, k_hidden);
|
||||
rstd = rstd / k_hidden - mean * mean;
|
||||
rstd = rstd < 0 ? eps : rstd + eps;
|
||||
rstd = 1.f / std::sqrt(rstd);
|
||||
__bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden);
|
||||
}
|
||||
__bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden,
|
||||
task_deal_batch * k_hidden, k_hidden);
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void foldRotaryImpl(float *nram_qk,
|
||||
float *nram_qk_rot,
|
||||
float *nram_table,
|
||||
int task_deal_batch,
|
||||
int head_num_qk,
|
||||
int head_size) {
|
||||
int rotary_low_dim = task_deal_batch * head_size;
|
||||
__bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim);
|
||||
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size);
|
||||
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size,
|
||||
head_num_qk * rotary_low_dim, rotary_low_dim);
|
||||
__bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void quantify(T *input,
|
||||
float *float_input,
|
||||
void *output_hp,
|
||||
void *output_lp,
|
||||
float *nram_trans,
|
||||
float *scale_hp,
|
||||
float *scale_lp,
|
||||
float *scale_lp_temp,
|
||||
int batch,
|
||||
int head_num,
|
||||
int head_size,
|
||||
int group_num,
|
||||
int group_size,
|
||||
bool quant_kv_hp,
|
||||
bool mixed_cache) {
|
||||
if (quant_kv_hp) {
|
||||
int hidden = head_num * head_size;
|
||||
int bh = batch * head_num;
|
||||
toFloat<T>(float_input, input, batch * hidden);
|
||||
__bang_recip(scale_hp, scale_hp, hidden);
|
||||
#if __BANG_ARCH__ > 500
|
||||
__asm__ __volatile__(
|
||||
"fuse.nram.crn.s8.f32 "
|
||||
"[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])"
|
||||
";\n\t" ::[dst] "r"(output_hp),
|
||||
[num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input),
|
||||
[src1] "r"(scale_hp), [pos] "i"(0));
|
||||
#endif
|
||||
if (mixed_cache) {
|
||||
__bang_transpose(nram_trans, float_input, bh * group_num, group_size);
|
||||
__bang_abs(float_input, nram_trans, bh * head_size);
|
||||
__bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1,
|
||||
1);
|
||||
__bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num);
|
||||
__bang_recip(scale_lp_temp, scale_lp, bh * group_num);
|
||||
__bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size,
|
||||
bh * group_num);
|
||||
__bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0);
|
||||
__bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void fuseRopeImpl(T *input,
|
||||
void *key_cache_hp,
|
||||
void *value_cache_hp,
|
||||
void *key_cache_lp,
|
||||
void *value_cache_lp,
|
||||
T *sin_table,
|
||||
T *cos_table,
|
||||
int *rope_offsets,
|
||||
T *gamma,
|
||||
T *beta,
|
||||
float *key_scale_hp,
|
||||
float *value_scale_hp,
|
||||
float *key_scale_lp,
|
||||
float *value_scale_lp,
|
||||
int *cache_bs_id_hp,
|
||||
int *cache_seq_offsets_hp,
|
||||
int *cache_bs_id_lp,
|
||||
int *cache_seq_offsets_lp,
|
||||
int *slot_mapping_hp,
|
||||
int *slot_mapping_lp,
|
||||
int rotary_stride,
|
||||
int task_deal_batch,
|
||||
int task_begin_batch,
|
||||
int head_num_q,
|
||||
int head_num_k,
|
||||
int head_size,
|
||||
int max_decode_len_hp,
|
||||
int max_decode_len_lp,
|
||||
int block_size_hp,
|
||||
int block_size_lp,
|
||||
int group_size,
|
||||
int batch_cap,
|
||||
float eps) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
/*
|
||||
由于需要支持mixed cache,kernel支持cache的功能组合比较多,现规定只存在以下几种:
|
||||
1.只存在hp_cache的情况(通过lp tensor不为0判断),cache支持bf16,fp16,量化下支持离线perchannel int8
|
||||
支持linear和paged,key和value cache形状一致,key/value_scale_hp形状为[head_num, head_size]
|
||||
2.mixed cache的情况hp支持离线perchannel int8量化,支持linear和paged,key和value cache形状一致,
|
||||
key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化,
|
||||
key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2]
|
||||
paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2,
|
||||
head_size],paged cache形状为 [num_blocks, head_num_k, block_size / 2,
|
||||
head_size],key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp,
|
||||
group_num],paged_cache 为 [num_blocks, head_num_k, block_size, group_num]
|
||||
*/
|
||||
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
|
||||
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
|
||||
bool discrete_batch_hp = cache_bs_id_hp != nullptr;
|
||||
bool discrete_batch_lp = cache_bs_id_lp != nullptr;
|
||||
bool paged_cache_hp = slot_mapping_hp != nullptr;
|
||||
bool paged_cache_lp = slot_mapping_lp != nullptr;
|
||||
int head_num_qk = head_num_q + head_num_k;
|
||||
int head_num_qkv = head_num_q + head_num_k * 2;
|
||||
int qkv_hidden = head_num_qkv * head_size;
|
||||
int qk_hidden = head_num_qk * head_size;
|
||||
int q_hidden = head_num_q * head_size;
|
||||
int k_hidden = head_num_k * head_size;
|
||||
int float_size = sizeof(float);
|
||||
int dtype_size = sizeof(T);
|
||||
int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size;
|
||||
int group_num = mixed_cache ? head_size / group_size : 1;
|
||||
|
||||
// task ddr offset
|
||||
T *input_begin = input + task_begin_batch * qkv_hidden;
|
||||
int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch;
|
||||
int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch;
|
||||
int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch;
|
||||
|
||||
// nram_buffer
|
||||
float *nram_qk = (float *)nram_buffer;
|
||||
float *nram_qk_rot = nram_qk + batch_cap * qk_hidden;
|
||||
float *nram_v = nram_qk_rot + batch_cap * qk_hidden;
|
||||
float *nram_kv_trans = nram_v + batch_cap * k_hidden;
|
||||
float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden;
|
||||
float *norm_params = nram_table + 2 * batch_cap * head_size;
|
||||
float *nram_k_scale_hp = norm_params + 2 * head_size;
|
||||
float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden;
|
||||
float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden;
|
||||
float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num;
|
||||
int8_t *nram_kv_hp =
|
||||
(int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num);
|
||||
int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden;
|
||||
int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden;
|
||||
int *nram_kv_cache_offsets_hp =
|
||||
(int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2);
|
||||
int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k;
|
||||
int *nram_v_cache_offsets_lp =
|
||||
nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
|
||||
int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
|
||||
int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k;
|
||||
float *cache_mask_hp =
|
||||
(float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k);
|
||||
float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8);
|
||||
|
||||
// 这里将qk和qk_rot放在一起是为升位宽可以一起做,减少指令,同样还有sincostable和norm的gamma和beta
|
||||
T *qk_in = (T *)nram_qk_rot;
|
||||
T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden);
|
||||
T *v_in =
|
||||
(T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden);
|
||||
T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size);
|
||||
T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size);
|
||||
int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t);
|
||||
|
||||
// 生成 kv cache的offset和mask,供scatter kv到kvcache使用
|
||||
genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp,
|
||||
nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp,
|
||||
nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k,
|
||||
head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1,
|
||||
discrete_batch_hp, paged_cache_hp, false);
|
||||
if (mixed_cache) {
|
||||
int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch;
|
||||
int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch;
|
||||
int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch;
|
||||
genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp,
|
||||
nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets,
|
||||
nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch,
|
||||
task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp,
|
||||
1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache);
|
||||
}
|
||||
|
||||
/*
|
||||
-----------------------
|
||||
load v |
|
||||
-----------------------
|
||||
load qk | quant v
|
||||
-----------------------
|
||||
store v | rope qk
|
||||
-----------------------
|
||||
store_q | layernorm k
|
||||
| quant k
|
||||
-----------------------
|
||||
store k |
|
||||
*/
|
||||
// prepare v v_scale cache_v rope_offset
|
||||
__memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM,
|
||||
k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1);
|
||||
if (quant_kv_hp) {
|
||||
__memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM);
|
||||
__memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM);
|
||||
}
|
||||
__memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
__sync_io();
|
||||
if (mixed_cache) {
|
||||
__gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp,
|
||||
head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t),
|
||||
task_deal_batch * head_num_k);
|
||||
}
|
||||
__bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size,
|
||||
task_deal_batch);
|
||||
__sync_compute();
|
||||
/*==============================================================================================*/
|
||||
|
||||
// load_qk,rope_table | quant v
|
||||
__memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
|
||||
task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1,
|
||||
qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size,
|
||||
head_num_qk - 1);
|
||||
__gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size,
|
||||
GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
|
||||
__gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets,
|
||||
head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
|
||||
__memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM);
|
||||
__memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM);
|
||||
|
||||
int8_t *nram_temp = (int8_t *)nram_qk;
|
||||
if (mixed_cache) {
|
||||
__bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0,
|
||||
0);
|
||||
__bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2);
|
||||
}
|
||||
quantify<T>(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp,
|
||||
nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num,
|
||||
group_size, quant_kv_hp, mixed_cache);
|
||||
if (mixed_cache) {
|
||||
__scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp,
|
||||
head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t),
|
||||
task_deal_batch * head_num_k);
|
||||
__bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden);
|
||||
__bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0,
|
||||
0);
|
||||
}
|
||||
__sync_io_move_compute();
|
||||
/*==============================================================================================*/
|
||||
|
||||
// rope | store v
|
||||
// 将qk的左右部分交换,用于生成qk_rot
|
||||
__memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM,
|
||||
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
|
||||
__memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM,
|
||||
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
|
||||
toFloat<T>(nram_qk, qk_in, 2 * batch_cap * qk_hidden);
|
||||
toFloat<T>(nram_table, table_in, 2 * task_deal_batch * head_size);
|
||||
toFloat<T>(norm_params, norm_params_in, 2 * head_size);
|
||||
__bang_write_value(nram_mask, head_size / 2, (float)-1);
|
||||
__bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1);
|
||||
foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size);
|
||||
floatTo<T>((T *)nram_qk, nram_qk, task_deal_batch * q_hidden);
|
||||
|
||||
int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v;
|
||||
__scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp,
|
||||
cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
|
||||
head_num_k * task_deal_batch);
|
||||
if (mixed_cache) {
|
||||
__scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp,
|
||||
cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM,
|
||||
head_size * sizeof(int8_t), head_num_k * task_deal_batch);
|
||||
__scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets,
|
||||
cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
|
||||
head_num_k * task_deal_batch);
|
||||
}
|
||||
__sync_io_move_compute();
|
||||
/*==============================================================================================*/
|
||||
// layernrom k quant k | store q
|
||||
// 从qk的nram buffer中提取出k做layernorm和量化
|
||||
float *nram_k = nram_qk_rot;
|
||||
__memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM,
|
||||
head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1,
|
||||
task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size,
|
||||
task_deal_batch - 1);
|
||||
layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps);
|
||||
quantify<float>(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp,
|
||||
nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k,
|
||||
head_size, group_num, group_size, quant_kv_hp, mixed_cache);
|
||||
if (mixed_cache) {
|
||||
__bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0);
|
||||
}
|
||||
if (!quant_kv_hp) {
|
||||
floatTo<T>((T *)nram_k, nram_k, task_deal_batch * k_hidden);
|
||||
}
|
||||
|
||||
// store q
|
||||
__memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size,
|
||||
task_deal_batch - 1, head_size * dtype_size, head_num_q - 1,
|
||||
head_size * dtype_size, task_deal_batch - 1,
|
||||
task_deal_batch * head_size * dtype_size, head_num_q - 1);
|
||||
// ===============================================================================================
|
||||
|
||||
int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k;
|
||||
__scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp,
|
||||
head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
|
||||
head_num_k * task_deal_batch);
|
||||
if (mixed_cache) {
|
||||
__scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp,
|
||||
head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t),
|
||||
head_num_k * task_deal_batch);
|
||||
__scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp,
|
||||
group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
|
||||
head_num_k * task_deal_batch);
|
||||
}
|
||||
__sync_io_move_compute();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUFuseRope(T *input,
|
||||
void *key_cache_hp,
|
||||
void *value_cache_hp,
|
||||
void *key_cache_lp,
|
||||
void *value_cache_lp,
|
||||
T *sin_table,
|
||||
T *cos_table,
|
||||
int *rope_offsets,
|
||||
T *gamma,
|
||||
T *beta,
|
||||
float *key_scale_hp,
|
||||
float *value_scale_hp,
|
||||
float *key_scale_lp,
|
||||
float *value_scale_lp,
|
||||
int *cache_bs_id_hp,
|
||||
int *cache_seq_offsets_hp,
|
||||
int *cache_bs_id_lp,
|
||||
int *cache_seq_offsets_lp,
|
||||
int *slot_mapping_hp,
|
||||
int *slot_mapping_lp,
|
||||
int rotary_stride,
|
||||
int batch,
|
||||
int head_num_q,
|
||||
int head_num_k,
|
||||
int head_size,
|
||||
int max_decode_len_hp,
|
||||
int max_decode_len_lp,
|
||||
int block_size_hp,
|
||||
int block_size_lp,
|
||||
int group_size,
|
||||
int batch_cap,
|
||||
int task_avg_batch,
|
||||
float eps) {
|
||||
int task_begin_batch = taskId * task_avg_batch;
|
||||
int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch);
|
||||
if (task_deal_batch <= 0 || __is_mpu()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap;
|
||||
int once_batch = (task_deal_batch + task_loop - 1) / task_loop;
|
||||
|
||||
for (int i = 0; i < task_loop; i++) {
|
||||
int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch);
|
||||
int batch_offset = task_begin_batch + once_batch * i;
|
||||
|
||||
fuseRopeImpl<T>(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table,
|
||||
cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp,
|
||||
key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp,
|
||||
cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp,
|
||||
rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size,
|
||||
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size,
|
||||
once_batch, eps);
|
||||
}
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeFusedRope(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *key_cache_hp,
|
||||
void *value_cache_hp,
|
||||
void *key_cache_lp,
|
||||
void *value_cache_lp,
|
||||
const void *sin_table,
|
||||
const void *cos_table,
|
||||
const void *rope_offsets,
|
||||
const void *gamma,
|
||||
const void *beta,
|
||||
const void *key_scale_hp,
|
||||
const void *value_scale_hp,
|
||||
void *key_scale_lp,
|
||||
void *value_scale_lp,
|
||||
const void *cache_bs_id_hp,
|
||||
const void *cache_seq_offsets_hp,
|
||||
const void *cache_bs_id_lp,
|
||||
const void *cache_seq_offsets_lp,
|
||||
const void *slot_mapping_hp,
|
||||
const void *slot_mapping_lp,
|
||||
int rotary_stride,
|
||||
int batch_size,
|
||||
int head_num_q,
|
||||
int head_num_kv,
|
||||
int head_size,
|
||||
int max_decode_len_hp,
|
||||
int max_decode_len_lp,
|
||||
int block_size_hp,
|
||||
int block_size_lp,
|
||||
int group_size,
|
||||
cnnlDataType_t dtype,
|
||||
float eps) {
|
||||
if (is_arch300()) {
|
||||
std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
uint32_t taskdimx = cluster_num * core_num;
|
||||
|
||||
int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx;
|
||||
int float_size = sizeof(float);
|
||||
int group_num = head_size / group_size;
|
||||
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
|
||||
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
|
||||
|
||||
int nram_avalible_bytes = 480 * 1024;
|
||||
int task_max_batch = 32;
|
||||
int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1);
|
||||
int nram_params_bytes = 2 * head_size * float_size;
|
||||
int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size;
|
||||
int nram_remain_bytes =
|
||||
nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes;
|
||||
int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2;
|
||||
int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1);
|
||||
int nram_table_bytes = 2 * head_size * float_size;
|
||||
int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size;
|
||||
int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size;
|
||||
int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size;
|
||||
int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2;
|
||||
int nram_cache_offsets_hp = head_num_kv * sizeof(int);
|
||||
int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int);
|
||||
int batch_cap =
|
||||
nram_remain_bytes /
|
||||
(nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes +
|
||||
nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp);
|
||||
batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch;
|
||||
|
||||
cnrtDim3_t dim{taskdimx, 1, 1};
|
||||
if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
|
||||
(half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta,
|
||||
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
|
||||
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
|
||||
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
|
||||
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
|
||||
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
|
||||
task_avg_batch, eps);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
|
||||
(bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta,
|
||||
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
|
||||
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
|
||||
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
|
||||
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
|
||||
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
|
||||
task_avg_batch, eps);
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
119
torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mluh
Normal file
119
torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mluh
Normal file
@@ -0,0 +1,119 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
|
||||
#define CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Apply query and kery rotary embedding, key layernorm and
|
||||
* quantize key and value to kv cache.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param input: Input/Output. Pointer to the MLU memory that stores the input,
|
||||
* the shape must be [batch_size, 1, head_num_q + head_num_kv * 2, head_size].
|
||||
* @param key_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision key
|
||||
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
||||
* head_num_kv, block_size, head_size].
|
||||
* @param value_cache_hp: Input/Output. Pointer to the MLU memory that stores the high precision
|
||||
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
||||
* head_num_kv, block_size, head_size].
|
||||
* @param key_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision key
|
||||
* cache , the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
||||
* head_num_kv, block_size, head_size].
|
||||
* @param value_cache_lp: Input/Output. Pointer to the MLU memory that stores the low precision
|
||||
* value cache, the shape must be [max_bs, head_num_kv, max_decode_len, head_size] or [num_blocks,
|
||||
* head_num_kv, block_size, head_size].
|
||||
* @param sin_table: Input. Pointer to the MLU memory that stores the sin value, may not be
|
||||
* continous. The shape must be [rotary_seq_len, rotary_dim].
|
||||
* @param cos_table: Input. Pointer to the MLU memory that stores the cos value, may not be
|
||||
* continous. The shape must be [rotary_seq_len, rotary_dim].
|
||||
* @param seq_offsets: Input. Pointer to the MLU memory that stores the sequene offsets of each
|
||||
* batch. The shape must be [batch].
|
||||
* @param norm_gamma: Input. Pointer to the MLU memory that stores the gamma param of layernorm.
|
||||
* @param norm_beta: Input. Pointer to the MLU memory that stores the beta param of layernorm.
|
||||
* @param key_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
|
||||
* key. The shape must be [head_num_kv, head_size]. If key_scale is nullptr,
|
||||
* that means key do not need to be quantized.
|
||||
* @param value_scale_hp: Input. Pointer to the MLU memory that stores the scales of high precision
|
||||
* value. The shape must be [head_num_kv, head_size]. If value_scale is nullptr,
|
||||
* that means value do not need to be quantized.
|
||||
* @param key_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
|
||||
* precision key. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
|
||||
* [num_blocks, head_num_kv, block_size, group_num].
|
||||
* @param value_scale_lp: Input/Output. Pointer to the MLU memory that stores the scales of low
|
||||
* precision value. The shape must be [batch_size, head_num_kv, max_deocde_len, group_num] or
|
||||
* [num_blocks, head_num_kv, block_size, group_num].
|
||||
* @param cache_bs_id_hp: Input. Pointer to the MLU memory that stores the batch
|
||||
* offset of high precision cache, the shape must be [batch], if it's nullptr, the
|
||||
* default value is {0, 1, 2 ... batch - 1}.
|
||||
* @param cache_seq_offsets_hp: Input. Pointer to the MLU memory that stores the sequence
|
||||
* offset of high precision cache, the shape must be [batch].
|
||||
* @param cache_bs_id_lp: Input. Pointer to the MLU memory that stores the batch
|
||||
* offset of low precision cache, the shape must be [batch], if it's nullptr, the
|
||||
* default value is {0, 1, 2 ... batch - 1}.
|
||||
* @param cache_seq_offsets_lp: Input. Pointer to the MLU memory that stores the sequence
|
||||
* offset of low precision cache, the shape must be [batch].
|
||||
* @param slot_mapping_hp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
|
||||
* which has shape [batch]. Data type of slot mapping must be int32_t.
|
||||
* @param slot_mapping_lp: Input. Pointer to the MLU memory that stores the slot_mapping tensor
|
||||
* which has shape [batch]. Data type of slot mapping must be int32_t.
|
||||
* @param rotary_stride: The stride of rotary_seq_len in sin_table and cos_table.
|
||||
* @param batch_size: Batch size.
|
||||
* @param head_num_q: Head number of query.
|
||||
* @param head_num_kv: Head number of key and value.
|
||||
* @param head_size: Head size. For simplify, the rotary dim must be the same as head_size.
|
||||
* @param max_decode_len_hp: The maximum sequence length of high precision cache.
|
||||
* @param max_decode_len_lp: The maximum sequence length of low precision cache.
|
||||
* @param block_size_hp: Number of tokens per block of high precision cache.
|
||||
* @param block_size_lp: Number of tokens per block of low precision cache.
|
||||
* @param data_type: Data type of all inputs and outputs.
|
||||
* @param eps: float number use for layernorm.
|
||||
* @note: Head_num_q and head_num_kv must be in range [1, 32].
|
||||
* Head_size must be in range [1, 128], and must be divided by 2.
|
||||
*/
|
||||
KernelStatus invokeFusedRope(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *key_cache_hp,
|
||||
void *value_cache_hp,
|
||||
void *key_cache_lp,
|
||||
void *value_cache_lp,
|
||||
const void *sin_table,
|
||||
const void *cos_table,
|
||||
const void *rope_offsets,
|
||||
const void *gamma,
|
||||
const void *beta,
|
||||
const void *key_scale_hp,
|
||||
const void *value_scale_hp,
|
||||
void *key_scale_lp,
|
||||
void *value_scale_lp,
|
||||
const void *cache_bs_id_hp,
|
||||
const void *cache_seq_offsets_hp,
|
||||
const void *cache_bs_id_lp,
|
||||
const void *cache_seq_offsets_lp,
|
||||
const void *slot_mapping_hp,
|
||||
const void *slot_mapping_lp,
|
||||
int rotary_stride,
|
||||
int batch_size,
|
||||
int head_num_q,
|
||||
int head_num_kv,
|
||||
int head_size,
|
||||
int max_decode_len_hp,
|
||||
int max_decode_len_lp,
|
||||
int block_size_hp,
|
||||
int block_size_lp,
|
||||
int group_size,
|
||||
cnnlDataType_t dtype,
|
||||
float eps);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_FUSE_ROPE_FUSE_ROPE_MLUH_
|
||||
130
torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mlu
Normal file
130
torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mlu
Normal file
@@ -0,0 +1,130 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "generate_alibi_slope.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
||||
|
||||
__nram__ float range_1[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
|
||||
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64};
|
||||
|
||||
__nram__ float range_2[64] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25,
|
||||
27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51,
|
||||
53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77,
|
||||
79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103,
|
||||
105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127};
|
||||
|
||||
__mlu_func__ void genRange(float *range_nram,
|
||||
float *range_base_nram,
|
||||
int fill_num,
|
||||
int base_num,
|
||||
int offset = 0) {
|
||||
int loop = (fill_num + base_num - 1) / base_num;
|
||||
for (int i = 0; i < loop; i++) {
|
||||
int num = std::min((fill_num - i * base_num), base_num);
|
||||
float *fill_nram = range_nram + i * base_num;
|
||||
__bang_move(fill_nram, range_base_nram, num * sizeof(float));
|
||||
__bang_add_scalar(fill_nram, fill_nram, i * base_num + offset, num);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUAlibiSlopeKernel(float *alibi_slopes,
|
||||
int *true_seq_lens,
|
||||
int batch_num,
|
||||
int head_start,
|
||||
int head_num,
|
||||
int head_num_total,
|
||||
int max_sequence_length,
|
||||
bool use_dynamic,
|
||||
int closest_power_of_2,
|
||||
int farthest_power_of_2,
|
||||
float base,
|
||||
float extra_base) {
|
||||
float *range_nram = (float *)nram_buffer;
|
||||
float *base_nram = range_nram + head_num;
|
||||
float *slope_nram = base_nram + head_num;
|
||||
|
||||
float scale = 1.0;
|
||||
float dynamic_base = base;
|
||||
if (use_dynamic) {
|
||||
float a0 = 1.0;
|
||||
float a = a0 * true_seq_lens[taskIdX] / max_sequence_length;
|
||||
a = std::max(a, 1.0f);
|
||||
scale = powf(a, (1.0 / (head_num_total - 1)));
|
||||
dynamic_base = base / scale;
|
||||
}
|
||||
|
||||
int close_head_num = 0;
|
||||
if (head_start >= closest_power_of_2) {
|
||||
close_head_num = 0;
|
||||
} else if (head_start + head_num <= closest_power_of_2) {
|
||||
close_head_num = head_num;
|
||||
} else {
|
||||
close_head_num = closest_power_of_2 - head_start;
|
||||
}
|
||||
int far_head_num = head_num - close_head_num;
|
||||
|
||||
// fill range: 1, 2..., n1, 1, 3, (n - n1) * 2 - 1
|
||||
if (close_head_num) {
|
||||
genRange(range_nram, range_1, close_head_num, 64, head_start);
|
||||
__bang_write_value(base_nram, close_head_num, dynamic_base);
|
||||
}
|
||||
if (far_head_num) {
|
||||
genRange(range_nram + close_head_num, range_2, far_head_num, 64,
|
||||
(head_start + close_head_num - closest_power_of_2) * 2);
|
||||
__bang_write_value(base_nram + close_head_num, far_head_num, extra_base);
|
||||
}
|
||||
|
||||
// base_nram ** range_nram
|
||||
__bang_log(base_nram, base_nram, head_num);
|
||||
__bang_mul(slope_nram, base_nram, range_nram, head_num);
|
||||
__bang_pow2(slope_nram, slope_nram, head_num);
|
||||
|
||||
if (use_dynamic) {
|
||||
__bang_mul_scalar(slope_nram, slope_nram, scale, close_head_num);
|
||||
}
|
||||
|
||||
__memcpy(alibi_slopes + taskIdX * head_num, slope_nram, head_num * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
|
||||
void *alibi_slopes,
|
||||
void *true_seq_lens,
|
||||
int batch_num,
|
||||
int head_start,
|
||||
int head_num,
|
||||
int head_num_total,
|
||||
int max_sequence_length,
|
||||
bool use_dynamic) {
|
||||
cnrtDim3_t dim{.x = (uint32_t)batch_num, .y = 1, .z = 1};
|
||||
|
||||
int closest_power_of_2 = pow(2, floor(log2(head_num_total)));
|
||||
int farthest_power_of_2 = closest_power_of_2 * 2;
|
||||
float base = pow(2, (-pow(2, -(log2(closest_power_of_2) - 3))));
|
||||
float extra_base = pow(2, (-pow(2, -(log2(2 * closest_power_of_2) - 3))));
|
||||
|
||||
kernels::MLUAlibiSlopeKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(float *)alibi_slopes, (int *)true_seq_lens, batch_num, head_start, head_num, head_num_total,
|
||||
max_sequence_length, use_dynamic, closest_power_of_2, farthest_power_of_2, base, extra_base);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
43
torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mluh
Normal file
43
torch_mlu_ops-v1.3.2/csrc/kernels/generate_alibi_slope.mluh
Normal file
@@ -0,0 +1,43 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
|
||||
#define CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Generate causal mask for context satge.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param alibi_slopes: Output. Pointer to the MLU memory that stores the output, the shape must be
|
||||
* [batch_num, head_num].
|
||||
* @param true_seq_lens: Input. Pointer to the MLU memory that stores the actual sequence length of
|
||||
* each batch, the shape must be [batch_num].
|
||||
* @param batch_num: Batch number.
|
||||
* @param head_start: The index of first head.
|
||||
* @param head_num: Head number in this card.
|
||||
* @param head_num_total: Total head number in all cards.
|
||||
* @param max_sequence_length: The maximum sequence length used during training.
|
||||
* @param use_dynamic: A boolean value indicates whether to use dynamic NTK.
|
||||
*/
|
||||
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
|
||||
void *alibi_slopes,
|
||||
void *true_seq_lens,
|
||||
int batch_num,
|
||||
int head_start,
|
||||
int head_num,
|
||||
int head_num_total,
|
||||
int max_sequence_length,
|
||||
bool use_dynamic);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_GENERATE_ALIBI_SLOPE_MLUH_
|
||||
214
torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mlu
Normal file
214
torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mlu
Normal file
@@ -0,0 +1,214 @@
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include "cn_api.h"
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "generate_mask.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
template <typename T>
|
||||
__mlu_func__ void write_value(void *dst, unsigned int elem_count, T value) {
|
||||
__bang_write_value(dst, elem_count, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void write_value(void *dst, unsigned int elem_count, bfloat16_t value) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
__bang_write_value(dst, elem_count, value);
|
||||
#endif
|
||||
}
|
||||
|
||||
// [once_len, once_len]
|
||||
__nram__ int8_t nram_small[(__MLU_NRAM_SIZE__ * 1 / 4 * 1024)];
|
||||
// [1 + once_len, 2 * once_len]
|
||||
__nram__ int8_t nram_large[(__MLU_NRAM_SIZE__ * 2 / 4 * 1024 + 1024)];
|
||||
// [once_len * 2 + 1]
|
||||
__nram__ int8_t nram_tiny[2048];
|
||||
template <typename T>
|
||||
class GenerateMask {
|
||||
constexpr static int once_len = sizeof(T) == 4 ? 160 : 256;
|
||||
// [once_len, once_len]
|
||||
T *nram_upper = (T *)(nram_small);
|
||||
// [1 + once_len, 2 * once_len]
|
||||
T *nram_buf = (T *)(nram_large);
|
||||
// [once_len, once_len], reuse upper part of nram_buf
|
||||
T *nram_filled = nram_buf;
|
||||
// [once_len, once_len], reuse lower part of nram_buf
|
||||
T *nram_zeros = nram_buf + once_len * once_len;
|
||||
// [once_len]
|
||||
T *nram_ones_zeros = (T *)nram_tiny;
|
||||
|
||||
__mlu_func__ void initBuffers(T fill_value = -10000) {
|
||||
/* nram_buf:
|
||||
|---once_len---||---once_len---|
|
||||
0, 1, 1, 1, ..., 1, 0, 0, 0, ...
|
||||
0, 0, 1, 1, ..., 1, 1, 0, 0, ...
|
||||
0, 0, 0, 1, ..., 1, 1, 1, 0, ...
|
||||
... */
|
||||
nram_buf[0] = 0;
|
||||
constexpr static int copy_size = (once_len * 2 + 1) * sizeof(T);
|
||||
__memcpy(nram_buf + 1, nram_ones_zeros, copy_size, NRAM2NRAM, copy_size, 0, once_len - 1);
|
||||
__memcpy(nram_upper, nram_buf, once_len * sizeof(T), NRAM2NRAM, once_len * sizeof(T),
|
||||
once_len * 2 * sizeof(T), once_len - 1);
|
||||
// nram_buf is nolonger needed
|
||||
write_value(nram_filled, once_len * once_len, (T)fill_value);
|
||||
write_value(nram_zeros, once_len * once_len, (T)0);
|
||||
}
|
||||
|
||||
__mlu_func__ void dealOneBatch(T *output, // [max_seq_len, max_seq_len]
|
||||
int max_seq_len,
|
||||
int seq_len) {
|
||||
/*
|
||||
| once_len |
|
||||
+----------+-----------------------------------+
|
||||
| | | |
|
||||
| upper | fill_value | |
|
||||
| | | |
|
||||
+----------+----------+ | |
|
||||
| | | | |
|
||||
| | upper | | fill |
|
||||
| | | | value |
|
||||
| +----------+----------+ | |
|
||||
| | | | |
|
||||
| | upper | | |
|
||||
| 0 | | | |
|
||||
| +----------+---+ |
|
||||
| | u | |
|
||||
|--------------------------------+---+ |
|
||||
| |
|
||||
| fill_value |
|
||||
| |
|
||||
+----------------------------------------------+
|
||||
*/
|
||||
int tile_count = seq_len / once_len;
|
||||
int tile_remain = seq_len % once_len;
|
||||
int boarder_len = max_seq_len - seq_len;
|
||||
int row = 0;
|
||||
for (; row < tile_count * once_len; row += once_len) {
|
||||
// fill left with zeros
|
||||
// assume that max_seq_len <= once_len^2
|
||||
if (row > 0) {
|
||||
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
|
||||
max_seq_len * sizeof(T), 0, once_len - 1);
|
||||
}
|
||||
// fill middle with upper
|
||||
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, once_len * sizeof(T),
|
||||
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), once_len - 1);
|
||||
// fill right with fill_value
|
||||
if (row + once_len < max_seq_len) {
|
||||
__memcpy_async(output + (size_t)row * max_seq_len + row + once_len, nram_filled,
|
||||
(max_seq_len - row - once_len) * sizeof(T), NRAM2GDRAM,
|
||||
max_seq_len * sizeof(T), 0, once_len - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (tile_remain) {
|
||||
// fill left with zeros
|
||||
if (row > 0) {
|
||||
__memcpy_async(output + (size_t)row * max_seq_len, nram_zeros, row * sizeof(T), NRAM2GDRAM,
|
||||
max_seq_len * sizeof(T), 0, tile_remain - 1);
|
||||
}
|
||||
// fill middle with upper
|
||||
__memcpy_async(output + (size_t)row * max_seq_len + row, nram_upper, tile_remain * sizeof(T),
|
||||
NRAM2GDRAM, max_seq_len * sizeof(T), once_len * sizeof(T), tile_remain - 1);
|
||||
// fill right with fill_value
|
||||
if (row + tile_remain < max_seq_len) {
|
||||
__memcpy_async(output + (size_t)row * max_seq_len + row + tile_remain, nram_filled,
|
||||
(max_seq_len - row - tile_remain) * sizeof(T), NRAM2GDRAM,
|
||||
max_seq_len * sizeof(T), 0, tile_remain - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (boarder_len) {
|
||||
// fill right boarder with fill_value
|
||||
__memcpy_async(output + seq_len, nram_filled, boarder_len * sizeof(T), NRAM2GDRAM,
|
||||
max_seq_len * sizeof(T), 0, (max_seq_len - boarder_len) - 1);
|
||||
// fill bottom boarder with fill_value
|
||||
__memcpy_async(output + (size_t)seq_len * max_seq_len, nram_filled, max_seq_len * sizeof(T),
|
||||
NRAM2GDRAM, max_seq_len * sizeof(T), 0, boarder_len - 1);
|
||||
}
|
||||
__sync_io();
|
||||
}
|
||||
|
||||
public:
|
||||
__mlu_func__ void execute(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
|
||||
int *batch_seq_len,
|
||||
int total_batch,
|
||||
int max_seq_len,
|
||||
T fill_value = -10000) {
|
||||
int batch_each = total_batch / taskDimY;
|
||||
int batch_remain = total_batch % taskDimY;
|
||||
int batch_start = taskIdY * batch_each + (taskIdY < batch_remain ? taskIdY : batch_remain);
|
||||
int batch_count = batch_each + (taskIdY < batch_remain ? 1 : 0);
|
||||
write_value(nram_ones_zeros, once_len, (T)fill_value);
|
||||
write_value(nram_ones_zeros + once_len, once_len + 1, (T)0);
|
||||
initBuffers();
|
||||
|
||||
for (int n = batch_start; n < batch_start + batch_count; n++) {
|
||||
T *output = output_ddr + (size_t)n * max_seq_len * max_seq_len;
|
||||
int seq_len = batch_seq_len[n];
|
||||
dealOneBatch(output, max_seq_len, seq_len);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUUnion1GenerateMask(T *output_ddr, // [total_batch, max_seq_len, max_seq_len]
|
||||
int *batch_seq_len,
|
||||
int total_batch,
|
||||
int max_seq_len,
|
||||
T fill_value = -10000) {
|
||||
if (coreId != 0) {
|
||||
return; // we only use 1 core in a cluster
|
||||
}
|
||||
GenerateMask<T>().execute(output_ddr, batch_seq_len, total_batch, max_seq_len, fill_value);
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
|
||||
void *output_ddr,
|
||||
int *batch_seq_len,
|
||||
int total_batch,
|
||||
int max_seq_len,
|
||||
cnnlDataType_t data_type,
|
||||
float fill_value) {
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
CNdev dev;
|
||||
cnnlGetDevice(handle, &dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
cnrtDim3_t dim;
|
||||
dim.x = 4;
|
||||
dim.y = cluster_num;
|
||||
dim.z = 1;
|
||||
if (data_type == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUUnion1GenerateMask<float><<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<float *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
|
||||
static_cast<float>(fill_value));
|
||||
} else if (data_type == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUUnion1GenerateMask<half><<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<half *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
|
||||
static_cast<half>(fill_value));
|
||||
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeGenerateMask]: MLU300 devices do not support bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUUnion1GenerateMask<bfloat16_t><<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<bfloat16_t *>(output_ddr), batch_seq_len, total_batch, max_seq_len,
|
||||
static_cast<bfloat16_t>(fill_value));
|
||||
} else {
|
||||
std::cerr << "[invokeGenerateMask]: invokeGenerateMask: data_type is not supported"
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
37
torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mluh
Normal file
37
torch_mlu_ops-v1.3.2/csrc/kernels/generate_mask.mluh
Normal file
@@ -0,0 +1,37 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_GENERATE_MASK_MLUH_
|
||||
#define CSRC_KERNELS_GENERATE_MASK_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Generate causal mask for context stage.
|
||||
* @param handle: The handle of cnnl.
|
||||
* @param output_ddr: Output. Pointer to the MLU memory that stores the output.
|
||||
* @param batch_seq_len: Input. Pointer to the MLU memory that stores the sequence length.
|
||||
* @param total_batch: Batch size.
|
||||
* @param max_seq_len: The maximum sequence length of context.
|
||||
* @param data_type: Data type.
|
||||
* @param fill_value: The fill value of the pad part.
|
||||
*/
|
||||
KernelStatus invokeGenerateMask(cnnlHandle_t handle,
|
||||
void *output_ddr,
|
||||
int *batch_seq_len,
|
||||
int total_batch,
|
||||
int max_seq_len,
|
||||
cnnlDataType_t data_type,
|
||||
float fill_value);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_GENERATE_MASK_MLUH_
|
||||
60
torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mlu
Normal file
60
torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mlu
Normal file
@@ -0,0 +1,60 @@
|
||||
#include <stdexcept>
|
||||
#include "cnrt.h"
|
||||
#include "get_glm_position_id.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
__nram__ int nram_buffer[__MLU_NRAM_SIZE__ * 3 / 4 * 1024 / sizeof(int)];
|
||||
|
||||
__mlu_global__ void MLUBlockSliceIndividualPosId(int *context_pos_id,
|
||||
int *generate_pos_id,
|
||||
int batch,
|
||||
int context_seq_len,
|
||||
int pos_id_dimension /* 1 for 1D, 2 for 2D */) {
|
||||
if (taskId != 0) return;
|
||||
__memcpy(nram_buffer, context_pos_id + context_seq_len - 1, sizeof(int), GDRAM2NRAM, sizeof(int),
|
||||
context_seq_len * sizeof(int), pos_id_dimension * batch - 1);
|
||||
if (pos_id_dimension == 2) {
|
||||
for (int i = 1; i < 2 * batch; i += 2) {
|
||||
nram_buffer[i] += 1;
|
||||
}
|
||||
}
|
||||
__memcpy(generate_pos_id, nram_buffer, pos_id_dimension * batch * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUBlockIncrement2DPosId(int *generate_pos_id, int batch) {
|
||||
if (taskId != 0) return;
|
||||
__memcpy(nram_buffer, generate_pos_id, 2 * batch * sizeof(int), GDRAM2NRAM);
|
||||
for (int i = 1; i < 2 * batch; i += 2) {
|
||||
nram_buffer[i] += 1;
|
||||
}
|
||||
__memcpy(generate_pos_id, nram_buffer, 2 * batch * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
} // namespace kernels
|
||||
KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue,
|
||||
int *context_pos_id,
|
||||
int *generate_pos_id,
|
||||
int batch,
|
||||
int context_seq_len,
|
||||
int pos_id_dimension /* 1 for 1D, 2 for 2D */) {
|
||||
if (pos_id_dimension != 1 && pos_id_dimension != 2) {
|
||||
std::cerr << "[invokeSliceIndividualPosId]: pos_id_dimension must be 1 or 2, but got "
|
||||
<< pos_id_dimension << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
cnrtDim3_t dim{4, 1, 1};
|
||||
kernels::MLUBlockSliceIndividualPosId<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
context_pos_id, generate_pos_id, batch, context_seq_len, pos_id_dimension);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch) {
|
||||
cnrtDim3_t dim{4, 1, 1};
|
||||
kernels::MLUBlockIncrement2DPosId<<<dim, cnrtFuncTypeUnion1, queue>>>(generate_pos_id, batch);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
59
torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mluh
Normal file
59
torch_mlu_ops-v1.3.2/csrc/kernels/get_glm_position_id.mluh
Normal file
@@ -0,0 +1,59 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
|
||||
#define CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Get generate position id from context position id, when position id is 2D,
|
||||
* increase block position id by one.
|
||||
* @example
|
||||
* in GLM network, context_pos_id shape is [batch, 2, context_seq_len], data is
|
||||
* [[[0, 1, 2, 2, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1]],
|
||||
* [[0, 1, 2, 3, 4, 5, 5], [0, 0, 0, 0, 0, 0, 1]]]
|
||||
* after invoke this kernel, the data is
|
||||
* [[[2], [2]],
|
||||
* [[5], [2]]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param context_pos_id: Input. Pointer to the MLU memory that stores the position id of
|
||||
* context.
|
||||
* @param generate_pos_id: Output. Pointer to the MLU memory that stores the position id of
|
||||
* generate.
|
||||
* @param batch: Batch size.
|
||||
* @param context_seq_len: The sequence length of context.
|
||||
* @param pos_id_dimension: The dimension of position id, 1 for 1D, 2 for 2D.
|
||||
*/
|
||||
KernelStatus invokeSliceIndividualPosId(cnrtQueue_t queue,
|
||||
int *context_pos_id,
|
||||
int *generate_pos_id,
|
||||
int batch,
|
||||
int context_seq_len,
|
||||
int pos_id_dimension);
|
||||
|
||||
/**
|
||||
* @brief Increase block position id by one in generate stage.
|
||||
* @example
|
||||
* in GLM network, generate_pos_id shape is [batch, 2, 1], data is
|
||||
* [[[2], [1]], [[5], [1]]]
|
||||
* after invoke this kernel, the data is
|
||||
* [[[2], [2]], [[5], [2]]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param generate_pos_id: Output/Input. Pointer to the MLU memory that stores the position id of
|
||||
* generate.
|
||||
* @param batch: Batch size.
|
||||
*/
|
||||
KernelStatus invokeIncrement2DPosId(cnrtQueue_t queue, int *generate_pos_id, int batch);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_GET_GLM_POSITION_ID_MLUH_
|
||||
54
torch_mlu_ops-v1.3.2/csrc/kernels/kernel_utils.h
Normal file
54
torch_mlu_ops-v1.3.2/csrc/kernels/kernel_utils.h
Normal file
@@ -0,0 +1,54 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_KERNEL_UTILS_H_
|
||||
#define CSRC_KERNELS_KERNEL_UTILS_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
|
||||
namespace tmo {
|
||||
const std::string arch_370 = "MLU370";
|
||||
enum class KernelStatus { KERNEL_STATUS_SUCCESS = 0, KERNEL_STATUS_FAILED };
|
||||
|
||||
#ifndef PAD_DOWN
|
||||
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
|
||||
#endif
|
||||
|
||||
#ifndef PAD_UP
|
||||
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
||||
#endif
|
||||
|
||||
inline bool isMlu300(const std::string &dev_name) {
|
||||
if (dev_name.find("MLU3") != std::string::npos) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_arch300() {
|
||||
int card_id = -1;
|
||||
cnrtDeviceProp_t dev_info;
|
||||
CNRT_CHECK(cnrtGetDevice(&card_id));
|
||||
CNRT_CHECK(cnrtGetDeviceProperties(&dev_info, card_id));
|
||||
std::string dev_name = dev_info.name;
|
||||
return isMlu300(dev_name);
|
||||
}
|
||||
|
||||
inline bool isBf16Supported() {
|
||||
return !is_arch300();
|
||||
}
|
||||
} // namespace tmo
|
||||
#endif // CSRC_KERNELS_KERNEL_UTILS_H_
|
||||
521
torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu
Normal file
521
torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu
Normal file
@@ -0,0 +1,521 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "add_bias_activation.mluh"
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
||||
#define USE_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 20 * 1024)
|
||||
#define USE_SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 20 * 1024)
|
||||
|
||||
__nram__ int8_t nram_buffer[USE_NRAM_SIZE];
|
||||
__mlu_shared__ int8_t sram_buffer[USE_SRAM_SIZE];
|
||||
|
||||
__mlu_func__ void get_expert_info(int *nram_count,
|
||||
int *count_sram,
|
||||
uint32_t *gather_offset,
|
||||
int real_inner,
|
||||
int tokens_start,
|
||||
int tokens_end,
|
||||
int tokens_load,
|
||||
int expert_deal_start,
|
||||
int expert_deal_end,
|
||||
uint32_t &expert_start,
|
||||
uint32_t &expert_end,
|
||||
uint32_t &tokens_deal_first,
|
||||
int dtype_size) {
|
||||
bool record_start = false;
|
||||
// loop expert to find first and last deal expert in current core
|
||||
for (int expert_id = expert_deal_start; expert_id <= expert_deal_end; expert_id++) {
|
||||
if (__load_nram(nram_count + expert_id + 1) > tokens_start && !record_start) {
|
||||
expert_start = expert_id;
|
||||
tokens_deal_first =
|
||||
std::min(__load_nram(nram_count + expert_id + 1) - 1, tokens_end) - tokens_start + 1;
|
||||
record_start = true;
|
||||
}
|
||||
if (__load_nram(nram_count + expert_id + 1) > tokens_end) {
|
||||
expert_end = expert_id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// record expert offset to gather bias
|
||||
__bang_write_zero(gather_offset, tokens_load);
|
||||
int tokens_load_total = 0;
|
||||
for (int expert_id = expert_start; expert_id <= expert_end; expert_id++) {
|
||||
int tokens_expand = __load_sram((int *)count_sram + expert_id);
|
||||
if (expert_id == expert_start) {
|
||||
tokens_expand = tokens_deal_first;
|
||||
} else if (expert_id == expert_end) {
|
||||
tokens_expand = tokens_load - tokens_load_total;
|
||||
}
|
||||
if (tokens_expand == 0) {
|
||||
continue;
|
||||
}
|
||||
__bang_write_value(gather_offset + tokens_load_total, tokens_expand,
|
||||
(int)((expert_id - expert_deal_start) * real_inner * dtype_size));
|
||||
tokens_load_total += tokens_expand;
|
||||
}
|
||||
}
|
||||
|
||||
/*************** functions for compute basic operation ***************/
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void add_bias(T *dst_src, T *bias, int number) {
|
||||
// cycle add bias
|
||||
__bang_add((T *)dst_src, (T *)dst_src, (T *)bias, number);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void add_bias(bfloat16_t *dst_src, bfloat16_t *bias, int number) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_add((bfloat16_t *)dst_src, (bfloat16_t *)dst_src, (bfloat16_t *)bias, number);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mul_left_right(T *left, T *right, int number) {
|
||||
__bang_mul((T *)left, (T *)left, (T *)right, number);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void mul_left_right(bfloat16_t *left, bfloat16_t *right, int number) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_mul((bfloat16_t *)left, (bfloat16_t *)left, (bfloat16_t *)right, number);
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void do_activation(float *input_left,
|
||||
float *act_space,
|
||||
int number,
|
||||
float active_coef,
|
||||
cnnlActivationMode_t act_type) {
|
||||
if (act_type == CNNL_ACTIVATION_GELU) {
|
||||
__bang_active_gelu((float *)input_left, (float *)input_left, number);
|
||||
} else if (act_type == CNNL_ACTIVATION_SWISH) {
|
||||
float *tmp = input_left;
|
||||
if (active_coef != 1.0f) {
|
||||
__bang_mul_scalar(act_space, input_left, active_coef, number);
|
||||
tmp = act_space;
|
||||
}
|
||||
__bang_active_sigmoid((float *)act_space, (float *)tmp, number);
|
||||
__bang_mul((float *)input_left, (float *)input_left, (float *)act_space, number);
|
||||
}
|
||||
}
|
||||
|
||||
/*************** functions for steps of each loop ***************/
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void gather_bias(T *bias_sram,
|
||||
T *bias_nram,
|
||||
uint32_t *gather_offset,
|
||||
int expert_start,
|
||||
int expert_end,
|
||||
int expert_deal_start,
|
||||
int tokens_deal_first,
|
||||
int tokens_deal,
|
||||
int inner_size,
|
||||
bool is_gated) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (is_gated) {
|
||||
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
|
||||
inner_size * sizeof(T), tokens_deal);
|
||||
__gather_async((T *)bias_nram + tokens_deal * inner_size, (T *)bias_sram + inner_size,
|
||||
gather_offset, inner_size * sizeof(T), SRAM2NRAM, inner_size * sizeof(T),
|
||||
tokens_deal);
|
||||
} else {
|
||||
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
|
||||
inner_size * sizeof(T), tokens_deal);
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < tokens_deal; i++) {
|
||||
if (is_gated) {
|
||||
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
|
||||
inner_size * sizeof(T), SRAM2NRAM);
|
||||
__memcpy_async((T *)bias_nram + (tokens_deal * inner_size + i * inner_size),
|
||||
(int8_t *)bias_sram + inner_size * sizeof(T) + gather_offset[i],
|
||||
inner_size * sizeof(T), SRAM2NRAM);
|
||||
} else {
|
||||
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
|
||||
inner_size * sizeof(T), SRAM2NRAM);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void loadBiasInput(T *input,
|
||||
T *left,
|
||||
T *right,
|
||||
T *bias_nram,
|
||||
T *bias_sram,
|
||||
uint32_t *gather_offset,
|
||||
size_t input_offset,
|
||||
int tokens_deal,
|
||||
int inner_size,
|
||||
uint32_t expert_start,
|
||||
uint32_t expert_end,
|
||||
int expert_deal_start,
|
||||
uint32_t tokens_deal_first,
|
||||
bool is_gated,
|
||||
bool has_bias) {
|
||||
if (is_gated) {
|
||||
// if gated, stride io load input, left/right inner to input_left/right
|
||||
__memcpy_async((T *)left, (T *)input + input_offset, inner_size * sizeof(T), GDRAM2NRAM,
|
||||
inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
|
||||
__memcpy_async((T *)right, (T *)input + input_offset + inner_size, inner_size * sizeof(T),
|
||||
GDRAM2NRAM, inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
|
||||
} else {
|
||||
// if not gated, load input to input_left total
|
||||
__memcpy_async((T *)left, (T *)input + input_offset, tokens_deal * inner_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
if (has_bias) {
|
||||
__sync_compute();
|
||||
gather_bias((T *)bias_sram, (T *)bias_nram, gather_offset, expert_start, expert_end,
|
||||
expert_deal_start, tokens_deal_first, tokens_deal, inner_size, is_gated);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void computeAddActivation(T *bias_nram,
|
||||
T *left_dst,
|
||||
T *input_right,
|
||||
float *input_left,
|
||||
float *act_space,
|
||||
int tokens_deal,
|
||||
int inner_size,
|
||||
bool is_gated,
|
||||
bool has_bias,
|
||||
float active_coef,
|
||||
cnnlActivationMode_t act_type) {
|
||||
int number = tokens_deal * inner_size;
|
||||
if (has_bias) {
|
||||
add_bias((T *)left_dst, (T *)bias_nram, number);
|
||||
if (is_gated) {
|
||||
add_bias((T *)input_right, (T *)bias_nram + tokens_deal * inner_size, number);
|
||||
}
|
||||
}
|
||||
|
||||
// cast half/bfloat16 to float to acvication, if float, left_dst is same as input_left
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float((float *)input_left, (half *)left_dst, number);
|
||||
}
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_bfloat162float((float *)input_left, (bfloat16_t *)left_dst, number);
|
||||
}
|
||||
#endif
|
||||
|
||||
// activation
|
||||
do_activation(input_left, act_space, number, active_coef, act_type);
|
||||
|
||||
// if half/bfloat16, cast float to T to mul
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_float2half((half *)input_left, (float *)input_left, number);
|
||||
}
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_float2bfloat16((bfloat16_t *)input_left, (float *)input_left, number);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_gated) {
|
||||
mul_left_right((T *)input_left, (T *)input_right, tokens_deal * inner_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void storeOutput(T *output,
|
||||
T *output_nram,
|
||||
size_t output_offset,
|
||||
int output_stride,
|
||||
int tokens_deal,
|
||||
int inner_size) {
|
||||
__memcpy_async((T *)output + output_offset, (T *)output_nram, inner_size * sizeof(T), NRAM2GDRAM,
|
||||
output_stride * sizeof(T), inner_size * sizeof(T), tokens_deal - 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUAddBiasActivationKernel(T *output,
|
||||
const T *input,
|
||||
const T *bias,
|
||||
const int *cusum_token_count,
|
||||
int num_expert,
|
||||
int total_tokens,
|
||||
int inner_size,
|
||||
int output_stride,
|
||||
bool is_gated,
|
||||
cnnlActivationMode_t act_type,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
float active_coef) {
|
||||
// if bias and token_count is nullptr, not add bias, only activation and gated mul.
|
||||
bool has_bias = (bias != nullptr);
|
||||
|
||||
// 1. distrubute nram space
|
||||
/* if not gated
|
||||
---------------------------- ----------------------------------
|
||||
| nram_token_count | bias_ping/pong |
|
||||
| num_expert * sizeof(int) | 2 * x * inner_size * sizeof(T) |
|
||||
---------------------------- ----------------------------------
|
||||
--------------------------------------
|
||||
| input_ping/pong |
|
||||
| 2 * x * inner_size * sizeof(float) |
|
||||
--------------------------------------
|
||||
------------------------------------------------- -------------------
|
||||
| act_space | gather_offset |
|
||||
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
|
||||
------------------------------------------------- -------------------
|
||||
*/
|
||||
|
||||
/* if gated
|
||||
---------------------------- ----------------------------------
|
||||
| nram_token_count | bias_ping/pong |
|
||||
| num_expert * sizeof(int) | 2 * x * real_inner * sizeof(T) |
|
||||
---------------------------- ----------------------------------
|
||||
-------------------------------------- ----------------------------------
|
||||
| input_left_ping/pong | input_right_ping/pong |
|
||||
| 2 * x * inner_size * sizeof(float) | 2 * x * inner_size * sizeof(T) |
|
||||
-------------------------------------- ----------------------------------
|
||||
------------------------------------------------- -------------------
|
||||
| act_space | gather_offset |
|
||||
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
|
||||
------------------------------------------------- -------------------
|
||||
*/
|
||||
|
||||
// distribute sram
|
||||
int8_t *count_sram = (int8_t *)sram_buffer;
|
||||
int8_t *bias_sram = (int8_t *)count_sram + num_expert * sizeof(int);
|
||||
|
||||
// distrubute nram
|
||||
int real_inner = (is_gated) ? inner_size * 2 : inner_size;
|
||||
int bias_nram_size = (has_bias) ? real_inner * sizeof(T) : 0;
|
||||
int act_space_size = (act_type == CNNL_ACTIVATION_GELU) ? 0 : inner_size * sizeof(float);
|
||||
int gated_ext_size = is_gated ? sizeof(T) : 0;
|
||||
int max_token_deal = (USE_NRAM_SIZE - (num_expert + 1) * sizeof(int)) /
|
||||
(2 * inner_size * (sizeof(float) + gated_ext_size) + act_space_size +
|
||||
2 * bias_nram_size + sizeof(int));
|
||||
|
||||
int8_t *nram_count = (int8_t *)nram_buffer;
|
||||
int8_t *bias_nram = (int8_t *)nram_count + (num_expert + 1) * sizeof(int);
|
||||
int8_t *input_left = (int8_t *)bias_nram + 2 * max_token_deal * bias_nram_size;
|
||||
int8_t *input_right =
|
||||
(int8_t *)input_left + 2 * ((is_gated) ? max_token_deal * inner_size * sizeof(float) : 0);
|
||||
int8_t *act_space = (int8_t *)input_right +
|
||||
2 * max_token_deal * inner_size * (is_gated ? sizeof(T) : sizeof(float));
|
||||
int8_t *gather_offset = (int8_t *)act_space + max_token_deal * act_space_size;
|
||||
|
||||
// 2. cusum_token_count load to nram, because need to reuse in load bias.
|
||||
if (has_bias) {
|
||||
__memcpy((int *)nram_count, (int *)cusum_token_count, (num_expert + 1) * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
if (taskIdX == 0) {
|
||||
__bang_sub((int *)bias_nram, (int *)nram_count + 1, (int *)nram_count, num_expert);
|
||||
__sync();
|
||||
__memcpy((int *)count_sram, (int *)bias_nram, num_expert * sizeof(int), NRAM2SRAM);
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
|
||||
// 3. sram loop to compute
|
||||
// compute once load bias to sram due to sram_limit
|
||||
int max_expert_deal = (USE_SRAM_SIZE - num_expert * sizeof(int)) / (real_inner * sizeof(T));
|
||||
int real_expert = cusum_token_count == nullptr ? num_expert : expert_size;
|
||||
int sram_loop_rem = real_expert % max_expert_deal;
|
||||
int sram_loop = real_expert / max_expert_deal + (int)(sram_loop_rem != 0);
|
||||
if (!has_bias) {
|
||||
max_expert_deal = real_expert;
|
||||
sram_loop = 1;
|
||||
sram_loop_rem = 0;
|
||||
}
|
||||
for (int deal_loop = 0; deal_loop < sram_loop; deal_loop++) {
|
||||
// load current bias, compute each core deal number
|
||||
int expert_deal =
|
||||
(deal_loop == (sram_loop - 1) && sram_loop_rem != 0) ? sram_loop_rem : max_expert_deal;
|
||||
int expert_deal_start = deal_loop * max_expert_deal + start_expert_id;
|
||||
int expert_deal_end = expert_deal_start + expert_deal - 1;
|
||||
__sync_all();
|
||||
if (has_bias && __is_mpu()) {
|
||||
__memcpy((T *)bias_sram, (T *)bias + deal_loop * max_expert_deal * real_inner,
|
||||
expert_deal * real_inner * sizeof(T), GDRAM2SRAM);
|
||||
}
|
||||
__sync_all();
|
||||
|
||||
// get tokens info of each core
|
||||
int tokens_total_cur = total_tokens;
|
||||
if (has_bias) {
|
||||
tokens_total_cur =
|
||||
__load_nram((int *)nram_count + expert_deal_end + 1) -
|
||||
(start_expert_id == 0 ? 0 : __load_nram((int *)nram_count + expert_deal_start));
|
||||
} else if (cusum_token_count != nullptr) {
|
||||
tokens_total_cur = __load_gdram(cusum_token_count + expert_deal_end + 1) -
|
||||
__load_gdram(cusum_token_count + expert_deal_start);
|
||||
}
|
||||
|
||||
if (sram_loop != 1) {
|
||||
tokens_total_cur = ((int *)nram_count)[expert_deal_end + 1] -
|
||||
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
|
||||
if (deal_loop == 0 && start_expert_id != 0) {
|
||||
tokens_total_cur -= __load_nram((int *)nram_count + expert_deal_start);
|
||||
}
|
||||
}
|
||||
|
||||
int tokens_core_rem = tokens_total_cur % taskDim;
|
||||
int tokens_cur_core = tokens_total_cur / taskDim + (taskId < tokens_core_rem);
|
||||
if (tokens_cur_core == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if ep, input start in current token, have a real start in total network
|
||||
int real_start =
|
||||
cusum_token_count != nullptr ? __load_gdram((int *)cusum_token_count + start_expert_id) : 0;
|
||||
int tokens_core_start = tokens_cur_core * taskId +
|
||||
(taskId < tokens_core_rem ? 0 : tokens_core_rem) +
|
||||
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
|
||||
if (deal_loop != 0) {
|
||||
tokens_core_start -= real_start;
|
||||
}
|
||||
|
||||
uint32_t expert_start = 0;
|
||||
uint32_t expert_end = 0;
|
||||
uint32_t tokens_deal_first = 0;
|
||||
|
||||
// 4. nram loop compute
|
||||
int nram_loop_rem = tokens_cur_core % max_token_deal;
|
||||
int nram_loop = tokens_cur_core / max_token_deal + (int)(nram_loop_rem != 0);
|
||||
int tokens_load = max_token_deal;
|
||||
int tokens_compute = max_token_deal;
|
||||
int tokens_store = max_token_deal;
|
||||
|
||||
for (int loop = 0; loop < nram_loop + 2; loop++) {
|
||||
int inner_io_offset = (loop % 2) * max_token_deal * inner_size;
|
||||
int inner_com_offset = ((loop + 1) % 2) * max_token_deal * inner_size;
|
||||
int real_io_offset = (loop % 2) * max_token_deal * real_inner;
|
||||
int real_com_offset = ((loop + 1) % 2) * max_token_deal * real_inner;
|
||||
if (nram_loop_rem != 0) {
|
||||
if (loop > 1 && (loop - 2) == (nram_loop - 1)) {
|
||||
tokens_store = nram_loop_rem;
|
||||
}
|
||||
if (loop > 0 && (loop - 1) == (nram_loop - 1)) {
|
||||
tokens_compute = nram_loop_rem;
|
||||
}
|
||||
if (loop == (nram_loop - 1)) {
|
||||
tokens_load = nram_loop_rem;
|
||||
}
|
||||
}
|
||||
int tokens_cur_start = tokens_core_start + loop * max_token_deal;
|
||||
int tokens_cur_end = tokens_cur_start + tokens_load - 1;
|
||||
// get current load info
|
||||
if (loop < nram_loop && has_bias) {
|
||||
get_expert_info((int *)nram_count, (int *)count_sram, (uint32_t *)gather_offset, real_inner,
|
||||
tokens_cur_start + real_start, tokens_cur_end + real_start, tokens_load,
|
||||
expert_deal_start, expert_deal_end, expert_start, expert_end,
|
||||
tokens_deal_first, sizeof(T));
|
||||
}
|
||||
|
||||
// store
|
||||
if (loop > 1) {
|
||||
size_t output_offset = (tokens_core_start + (loop - 2) * max_token_deal) * output_stride;
|
||||
storeOutput((T *)output, (T *)((float *)input_left + inner_io_offset), output_offset,
|
||||
output_stride, tokens_store, inner_size);
|
||||
}
|
||||
|
||||
// compute
|
||||
if (loop > 0 && loop <= nram_loop) {
|
||||
T *left_dst = (T *)((float *)input_left + inner_com_offset) +
|
||||
((std::is_same<T, float>::value) ? 0 : tokens_compute * inner_size);
|
||||
computeAddActivation((T *)bias_nram + real_com_offset, (T *)left_dst,
|
||||
(T *)input_right + inner_com_offset,
|
||||
(float *)input_left + inner_com_offset, (float *)act_space,
|
||||
tokens_compute, inner_size, is_gated, has_bias, active_coef, act_type);
|
||||
}
|
||||
|
||||
// load
|
||||
if (loop < nram_loop) {
|
||||
T *left_dst = (T *)((float *)input_left + inner_io_offset) +
|
||||
((std::is_same<T, float>::value) ? 0 : tokens_load * inner_size);
|
||||
size_t input_offset = tokens_cur_start * real_inner;
|
||||
loadBiasInput((T *)input, (T *)left_dst, (T *)input_right + inner_io_offset,
|
||||
(T *)bias_nram + real_io_offset, (T *)bias_sram, (uint32_t *)gather_offset,
|
||||
input_offset, tokens_load, inner_size, expert_start, expert_end,
|
||||
expert_deal_start, tokens_deal_first, is_gated, has_bias);
|
||||
}
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const int *cusum_token_count,
|
||||
int num_expert,
|
||||
int total_tokens,
|
||||
int inner_size,
|
||||
int output_stride,
|
||||
cnnlDataType_t dtype,
|
||||
bool is_gated,
|
||||
cnnlActivationMode_t act_type,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
float active_coef) {
|
||||
if (bias != NULL && cusum_token_count == NULL) {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
|
||||
<< "when have bias, cusum_token_count can not be nullptr.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (act_type != CNNL_ACTIVATION_GELU && act_type != CNNL_ACTIVATION_SWISH) {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
|
||||
<< "activation mode only supports gelu and swish now.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(float *)output, (const float *)input, (const float *)bias, cusum_token_count, num_expert,
|
||||
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
|
||||
active_coef);
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(half *)output, (const half *)(input), (const half *)bias, cusum_token_count, num_expert,
|
||||
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
|
||||
active_coef);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)bias,
|
||||
cusum_token_count, num_expert, total_tokens, inner_size, output_stride, is_gated, act_type,
|
||||
start_expert_id, expert_size, active_coef);
|
||||
} else {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: add_bias_activation data_type not support, "
|
||||
<< "only support float/half/bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,84 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
|
||||
#define CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Add bias and activate to all tokens. Different expert with different bias.
|
||||
* If in gated mode, add bias to all input. But only activate in the
|
||||
* input of [:, :inner_size]. Then multiply it to the input of [:, inner_size:].
|
||||
* Else, add bias and activate to all input, has no multiply process.
|
||||
* @example
|
||||
* is_gated = true, num_expert = 4, total_tokens = 6, inner_size = 2
|
||||
* input: (6, 4) = [[2, 4, 5, 6], [1, 4, 5, 3],
|
||||
* [3, 5, 7, 8], [6, 8, 5, 3],
|
||||
* [2, 3, 4 ,5], [2, 9, 2, 3]]
|
||||
* bias: (4, 4) = [[1, 0, 1, 0], [0, 1, 2, 2], [2, 3, 2, 3], [1, 2, 3, 4]]
|
||||
* token_count = [2, 2, 1, 1]
|
||||
* first step: add bias
|
||||
* [[2+1, 4+0, 5+1, 6+0], [1+1, 4+0, 5+1, 3+0],
|
||||
* [3+0, 5+1, 7+2, 8+2], [6+0, 8+1, 5+2, 3+2],
|
||||
* [2+2, 3+3, 4+2, 5+3], [2+1, 9+2, 2+3, 3+4]]
|
||||
* second step: act and mul
|
||||
* output: (6, 2) = [[act(3)*6, act(4)*6], [act(2)*6, act(4)*3],
|
||||
* [act(3)*9, act(6)*10], [act(6)*7, act(9)*5],
|
||||
* [act(4)*6, act(6)*8], [act(3)*5, act(11)*7]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param output: Output. Pointer to the MLU memory that stores the result.
|
||||
* When is_gated is true, The shape is [total_tokens, input_size / 2].
|
||||
* In this case, the input_size must be even. Otherwise the shape is [total_tokens,
|
||||
* input_size]. The memory can be discontinuous in total_tokens dim. The stride is output_stride.
|
||||
* @param input: Input. Pointer to the MLU memory that stores the input tokens.
|
||||
* The shape is [total_tokens, input_size].
|
||||
* When is_gated is true, the shape is [total_tokens, 2 * inner_size].
|
||||
* Otherwise the shape is [total_tokens, inner_size].
|
||||
* @param bias: Input. Pointer to the MLU memory that stores the bias. The memory must be
|
||||
* continuous. When is_gated is true, the shape is [num_expert, 2 * inner_size]. Otherwise the shape
|
||||
* is [num_expert, inner_size]. Bias can be nullptr. If bias is nullptr, has no add bias process.
|
||||
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of token
|
||||
* counts. The shape is [num_expert + 1]. If cusum_token_count
|
||||
* is not nullptr, cusum_token_count, start_expert_id and
|
||||
* expert_size together determine which tokens to process.
|
||||
* If cusum_token_count is nullptr, process all tokens,
|
||||
* the number of which is total_tokens. When bias is not nullptr,
|
||||
* cusum_token_count must also not be nullptr.
|
||||
* @param num_expert: The number of expert.
|
||||
* @param total_tokens: The total number of tokens.
|
||||
* @param inner_size: The inner size of output.
|
||||
* @param output_stride: The stride of output, must be greater than or equal to inner_size.
|
||||
* @param dtype: Data type.
|
||||
* @param is_gated: Gated or not.
|
||||
* @param act_type: The type of activation. Support gelu and swish.
|
||||
* @param start_expert_id: The index of the start expert.
|
||||
* @param expert_size: The number of experts to process.
|
||||
* @param active_coef: The coefficient used in the swish activation.
|
||||
*/
|
||||
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const int *cusum_token_count,
|
||||
int num_expert,
|
||||
int total_tokens,
|
||||
int inner_size,
|
||||
int output_stride,
|
||||
cnnlDataType_t dtype,
|
||||
bool is_gated,
|
||||
cnnlActivationMode_t act_type,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
float active_coef);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
|
||||
646
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu
Normal file
646
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu
Normal file
@@ -0,0 +1,646 @@
|
||||
#include <stdint.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "cast_gating.mluh"
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
namespace tmo {
|
||||
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
|
||||
|
||||
#define NRAM_BUFFER_SIZE (496 * 1024)
|
||||
#define WRAM_BUFFER_SIZE (512 * 1024)
|
||||
#define SRAM_BUFFER_SIZE (2032 * 1024)
|
||||
|
||||
#ifndef ONE_LINE
|
||||
#define ONE_LINE 64
|
||||
#endif
|
||||
|
||||
#ifndef LT_NUM
|
||||
#define LT_NUM 64
|
||||
#endif
|
||||
|
||||
struct castGatingTileInfo {
|
||||
int32_t block = 64;
|
||||
int32_t split_k_num = 8;
|
||||
int32_t block_k = 256;
|
||||
};
|
||||
|
||||
namespace kernels {
|
||||
#pragma bang walign(16)
|
||||
#ifndef ROW_PER_LT
|
||||
#define ROW_PER_LT 4
|
||||
#endif
|
||||
|
||||
#ifndef LT_SIZE
|
||||
#define LT_SIZE 16
|
||||
#endif
|
||||
|
||||
#ifndef WRAM_LT_MAP16_STRIDE
|
||||
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
|
||||
#endif
|
||||
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
|
||||
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
|
||||
do { \
|
||||
uint32_t align_num = 64 / src_dsize; \
|
||||
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
|
||||
uint32_t rem = size % 64; \
|
||||
if (n) { \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.nram.sram.b16" \
|
||||
" [%[dst_addr]], [%[src_addr]], " \
|
||||
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
|
||||
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
|
||||
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
|
||||
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
|
||||
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
|
||||
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
|
||||
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
|
||||
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
|
||||
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
|
||||
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
|
||||
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
|
||||
} \
|
||||
\
|
||||
if (rem) { \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.nram.sram.b16" \
|
||||
" [%[dst_addr]], [%[src_addr]], " \
|
||||
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
|
||||
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
|
||||
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
|
||||
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
|
||||
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
|
||||
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
|
||||
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
|
||||
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
|
||||
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
|
||||
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
|
||||
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
|
||||
}
|
||||
|
||||
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
|
||||
convert_type) \
|
||||
int align_n = PAD_DOWN(n, LT_NUM); \
|
||||
int sn0 = ONE_LINE; \
|
||||
int size_sn0 = sn0 / src_dsize; \
|
||||
int sn1 = ONE_LINE / src_dsize; \
|
||||
int ss1 = total_k * src_dsize; \
|
||||
int sn3 = k / size_sn0; \
|
||||
int sn4 = align_n / sn1; \
|
||||
int ss4 = sn1 * ss1; \
|
||||
int dn0 = sn0; \
|
||||
int dn1 = ROW_PER_LT; \
|
||||
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
|
||||
int ds1 = dst_k * dst_dsize; \
|
||||
int dn2 = sn1 / ROW_PER_LT; \
|
||||
int ds2 = WRAM_LT_MAP16_STRIDE; \
|
||||
int ds3 = sn0 * dst_dsize / src_dsize; \
|
||||
int dn4 = LT_SIZE / dn2; \
|
||||
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
|
||||
int dn5 = align_n / LT_NUM; \
|
||||
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
|
||||
int rem_k = k % size_sn0; \
|
||||
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
|
||||
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
|
||||
if (align_n > 0 && sn3 > 0) { \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
||||
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
|
||||
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
|
||||
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
||||
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
|
||||
";\n\t" ::[dst_addr] "r"(wram_dst), \
|
||||
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
|
||||
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
|
||||
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
|
||||
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
|
||||
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
|
||||
sram_src += align_n * total_k; \
|
||||
wram_dst += align_n / LT_SIZE * dst_k; \
|
||||
} \
|
||||
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
|
||||
if (align_n > 0 && sn3 > 0) { \
|
||||
sn1 = align_n; \
|
||||
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
||||
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
|
||||
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
||||
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
|
||||
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
|
||||
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
|
||||
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
|
||||
[dst_s3] "r"(ds3)); \
|
||||
sram_src += align_n * total_k; \
|
||||
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
|
||||
} \
|
||||
if (rem_k > 0) { \
|
||||
align_n = PAD_UP(n, LT_NUM); \
|
||||
sn0 = rem_k * src_dsize; \
|
||||
dn0 = sn0; \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
||||
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
|
||||
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
|
||||
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
||||
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
|
||||
";\n\t" ::[dst_addr] "r"(wram_dst2), \
|
||||
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
|
||||
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
|
||||
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
|
||||
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
|
||||
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
|
||||
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
|
||||
[dst_s5] "r"(0)); \
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_weight(float *wram_dst,
|
||||
half *sram_src,
|
||||
int32_t warp_n,
|
||||
int32_t len_k,
|
||||
int32_t total_k) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
|
||||
".cvt.rn.f32.f16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_weight(float *wram_dst,
|
||||
bfloat16_t *sram_src,
|
||||
int32_t warp_n,
|
||||
int32_t len_k,
|
||||
int32_t total_k) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
|
||||
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void warp_prompt_weight(T *wram_dst,
|
||||
T *sram_src,
|
||||
int32_t n,
|
||||
int32_t len_k,
|
||||
int32_t total_k) {
|
||||
int32_t type_size = sizeof(T);
|
||||
int32_t data_size = len_k * type_size;
|
||||
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
|
||||
int32_t ss0 = total_k * type_size;
|
||||
int32_t count = n / LT_NUM;
|
||||
for (int32_t i = 0; i < count; ++i) {
|
||||
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
|
||||
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
|
||||
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
|
||||
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
|
||||
}
|
||||
count = n % LT_NUM / ROW_PER_LT;
|
||||
if (count > 0) {
|
||||
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
|
||||
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
|
||||
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
|
||||
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
|
||||
}
|
||||
count = n % ROW_PER_LT;
|
||||
if (count > 0) {
|
||||
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
|
||||
const int32_t &taskid,
|
||||
const int32_t &taskdim,
|
||||
int32_t &task_offset,
|
||||
int32_t &num_cur_task) {
|
||||
int32_t num_per_task = num_total_task / taskdim;
|
||||
int32_t rem_idx = num_total_task % taskdim;
|
||||
if (taskid < rem_idx) {
|
||||
task_offset = taskid * (num_per_task + 1);
|
||||
num_cur_task = num_per_task + 1;
|
||||
} else {
|
||||
task_offset = taskid * num_per_task + rem_idx;
|
||||
num_cur_task = num_per_task;
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void bidirectionBarrierOp() {
|
||||
int32_t bcnt = coreDim + 1;
|
||||
if (__is_ipu()) {
|
||||
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
} else {
|
||||
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
|
||||
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_store(void *ddr_dst,
|
||||
void *nram_src,
|
||||
const int32_t data_num,
|
||||
const int32_t dst_stride,
|
||||
const int32_t src_stride,
|
||||
const int32_t count,
|
||||
const int32_t dt_size) {
|
||||
if (src_stride == data_num && dst_stride == data_num) {
|
||||
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
|
||||
} else {
|
||||
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
|
||||
src_stride * dt_size, count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tc, typename Tcc>
|
||||
__mlu_func__ void splitKReduce(Tcc *workspace,
|
||||
Tc *output,
|
||||
int32_t M,
|
||||
int32_t N,
|
||||
int32_t split_k_num,
|
||||
int32_t ldc) {
|
||||
int32_t offset_m, cta_m;
|
||||
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
|
||||
if (cta_m <= 0) return;
|
||||
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
|
||||
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
|
||||
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
|
||||
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
|
||||
Tc *output_ddr = (Tc *)output + offset_m * ldc;
|
||||
for (int32_t i = 0; i < repeat; i++) {
|
||||
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
|
||||
int32_t data_size = N * sizeof(Tc);
|
||||
int32_t data_num = current_m - 1;
|
||||
if (ldc == N) {
|
||||
data_size = current_m * N * sizeof(Tc);
|
||||
data_num = 0;
|
||||
}
|
||||
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
|
||||
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
|
||||
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
|
||||
split_k_num, 1, 1, 1);
|
||||
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
|
||||
N * sizeof(Tc), data_num);
|
||||
workspace_ddr = workspace_ddr + block_m * N;
|
||||
output_ddr = output_ddr + block_m * ldc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Ta,
|
||||
typename Tac,
|
||||
typename Tb,
|
||||
typename Tbc,
|
||||
typename Tc,
|
||||
typename Tcc,
|
||||
bool EXCHANGE_AB>
|
||||
__mlu_global__ void MLUCastGating(Ta *A,
|
||||
Tb *B,
|
||||
Tc *C,
|
||||
Tcc *workspace,
|
||||
int32_t M,
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
int32_t lda,
|
||||
int32_t ldb,
|
||||
int32_t ldc,
|
||||
castGatingTileInfo split_info) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
int32_t block_k = split_info.block_k;
|
||||
int32_t grid_dimx = split_info.split_k_num;
|
||||
int32_t block = split_info.block;
|
||||
int32_t grid_idx = clusterId % grid_dimx;
|
||||
int32_t grid_idy = clusterId / grid_dimx;
|
||||
int32_t offset_k = 0, problem_k = 0;
|
||||
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
|
||||
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
|
||||
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
|
||||
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
|
||||
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
|
||||
int32_t warp_m = cta_m, warp_offset_m = 0;
|
||||
int32_t warp_n = cta_n, warp_offset_n = 0;
|
||||
int32_t outer_loop = 0, outer_rem = 0;
|
||||
if (EXCHANGE_AB) {
|
||||
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
|
||||
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
|
||||
if (cta_n > block && cta_n % block != 0) {
|
||||
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
|
||||
if (block_tmp < block) block = block_tmp;
|
||||
}
|
||||
outer_loop = (cta_n + block - 1) / block;
|
||||
outer_rem = cta_n % block == 0 ? block : cta_n % block;
|
||||
} else {
|
||||
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
|
||||
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
|
||||
if (cta_m > block && cta_m % block != 0) {
|
||||
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
|
||||
if (block_tmp < block) block = block_tmp;
|
||||
}
|
||||
outer_loop = (cta_m + block - 1) / block;
|
||||
outer_rem = cta_m % block == 0 ? block : cta_m % block;
|
||||
}
|
||||
|
||||
int32_t size_nram_buf =
|
||||
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
|
||||
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
|
||||
Tac *nbuf_a = (Tac *)nram_buffer;
|
||||
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
|
||||
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
|
||||
|
||||
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
|
||||
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
|
||||
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
|
||||
Ta *sbuf_a = (Ta *)sram_buffer;
|
||||
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
|
||||
|
||||
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
|
||||
Tbc *wbuf_b = (Tbc *)wram_buffer;
|
||||
|
||||
int32_t a_dsize = sizeof(Ta);
|
||||
int32_t b_dsize = sizeof(Tb);
|
||||
int32_t k_loop_count = 0;
|
||||
for (int32_t j = 0; j < outer_loop; j++) {
|
||||
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
|
||||
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
|
||||
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
|
||||
if (EXCHANGE_AB) {
|
||||
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
|
||||
} else {
|
||||
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
|
||||
}
|
||||
int32_t compute_total = warp_m * warp_n;
|
||||
if (compute_total > 0 && __is_ipu()) {
|
||||
if (!EXCHANGE_AB) {
|
||||
__sync_io_move_compute(true, false, false, false, false, true);
|
||||
}
|
||||
__bang_write_zero((Tcc *)nbuf_c, compute_total);
|
||||
}
|
||||
int32_t i = 0;
|
||||
for (; i < k_loop; i++) {
|
||||
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
|
||||
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
|
||||
cta_k = i == k_loop - 1 ? rem_k : block_k;
|
||||
if (__is_mpu()) {
|
||||
if (EXCHANGE_AB) {
|
||||
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
|
||||
current_block - 1);
|
||||
__asm__ volatile(
|
||||
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
|
||||
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
|
||||
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
|
||||
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
|
||||
} else {
|
||||
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
|
||||
current_block - 1);
|
||||
__asm__ volatile(
|
||||
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
|
||||
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
|
||||
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
|
||||
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
|
||||
}
|
||||
a_ddr = (Ta *)a_ddr + block_k;
|
||||
b_ddr = (Tb *)b_ddr + block_k;
|
||||
}
|
||||
bidirectionBarrierOp();
|
||||
if (__is_ipu() && compute_total > 0) {
|
||||
__sync_io_move_compute(false, true, false, false, false, true);
|
||||
__sync_io_move_compute(false, false, true, false, true, false);
|
||||
if (i >= 1) {
|
||||
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
|
||||
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
|
||||
}
|
||||
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
|
||||
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
|
||||
// mvdma bound for EXCHANGE_AB when n==32
|
||||
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
|
||||
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
|
||||
}
|
||||
k_loop_count += 1;
|
||||
}
|
||||
if (compute_total > 0) {
|
||||
__sync_io_move_compute(false, true, false, false, false, true);
|
||||
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
|
||||
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
|
||||
if (EXCHANGE_AB) {
|
||||
__sync_io_move_compute(true, false, false, false, false, true);
|
||||
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
|
||||
}
|
||||
int32_t total_offset =
|
||||
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
|
||||
: (offset_m + warp_offset_m + block * j) * N);
|
||||
Tcc *wks = (Tcc *)workspace + total_offset;
|
||||
int32_t store_c_size = sizeof(Tcc);
|
||||
int8_t *store_ddr = (int8_t *)wks;
|
||||
int32_t dst_str = EXCHANGE_AB ? M : N;
|
||||
if (grid_dimx == 1) {
|
||||
// convert Tcc to Tc
|
||||
dst_str = ldc;
|
||||
store_ddr =
|
||||
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
|
||||
: (offset_m + warp_offset_m + block * j) * ldc));
|
||||
}
|
||||
__asm__ volatile("sync.psimd.cio;\n\t");
|
||||
if (EXCHANGE_AB) {
|
||||
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
|
||||
} else {
|
||||
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_dimx != 1) {
|
||||
__sync_all();
|
||||
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
|
||||
split_info.split_k_num, ldc);
|
||||
}
|
||||
#endif // __BANG_ARCH__ >= 500
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
int32_t getBlock(int32_t m,
|
||||
int32_t n,
|
||||
int32_t core_num,
|
||||
int32_t block_k,
|
||||
int32_t a_dtype_size,
|
||||
int32_t b_dtype_size,
|
||||
int32_t compute_dtype_size,
|
||||
bool EXCHANGE_AB) {
|
||||
int32_t block = 0;
|
||||
if (EXCHANGE_AB) {
|
||||
int32_t block_m = n;
|
||||
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
|
||||
(2 * block_m * compute_dtype_size) * core_num;
|
||||
int32_t wram_block_n =
|
||||
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
|
||||
int32_t sram_block_n =
|
||||
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
|
||||
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
|
||||
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
|
||||
return block_n > 0 ? block_n : block_n_tmp;
|
||||
} else {
|
||||
int32_t block_n = n;
|
||||
int32_t nram_block_m =
|
||||
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
|
||||
int32_t sram_block_m =
|
||||
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
|
||||
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
|
||||
return block;
|
||||
}
|
||||
}
|
||||
|
||||
void gatingTiling(int32_t m,
|
||||
int32_t n,
|
||||
int32_t k,
|
||||
size_t a_dtype_size,
|
||||
size_t b_dtype_size,
|
||||
size_t compute_dtype_size,
|
||||
size_t workspace_size,
|
||||
int32_t union_number,
|
||||
int32_t core_num,
|
||||
int32_t &block,
|
||||
int32_t &split_k_num,
|
||||
int32_t &block_k,
|
||||
bool &EXCHANGE_AB) {
|
||||
block_k = std::min(k, int32_t(512 / a_dtype_size));
|
||||
split_k_num = 1;
|
||||
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
|
||||
if (m >= core_num * LT_NUM &&
|
||||
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
|
||||
EXCHANGE_AB = true;
|
||||
}
|
||||
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
|
||||
compute_dtype_size, EXCHANGE_AB);
|
||||
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
|
||||
block = tmp_block;
|
||||
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
|
||||
for (int32_t i = total_blocks; i <= union_number; i++) {
|
||||
if (union_number % i == 0) {
|
||||
int32_t tmp_split_k = union_number / i;
|
||||
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
|
||||
if (workspace_size >= workspace_size_need) {
|
||||
split_k_num = tmp_split_k;
|
||||
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
|
||||
if (EXCHANGE_AB && block > LT_NUM * core_num) {
|
||||
block = PAD_DOWN(block, LT_NUM * core_num);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
|
||||
}
|
||||
|
||||
KernelStatus invokeCastGating(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *filter,
|
||||
void *output,
|
||||
int input_row,
|
||||
int expert_num,
|
||||
int hidden_size,
|
||||
cnnlDataType_t a_dtype,
|
||||
void *workspace,
|
||||
size_t workspace_size_bytes) {
|
||||
if (is_arch300()) {
|
||||
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (expert_num > 128) {
|
||||
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
|
||||
std::cerr
|
||||
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (workspace_size_bytes > 0 && workspace == NULL) {
|
||||
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
|
||||
"greater than 0."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
int32_t union_number, core_num;
|
||||
getContxtInfo(&union_number, &core_num);
|
||||
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
|
||||
cnrtDim3_t dim;
|
||||
dim.x = (int32_t)func_type;
|
||||
dim.y = 1;
|
||||
dim.z = 1;
|
||||
|
||||
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
|
||||
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
|
||||
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
|
||||
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
|
||||
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
|
||||
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
|
||||
castGatingTileInfo split_info;
|
||||
bool EXCHANGE_AB = false;
|
||||
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
|
||||
workspace_size_bytes, union_number, core_num, split_info.block,
|
||||
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
|
||||
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (EXCHANGE_AB) {
|
||||
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
|
||||
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
|
||||
(float *)workspace, expert_num, input_row, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
} else {
|
||||
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
|
||||
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
|
||||
(float *)workspace, input_row, expert_num, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
}
|
||||
} else if (a_dtype == CNNL_DTYPE_HALF) {
|
||||
if (EXCHANGE_AB) {
|
||||
kernels::MLUCastGating<float, float, half, float, float, float, true>
|
||||
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
|
||||
(float *)workspace, expert_num, input_row, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
} else {
|
||||
kernels::MLUCastGating<half, float, float, float, float, float, false>
|
||||
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
|
||||
(float *)workspace, input_row, expert_num, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
50
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh
Normal file
50
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh
Normal file
@@ -0,0 +1,50 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
#define CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Convert input to float32 and do gating operation.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param input: Input. Pointer to the MLU memory that stores the input,
|
||||
* the shape must be [input_row, hidden_size].
|
||||
* @param filter: Input. Pointer to the MLU memory that stores the weight,
|
||||
* the shape must be [expert_num, hidden_size].
|
||||
* @param output: Output. Pointer to the MLU memory that stores the output,
|
||||
* the shape must be [input_row, expert_num].
|
||||
* @param input_row: Input.
|
||||
* @param expert_num: Input.
|
||||
* @param hidden_size: Input.
|
||||
* @param a_dtype: Input. The data-type of input.
|
||||
* @param workspace: Input. Pointer to the MLU workspace.
|
||||
* @param workspace_size_bytes: Input. The size of workspace in bytes.
|
||||
* @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF.
|
||||
* expert_num must be in range [1, 128].
|
||||
* If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024.
|
||||
* The data-type of filter and output must be float.
|
||||
* cast_gating only supports MLU500 device or higher.
|
||||
*/
|
||||
KernelStatus invokeCastGating(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *filter,
|
||||
void *output,
|
||||
int input_row,
|
||||
int expert_num,
|
||||
int hidden_size,
|
||||
cnnlDataType_t a_dtype,
|
||||
void *workspace,
|
||||
size_t workspace_size_bytes);
|
||||
} // namespace tmo
|
||||
#endif // CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
760
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu
Normal file
760
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu
Normal file
@@ -0,0 +1,760 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "cnrt.h"
|
||||
#include "combine_result.mluh"
|
||||
// clang-format off
|
||||
#include <bang_device_functions_extra.h>
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
#if __BANG_ARCH__ >= 592
|
||||
#include <bang_fusor.h>
|
||||
template <typename SrcT>
|
||||
using bang_fusor = bang::experimental::fusor<SrcT>;
|
||||
#endif
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
#define NRAM_REMAIN_SIZE (32 * 1024)
|
||||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
||||
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void swap(T *&ping, T *&pong) {
|
||||
T *temp = ping;
|
||||
ping = pong;
|
||||
pong = temp;
|
||||
}
|
||||
|
||||
#define GATHER_ASYNC_IO0(offset_type) \
|
||||
__asm__ __volatile__( \
|
||||
"gather.vector.async.nram.gdram.nram." #offset_type \
|
||||
".io0 [%[dst]], [%[src]], [%[offset]], " \
|
||||
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
|
||||
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
|
||||
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
|
||||
|
||||
#define FUSE_MUL_CVT(dst_dtype) \
|
||||
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
|
||||
".f32 [%[dst]], [%[src0]], %[src1]," \
|
||||
" %[size];\n\t" ::[dst] "r"(dst), \
|
||||
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
|
||||
|
||||
#define FUSE_MULADD_CVT(dst_dtype) \
|
||||
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
|
||||
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
|
||||
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
|
||||
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void toFloat(float *dst, T *src, int count) {
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float(dst, (half *)src, count);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void floatTo(T *dst, float *src, int count) {
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_float2half_rn((half *)dst, src, count);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void loadAsync2d(void *dst,
|
||||
void *src,
|
||||
int size,
|
||||
int dststride,
|
||||
int srcstride,
|
||||
int seg_num) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__asm__ __volatile__(
|
||||
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
|
||||
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
|
||||
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
|
||||
[segnum] "r"(seg_num));
|
||||
#else
|
||||
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void storeAsync2d(void *dst,
|
||||
void *src,
|
||||
int size,
|
||||
int dststride,
|
||||
int srcstride,
|
||||
int seg_num) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__asm__ __volatile__(
|
||||
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
|
||||
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
|
||||
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
|
||||
[segnum] "r"(seg_num));
|
||||
#else
|
||||
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T_IDX>
|
||||
__mlu_func__ void gatherTokensAsync(void *dst,
|
||||
void *src_gdram,
|
||||
T_IDX *nram_offset,
|
||||
int transfer_size,
|
||||
int token_count) {
|
||||
if (token_count <= 0 || src_gdram == nullptr) return;
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T_IDX, uint32_t>::value) {
|
||||
GATHER_ASYNC_IO0(u32);
|
||||
} else {
|
||||
GATHER_ASYNC_IO0(u64);
|
||||
}
|
||||
#else
|
||||
for (int k = 0; k < token_count; k++) {
|
||||
__memcpy_async((int8_t *)dst + k * transfer_size,
|
||||
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
|
||||
int *nram_mask,
|
||||
uint8_t *nram_mask_char,
|
||||
int *nram_mask_buffer,
|
||||
int begin_expert_acc_tokens,
|
||||
int end_expert_acc_tokens,
|
||||
int token_count,
|
||||
bool expert_parallelism) {
|
||||
if (!expert_parallelism) {
|
||||
return token_count;
|
||||
}
|
||||
|
||||
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
|
||||
#if __BANG_ARCH__ >= 592
|
||||
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
|
||||
.ge(begin_expert_acc_tokens)
|
||||
.land(nram_mask_buffer)
|
||||
.cvt<float>(0);
|
||||
#else
|
||||
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
|
||||
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
|
||||
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
|
||||
#endif
|
||||
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
|
||||
int active_token_count = __bang_count((float *)nram_mask, token_count);
|
||||
return active_token_count;
|
||||
}
|
||||
|
||||
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
|
||||
int *nram_idx,
|
||||
uint64_t mul_scalar,
|
||||
int64_t add_scalar,
|
||||
uint32_t token_count) {
|
||||
#if __BANG_ARCH__ > 592
|
||||
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
|
||||
#else
|
||||
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
|
||||
#endif
|
||||
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
|
||||
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
|
||||
}
|
||||
|
||||
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
|
||||
int *nram_idx,
|
||||
uint32_t mul_scalar,
|
||||
int64_t add_scalar,
|
||||
uint32_t token_count) {
|
||||
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
|
||||
token_count);
|
||||
}
|
||||
|
||||
template <typename T_IDX>
|
||||
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
|
||||
T_IDX *nram_bias_offset,
|
||||
int *nram_token_idx,
|
||||
int *nram_expert_tables,
|
||||
int expert_num,
|
||||
int token_count,
|
||||
int active_token_count,
|
||||
int hidden_size,
|
||||
int local_hidden_begin,
|
||||
int dtype_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int begin_expert_acc_tokens,
|
||||
bool has_bias) {
|
||||
// for large tensor, convert int322int64 then do multiply and add seperately.
|
||||
if (active_token_count <= 0) return;
|
||||
if (has_bias) {
|
||||
int *nram_bias_offset_temp = (int *)nram_token_offset;
|
||||
__bang_write_zero(nram_bias_offset, active_token_count);
|
||||
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
|
||||
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
|
||||
active_token_count);
|
||||
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
|
||||
active_token_count);
|
||||
}
|
||||
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
|
||||
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
|
||||
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
|
||||
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
|
||||
active_token_count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
FUSE_MUL_CVT(bf16);
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
FUSE_MUL_CVT(f16);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
|
||||
}
|
||||
#else
|
||||
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
|
||||
floatTo((T *)dst, (float *)dst, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
FUSE_MULADD_CVT(bf16);
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
FUSE_MULADD_CVT(f16);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
|
||||
size);
|
||||
}
|
||||
#else
|
||||
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
|
||||
size);
|
||||
floatTo((T *)dst, (float *)dst, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
// weightedReduceSum with EP split
|
||||
// input [token_count, k, hidden_size], weight [token_count, k]
|
||||
// 1. input * weight
|
||||
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
|
||||
template <typename T,
|
||||
bool expert_parallelism,
|
||||
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
|
||||
__mlu_func__ void weightedReduceSum(T *output,
|
||||
T *input,
|
||||
float *weight,
|
||||
T *input_buffer,
|
||||
int8_t *og_mask,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int token_count,
|
||||
bool &is_ping) {
|
||||
float *nram_input_buffer =
|
||||
(float *)((half *)input_buffer +
|
||||
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
|
||||
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
|
||||
int32_t index[32];
|
||||
float reg_weight[128];
|
||||
int8_t *index_ = (int8_t *)index;
|
||||
int topk_divide_4 = PAD_UP(topk, 4) / 4;
|
||||
int token_use_count = 0;
|
||||
for (int t_i = 0; t_i < token_count; t_i++) {
|
||||
float *output_begin = (float *)(output_base + t_i * hidden_size);
|
||||
for (int i = 0; i < topk_divide_4; i++) {
|
||||
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
|
||||
float *weight_begin = weight + t_i * topk + i * 4;
|
||||
reg_weight[i * 4] = __load_nram(weight_begin);
|
||||
if (i * 4 + 1 < topk) {
|
||||
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
|
||||
}
|
||||
if (i * 4 + 2 < topk) {
|
||||
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
|
||||
}
|
||||
if (i * 4 + 3 < topk) {
|
||||
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
|
||||
}
|
||||
}
|
||||
|
||||
int first_in_expert = 0;
|
||||
float expert_coeff;
|
||||
for (; first_in_expert < topk - 1; first_in_expert++) {
|
||||
bool in_expert_range = index_[first_in_expert];
|
||||
if (!in_expert_range) continue;
|
||||
|
||||
expert_coeff = reg_weight[first_in_expert];
|
||||
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
|
||||
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
|
||||
token_use_count++;
|
||||
break;
|
||||
}
|
||||
if (first_in_expert == topk - 1) {
|
||||
if (index_[topk - 1]) {
|
||||
expert_coeff = reg_weight[topk - 1];
|
||||
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
||||
token_use_count++;
|
||||
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
||||
} else {
|
||||
__bang_write_zero((T *)output_begin, hidden_size);
|
||||
}
|
||||
} else {
|
||||
for (int j = first_in_expert + 1; j < topk - 1; j++) {
|
||||
bool in_expert_range = index_[j];
|
||||
if (!in_expert_range) continue;
|
||||
|
||||
expert_coeff = reg_weight[j];
|
||||
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
||||
token_use_count++;
|
||||
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
|
||||
hidden_size, hidden_size);
|
||||
}
|
||||
if (index_[topk - 1]) {
|
||||
expert_coeff = reg_weight[topk - 1];
|
||||
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
||||
token_use_count++;
|
||||
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
||||
} else {
|
||||
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_ping && sizeof(T) < sizeof(float)) {
|
||||
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
|
||||
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0);
|
||||
}
|
||||
is_ping = !is_ping;
|
||||
}
|
||||
|
||||
// weightedReduceSum without EP split
|
||||
// input [token_count, k, hidden_size], weight [token_count, k]
|
||||
// 1. input * weight
|
||||
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
|
||||
template <typename T,
|
||||
bool expert_parallelism,
|
||||
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
|
||||
__mlu_func__ void weightedReduceSum(T *output,
|
||||
T *input,
|
||||
float *weight,
|
||||
T *input_buffer,
|
||||
int8_t *og_mask,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int token_count,
|
||||
bool &is_ping) {
|
||||
float *nram_input_buffer =
|
||||
(float *)((half *)input_buffer +
|
||||
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
|
||||
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
|
||||
if (topk == 1) {
|
||||
for (int i = 0; i < token_count; i++) {
|
||||
float expert_coeff = __load_nram(weight + i);
|
||||
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
|
||||
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (int t_i = 0; t_i < token_count; t_i++) {
|
||||
float *output_begin = (float *)(output_base + t_i * hidden_size);
|
||||
float expert_coeff = __load_nram(weight + t_i * topk);
|
||||
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
|
||||
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
|
||||
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
|
||||
expert_coeff = __load_nram(weight + t_i * topk + 1);
|
||||
for (int k_i = 2; k_i < topk; k_i++) {
|
||||
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
|
||||
hidden_size, hidden_size);
|
||||
expert_coeff = __load_nram(weight + t_i * topk + k_i);
|
||||
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
|
||||
}
|
||||
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
||||
}
|
||||
if (!is_ping && sizeof(T) < sizeof(float)) {
|
||||
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
|
||||
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0);
|
||||
}
|
||||
is_ping = !is_ping;
|
||||
}
|
||||
|
||||
template <typename T, typename T_IDX>
|
||||
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
|
||||
T *input,
|
||||
T *bias,
|
||||
T *residual,
|
||||
float *reduce_weight,
|
||||
int *cusum_token_count,
|
||||
int *gather_idx,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int HIDDEN_BLOCK,
|
||||
int TOKEN_BLOCK) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
|
||||
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
|
||||
int task_avg_tokens = num_token / taskDimY;
|
||||
int task_remain_tokens = num_token % taskDimY;
|
||||
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
|
||||
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
|
||||
if (local_hidden_size <= 0) return;
|
||||
if (task_tokens <= 0) return;
|
||||
|
||||
constexpr int int32_dtype_size = (int)sizeof(int);
|
||||
constexpr int fp32_dtype_size = (int)sizeof(float);
|
||||
int pad_num_expert = PAD_UP(num_expert + 1, 32);
|
||||
bool has_bias = bias != nullptr;
|
||||
bool has_residual = residual != nullptr;
|
||||
bool using_acc_sum = cusum_token_count != nullptr;
|
||||
bool expert_parallelism = expert_size < num_expert;
|
||||
|
||||
int block_size = TOKEN_BLOCK * topk;
|
||||
int pad_block_size = PAD_UP(block_size, 64);
|
||||
|
||||
int *nram_expert_tables = (int *)nram_buffer;
|
||||
int *nram_token_idx = nram_expert_tables + pad_num_expert;
|
||||
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
|
||||
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
|
||||
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
|
||||
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
|
||||
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
|
||||
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
|
||||
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
|
||||
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
|
||||
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
|
||||
float *nram_weight_ping =
|
||||
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
|
||||
float *nram_weight_pong = nram_weight_ping + pad_block_size;
|
||||
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
|
||||
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
|
||||
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
|
||||
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
|
||||
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
|
||||
int *nram_mask_buffer = (int *)nram_token_offset;
|
||||
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
|
||||
|
||||
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
|
||||
int begin_expert_acc_tokens = 0;
|
||||
int end_expert_acc_tokens = num_token * topk;
|
||||
if (using_acc_sum) {
|
||||
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
|
||||
init_token_count * sizeof(int), GDRAM2NRAM);
|
||||
__sync_io();
|
||||
|
||||
if (expert_parallelism) {
|
||||
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
|
||||
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
|
||||
}
|
||||
|
||||
int active_token_count = getMaskAndActiveTokenCount(
|
||||
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
|
||||
end_expert_acc_tokens, init_token_count, expert_parallelism);
|
||||
|
||||
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
|
||||
init_token_count, active_token_count, hidden_size, local_hidden_begin,
|
||||
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
|
||||
__sync_io_move_compute(true, false, false, false, false, true);
|
||||
__sync_io_move_compute(false, false, true, true, false, false);
|
||||
|
||||
int next_active_token_count = active_token_count;
|
||||
int previous_global_token_begin = 0;
|
||||
int previous_token_count = 0;
|
||||
bool is_ping = false;
|
||||
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
|
||||
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
|
||||
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
|
||||
bool is_last_loop = next_token_begin >= task_tokens;
|
||||
bool is_last_2_loop = next_next_token_begin >= task_tokens;
|
||||
int current_token_begin = task_begin * TOKEN_BLOCK;
|
||||
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
|
||||
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
|
||||
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
|
||||
int current_global_token_begin = task_token_begin + current_token_begin;
|
||||
int next_global_token_begin = task_token_begin + next_token_begin;
|
||||
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
|
||||
|
||||
if (!is_last_loop) {
|
||||
if (!is_last_2_loop) {
|
||||
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
|
||||
next_next_token_count * topk * sizeof(int), 0, 0, 0);
|
||||
}
|
||||
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
|
||||
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
|
||||
if (has_residual) {
|
||||
loadAsync2d(nram_residual_ping,
|
||||
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
||||
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
|
||||
hidden_size * sizeof(T), next_token_count - 1);
|
||||
}
|
||||
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
|
||||
local_hidden_size * sizeof(T), next_active_token_count);
|
||||
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
|
||||
local_hidden_size * sizeof(T), next_active_token_count);
|
||||
}
|
||||
|
||||
if (task_begin >= 1) {
|
||||
storeAsync2d(
|
||||
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
||||
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
|
||||
local_hidden_size * sizeof(T), previous_token_count - 1);
|
||||
}
|
||||
|
||||
if (task_begin >= 0) {
|
||||
if (has_bias && active_token_count) {
|
||||
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
|
||||
active_token_count * local_hidden_size);
|
||||
}
|
||||
if (expert_parallelism) {
|
||||
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
|
||||
nram_input_buffer, (int8_t *)nram_mask_char, topk,
|
||||
local_hidden_size, current_token_count, is_ping);
|
||||
} else {
|
||||
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
|
||||
nram_input_buffer, (int8_t *)nram_mask_char, topk,
|
||||
local_hidden_size, current_token_count, is_ping);
|
||||
}
|
||||
if (has_residual) {
|
||||
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
|
||||
current_token_count * local_hidden_size);
|
||||
}
|
||||
}
|
||||
|
||||
__sync_io_move_compute();
|
||||
active_token_count = next_active_token_count;
|
||||
if (expert_parallelism && !is_last_loop) {
|
||||
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
|
||||
}
|
||||
if (!is_last_2_loop) {
|
||||
next_active_token_count = getMaskAndActiveTokenCount(
|
||||
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
|
||||
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
|
||||
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
|
||||
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
|
||||
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
|
||||
begin_expert_acc_tokens, has_bias);
|
||||
}
|
||||
|
||||
swap(nram_input_ping, nram_input_pong);
|
||||
swap(nram_bias_ping, nram_bias_pong);
|
||||
swap(nram_residual_ping, nram_residual_pong);
|
||||
swap(nram_weight_ping, nram_weight_pong);
|
||||
swap(nram_output_ping, nram_output_pong);
|
||||
previous_global_token_begin = current_global_token_begin;
|
||||
previous_token_count = current_token_count;
|
||||
}
|
||||
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
||||
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
|
||||
local_hidden_size * sizeof(T), previous_token_count - 1);
|
||||
}
|
||||
|
||||
#if __BANG_ARCH__ < 500
|
||||
template <>
|
||||
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
|
||||
bfloat16_t *input,
|
||||
bfloat16_t *bias,
|
||||
bfloat16_t *residual,
|
||||
float *reduce_weight,
|
||||
int *cusum_token_count,
|
||||
int *gather_ids,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int HIDDEN_BLOCK,
|
||||
int TOKEN_BLOCK) {}
|
||||
|
||||
template <>
|
||||
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
|
||||
bfloat16_t *input,
|
||||
bfloat16_t *bias,
|
||||
bfloat16_t *residual,
|
||||
float *reduce_weight,
|
||||
int *cusum_token_count,
|
||||
int *gather_ids,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int HIDDEN_BLOCK,
|
||||
int TOKEN_BLOCK) {}
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const void *residual,
|
||||
const float *reduce_weight,
|
||||
const int *cusum_token_count,
|
||||
const int *gather_idx,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
cnnlDataType_t dtype) {
|
||||
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: "
|
||||
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (bias != nullptr) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
|
||||
<< "cusum_token_count can not be nullptr.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
size_t data_bytes = 0;
|
||||
cnnlGetSizeOfDataType(dtype, &data_bytes);
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
|
||||
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
|
||||
int convert_buffer = data_bytes == 2
|
||||
? 3 * hidden_size * data_bytes
|
||||
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
|
||||
int max_input_size = (432 * 1024 - convert_buffer) /
|
||||
(2 * topk * data_bytes + /*input size, double buffer*/
|
||||
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
|
||||
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
|
||||
2 * data_bytes); /*output size, one buffer*/
|
||||
|
||||
int TOKEN_BLOCK = 1;
|
||||
int HIDDEN_BLOCK = 1;
|
||||
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
|
||||
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
|
||||
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
|
||||
TOKEN_BLOCK = 1;
|
||||
} else {
|
||||
HIDDEN_BLOCK = hidden_size;
|
||||
}
|
||||
// for latency case, hidden_size is large but token is small.
|
||||
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
|
||||
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
|
||||
}
|
||||
|
||||
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
|
||||
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
|
||||
task_dim_x =
|
||||
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
|
||||
uint32_t pad_dim_x = task_dim_x;
|
||||
while (pad_dim_x <= cluster_num * core_num) {
|
||||
if ((cluster_num * core_num % pad_dim_x == 0)) {
|
||||
task_dim_x = pad_dim_x;
|
||||
break;
|
||||
}
|
||||
pad_dim_x += core_num;
|
||||
}
|
||||
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
|
||||
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
|
||||
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
|
||||
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
|
||||
}
|
||||
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
|
||||
|
||||
float max_cluster_num = core_num * cluster_num / task_dim_x;
|
||||
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
|
||||
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
|
||||
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
|
||||
|
||||
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
if (!is_large_tensor) {
|
||||
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
} else {
|
||||
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
}
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
if (!is_large_tensor) {
|
||||
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
} else {
|
||||
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
}
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (!is_large_tensor) {
|
||||
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
|
||||
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
|
||||
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
} else {
|
||||
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
|
||||
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
|
||||
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
|
||||
<< "among float/half/bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
85
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh
Normal file
85
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh
Normal file
@@ -0,0 +1,85 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
||||
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Sort tokens grouped by different experts based on index. Each token
|
||||
* selects the topk hidden vectors, multiplies them by corresponding weights,
|
||||
* and finally reduces the topk vectors for each token. This process involves
|
||||
* bias and residual, calculated as (x + bias) * weight + residual.
|
||||
* @example
|
||||
* input:
|
||||
* [[[1, 2, 1, 1],
|
||||
* [1, 1, 1, 2]],
|
||||
* [[2, 1, 1, 1],
|
||||
* [1, 1, 1, 1]]]
|
||||
* num_token = 2, topk = 2
|
||||
* cusum_token_count = [0, 2, 4]
|
||||
* index:
|
||||
* [0, 1, 2, 3]
|
||||
* weight:
|
||||
* [0, 0, 1, 1]
|
||||
* bias:
|
||||
* [[0, 0, 0, 0],
|
||||
* [1, 1, 1, 1]]
|
||||
* residual:
|
||||
* [[1, 1, 1, 1],
|
||||
* [0, 0, 0, 0]]
|
||||
* output:
|
||||
* [[1, 1, 1, 1],
|
||||
* [5, 4, 4, 4]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param output: Output. Pointer to the MLU memory that stores the result.
|
||||
* The shape is [num_token, hidden_size].
|
||||
* @param input: Input. Pointer to the MLU memory that stores input tokens.
|
||||
* The shape is [num_token * topk, hidden_size].
|
||||
* @param bias: Input. Pointer to the MLU memory that stores bias.
|
||||
* The shape is [num_expert, hidden_size].
|
||||
* @param residual: Input. Pointer to the MLU memory that stores residual.
|
||||
* The shape is [num_token, hidden_size].
|
||||
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
|
||||
* The shape is [num_token * topk].
|
||||
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
|
||||
* token number of each expert. The shape is [num_expert + 1].
|
||||
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
|
||||
* The shape is [num_token * topk].
|
||||
* @param num_token: The total number of tokens.
|
||||
* @param topk: The number of expert.
|
||||
* @param num_expert: The number of expert.
|
||||
* @param hidden_size: The size of lowest dimension.
|
||||
* @param start_expert_id: The id of the first processed expert.
|
||||
* @param expert_size: The number of processed experts.
|
||||
* @param dtype: Data type.
|
||||
* @note Currently does not support bias.
|
||||
*/
|
||||
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const void *residual,
|
||||
const float *reduce_weight,
|
||||
const int *cusum_token_count,
|
||||
const int *gather_idx,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
cnnlDataType_t dtype);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
||||
219
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
Normal file
219
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
Normal file
@@ -0,0 +1,219 @@
|
||||
#include <bang_device_functions_extra.h>
|
||||
#include <mlu.h>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "expand_input.mluh"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#define RESERVED_SIZE (32 * 1024)
|
||||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
||||
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
||||
#define MEMCPY_BURST_SIZE 128
|
||||
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
// T_offset: uint32_t or uint64_t
|
||||
template <size_t data_size, typename T_offset>
|
||||
__mlu_func__ void ExpandInputKernel(void *output,
|
||||
void *input,
|
||||
int *index,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int num_index) {
|
||||
void *input_ptr = input;
|
||||
uint64_t input_size = data_size * num_token * hidden_size;
|
||||
// whether SRAM_BUFFER_SIZE can hold the input data
|
||||
// sram_enable is true ==> is_nram_output is true
|
||||
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
|
||||
if (sram_enable) {
|
||||
input_ptr = (void *)sram_buffer;
|
||||
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
|
||||
__sync_cluster();
|
||||
}
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
|
||||
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
|
||||
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
|
||||
uint32_t total_num = num_index;
|
||||
uint32_t base = total_num / maxTaskDim;
|
||||
uint32_t tail = total_num - base * maxTaskDim;
|
||||
if (taskId >= maxTaskDim) {
|
||||
return;
|
||||
}
|
||||
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
|
||||
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
|
||||
|
||||
// nram
|
||||
/*
|
||||
* first: compute offset: index[i] * data_size
|
||||
* second: gather data
|
||||
* -------------------------------------------------
|
||||
* addr || index/offset | output |
|
||||
* type || int32_t/T_offset | T |
|
||||
* num || n | n * hidden_size |
|
||||
* -------------------------------------------------
|
||||
*/
|
||||
|
||||
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
|
||||
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
|
||||
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
|
||||
uint32_t per_num =
|
||||
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
|
||||
|
||||
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
|
||||
int *index_base = index + batch_step;
|
||||
T_offset *nram_offset = (T_offset *)nram_buffer;
|
||||
int32_t *nram_index;
|
||||
if (std::is_same<T_offset, int64_t>::value) {
|
||||
nram_index = (int32_t *)nram_offset + per_num;
|
||||
} else {
|
||||
nram_index = (int32_t *)nram_offset;
|
||||
}
|
||||
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
|
||||
|
||||
uint32_t repeat = batch_per_core / per_num;
|
||||
uint32_t remain = batch_per_core - repeat * per_num;
|
||||
uint32_t deal_num = per_num;
|
||||
uint32_t is_remain = remain != 0 ? 1 : 0;
|
||||
|
||||
for (int32_t i = 0; i < repeat + is_remain; i++) {
|
||||
if (i == repeat) {
|
||||
deal_num = remain;
|
||||
}
|
||||
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
|
||||
int32_t *index_ptr = index_base + i * per_num;
|
||||
|
||||
// index -> offset
|
||||
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
|
||||
if (std::is_same<T_offset, uint64_t>::value) {
|
||||
#if __BANG_ARCH__ > 592
|
||||
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
|
||||
#else
|
||||
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
|
||||
#endif
|
||||
}
|
||||
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
|
||||
|
||||
// copy
|
||||
if (is_nram_output) {
|
||||
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
|
||||
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
|
||||
// GDRAM or SRAM -> NRAM -> GDRAM
|
||||
#if __BANG_ARCH__ >= 592 // gather requires
|
||||
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
|
||||
hidden_size * data_size, deal_num);
|
||||
#else
|
||||
for (int32_t j = 0; j < deal_num; j++) {
|
||||
T_offset offset_value = *(nram_offset + j);
|
||||
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
|
||||
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
|
||||
dir);
|
||||
}
|
||||
#endif // __BANG_ARCH__
|
||||
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
|
||||
} else {
|
||||
// GDRAM -> GDRAM
|
||||
#if __BANG_ARCH__ >= 592 // gather requires
|
||||
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
|
||||
hidden_size * data_size, deal_num);
|
||||
#else
|
||||
for (int32_t j = 0; j < deal_num; j++) {
|
||||
T_offset offset_value = *(nram_offset + j);
|
||||
int8_t *input_offset = (int8_t *)input + offset_value;
|
||||
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
|
||||
hidden_size * data_size, GDRAM2GDRAM);
|
||||
}
|
||||
#endif // __BANG_ARCH__
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// T_offset: uint32_t or uint64_t
|
||||
template <size_t data_size, typename T_offset>
|
||||
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
|
||||
void *hidden_state,
|
||||
int *gather_idx,
|
||||
int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count) {
|
||||
int32_t num_index = num_token * topk;
|
||||
int *gather_start_idx = (int *)gather_idx;
|
||||
if (cusum_token_count != nullptr) {
|
||||
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
|
||||
*((int *)cusum_token_count + start_expert_id);
|
||||
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
|
||||
}
|
||||
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
|
||||
num_token, hidden_size, num_index);
|
||||
}
|
||||
// instantiate kernels
|
||||
#define INSTANTIATE_ONE(data_size, T_offset) \
|
||||
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
|
||||
void *, void *, int *, int *, int, int, int, int, int, int);
|
||||
|
||||
INSTANTIATE_ONE(1, uint32_t)
|
||||
INSTANTIATE_ONE(2, uint32_t)
|
||||
INSTANTIATE_ONE(4, uint32_t)
|
||||
INSTANTIATE_ONE(8, uint32_t)
|
||||
// large tensor
|
||||
INSTANTIATE_ONE(1, uint64_t)
|
||||
INSTANTIATE_ONE(2, uint64_t)
|
||||
INSTANTIATE_ONE(4, uint64_t)
|
||||
INSTANTIATE_ONE(8, uint64_t)
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
|
||||
void *expand_hidden_state,
|
||||
const void *hidden_state,
|
||||
const int *gather_idx,
|
||||
const int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
cnnlDataType_t data_type,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
size_t type_size = 0;
|
||||
cnnlGetSizeOfDataType(data_type, &type_size);
|
||||
|
||||
int max_cluster_num =
|
||||
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
|
||||
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
|
||||
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
|
||||
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
|
||||
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
|
||||
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
|
||||
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
|
||||
|
||||
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
|
||||
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
|
||||
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
|
||||
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
|
||||
expert_count);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
81
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh
Normal file
81
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh
Normal file
@@ -0,0 +1,81 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
|
||||
#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
/**
|
||||
* @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count.
|
||||
* @example
|
||||
* hidden_state:
|
||||
* [[1, 2, 3, 4],
|
||||
* [5, 6, 7, 8],
|
||||
* [9, 10, 11, 12]]
|
||||
* gather_idx:
|
||||
* [[1, 0, 2, 2, 1, 0]]
|
||||
* cusum_token_count: NULL
|
||||
* num_token = 3
|
||||
* hidden_size = 4
|
||||
* topk = 2
|
||||
* expand_hidden_state:
|
||||
* [[5, 6, 7, 8],
|
||||
* [1, 2, 3, 4],
|
||||
* [9, 10, 11, 12],
|
||||
* [9, 10, 11, 12],
|
||||
* [5, 6, 7, 8],
|
||||
* [1, 2, 3, 4]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param hidden_state: Input. Pointer to the MLU memory that store the input,
|
||||
* the shape must be [num_token, hidden_size].
|
||||
* @param gather_idx: Input. Pointer to the MLU memory that stores the index,
|
||||
* the shape must be [num_token * topk].
|
||||
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of
|
||||
* token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The
|
||||
* gather operation will be performed as follows: if cusum_token_count is not NULL: index =
|
||||
* gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]]
|
||||
* expand_hidden_state = hidden_state[index]
|
||||
* else:
|
||||
* index = gather_idx[:]
|
||||
* expand_hidden_state = hidden_state[index]
|
||||
* @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output,
|
||||
* if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in
|
||||
* which num_index =
|
||||
* cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise,
|
||||
* the shape should be [num_token * topk, hidden_size].
|
||||
* @param num_token: the number of token.
|
||||
* @param hidden_size: the slice size.
|
||||
* @param topk: the number of topk.
|
||||
* @param data_type: Data type of hidden_state.
|
||||
* @param total_expert_num: the total number of expert.
|
||||
* @param start_expert_id: the first expert id.
|
||||
* @param expert_count: the number of experts currently being processed.
|
||||
*/
|
||||
|
||||
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
|
||||
void *expand_hidden_state,
|
||||
const void *hidden_state,
|
||||
const int *gather_idx,
|
||||
const int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
cnnlDataType_t data_type,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
|
||||
935
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
Normal file
935
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
Normal file
@@ -0,0 +1,935 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "gen_idx.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
|
||||
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
|
||||
#define ALIGN_16 (16)
|
||||
|
||||
#define EXPERT_AVG_COUNT_TEST (0)
|
||||
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
|
||||
|
||||
// Generate integer sequence data from 0 to length-1
|
||||
__mlu_func__ void generateIntSeq(int *dst, int length) {
|
||||
int count = 64;
|
||||
__bang_move(dst, range, std::min(count, length) * sizeof(int));
|
||||
while (count < length) {
|
||||
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
|
||||
count *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
// genIdx Block kernel, use only 1 core to process
|
||||
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
/* NRAM space */
|
||||
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
|
||||
// --------------------------------------------------------------
|
||||
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
|
||||
// | combine_idx | expand_idx | | scatter_offset |
|
||||
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
|
||||
// --------------------------------------------------------------
|
||||
// ------------------------------
|
||||
// |token_count|token_count_presum|
|
||||
// | | |
|
||||
// | num_expert| num_expert |
|
||||
// ------------------------------
|
||||
|
||||
uint32_t token_total_num = num_token * topk;
|
||||
// num align to 16, size align to 64B
|
||||
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
|
||||
|
||||
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
||||
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
|
||||
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
|
||||
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
|
||||
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
|
||||
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
|
||||
|
||||
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
|
||||
#if __BANG_ARCH__ >= 592
|
||||
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
|
||||
#endif
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
|
||||
|
||||
// Load current core input expert_id and generate int sequence
|
||||
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
||||
__sync();
|
||||
|
||||
// Initialize sort idx offset
|
||||
uint32_t sorted_idx_offset = 0;
|
||||
// Initialize token count first presum with 0
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
|
||||
|
||||
// Loop on each expert, eq, count, filter index
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
|
||||
token_total_num);
|
||||
// Use filter to sort gen_idx, output with sorted_idx_offset
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, token_total_num);
|
||||
|
||||
sorted_idx_offset += cur_expert_count;
|
||||
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
|
||||
|
||||
// Compute cusum token count and store
|
||||
if (need_cusum_token_count) {
|
||||
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
|
||||
}
|
||||
}
|
||||
|
||||
#if EXPERT_AVG_COUNT_TEST
|
||||
// NOTE: test avg expert code here:
|
||||
uint32_t token_count_avg = token_total_num / num_expert;
|
||||
uint32_t expert_remain_num = token_total_num % num_expert;
|
||||
|
||||
for (int i = 0; i < num_expert; i++) {
|
||||
((int *)token_count_onchip)[i] =
|
||||
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
||||
((int *)token_count_presum_onchip)[i + 1] =
|
||||
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
__sync_compute();
|
||||
// Store token_count and cusum token count
|
||||
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
if (need_cusum_token_count) {
|
||||
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
// Use sorted idx to generate gather idx for expand and combine
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// scatter_offset = sorted_idx mul_scalar sizeof(int);
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
token_total_num);
|
||||
#else
|
||||
// scatter dst GDRAM addr should align to 64B
|
||||
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
||||
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
||||
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
|
||||
#endif
|
||||
__sync_compute();
|
||||
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// scatter_async to NRAM
|
||||
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
||||
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
|
||||
#endif
|
||||
// expand_idx = sorted_idx div(topk)
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
|
||||
|
||||
// Store expand_idx and combine_idx
|
||||
__sync_compute();
|
||||
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__sync_move();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
|
||||
token_total_num * sizeof(int), NRAM2GDRAM);
|
||||
#else
|
||||
// 370 directly scatter to GDRAM
|
||||
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Only MLU500 series support NRAM2SRAM scatter direction
|
||||
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// When length larger than 65535(maximum segnum in bang_scatter),
|
||||
// and src/offset address should align to 64B
|
||||
int seg_repeat = length / 32768;
|
||||
int seg_remain = length % 32768;
|
||||
int seg_offset = 0;
|
||||
|
||||
for (int seg = 0; seg < seg_repeat; seg++) {
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
|
||||
seg_offset += 32768;
|
||||
}
|
||||
|
||||
if (seg_remain > 0) {
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Scatter sequence, transfer size is sizeof(int)
|
||||
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
|
||||
// When length larger than 65535(maximum segnum in bang_scatter),
|
||||
// and src/offset address should align to 64B
|
||||
int seg_repeat = length / 32768;
|
||||
int seg_remain = length % 32768;
|
||||
int seg_offset = 0;
|
||||
|
||||
for (int seg = 0; seg < seg_repeat; seg++) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
||||
#else
|
||||
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
||||
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
||||
#endif
|
||||
seg_offset += 32768;
|
||||
}
|
||||
if (seg_remain > 0) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
#else
|
||||
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
||||
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Get token count
|
||||
__mlu_func__ void getTokenCount(int *token_count,
|
||||
int *expert_id,
|
||||
int token_cur_core,
|
||||
int cur_token_start,
|
||||
int num_expert) {
|
||||
// 1. Partition on [num_token*topk],
|
||||
// each core for-loop on all expert_id, use eq and count instructions,
|
||||
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// ------------------------------------------------------
|
||||
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
|
||||
// | deal_num | deal_num | num_expert |
|
||||
// ------------------------------------------------------
|
||||
|
||||
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
|
||||
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
||||
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
|
||||
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
|
||||
|
||||
// Current core data loop
|
||||
uint32_t repeat = token_cur_core / deal_num;
|
||||
uint32_t remain = token_cur_core % deal_num;
|
||||
uint32_t total_repeat = repeat + (int)(remain > 0);
|
||||
uint32_t token_addr_offset = cur_token_start;
|
||||
|
||||
// Initialize token_count with 0
|
||||
if (taskId == 0) {
|
||||
__gdramset((int *)token_count, num_expert, 0);
|
||||
}
|
||||
// Sync for initialize token_count
|
||||
__sync_all_ipu();
|
||||
|
||||
// Initialize expert count onchip with 0
|
||||
if (token_cur_core > 0) {
|
||||
__bang_write_zero((int *)expert_count_onchip, num_expert);
|
||||
}
|
||||
|
||||
// actual num in loop
|
||||
int cur_deal_num = deal_num;
|
||||
for (int i = 0; i < total_repeat; i++) {
|
||||
if (i == total_repeat - 1 && remain > 0) {
|
||||
cur_deal_num = remain;
|
||||
}
|
||||
// Load current core input expert_id
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
cur_deal_num * sizeof(int), GDRAM2NRAM);
|
||||
token_addr_offset += cur_deal_num;
|
||||
|
||||
// Loop on each expert, eq, count
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
|
||||
// NOTE: __bang_count() only support floating data type
|
||||
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
|
||||
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
|
||||
}
|
||||
}
|
||||
|
||||
// AtomicAdd(reduce) all cores token count results
|
||||
if (token_cur_core > 0) {
|
||||
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
|
||||
}
|
||||
// Sync for all cores, get accumulate of token_count
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
// 2. Get token count presum, for each expert index start address after sorting
|
||||
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
|
||||
int *token_count,
|
||||
const int num_expert) {
|
||||
// 2. After first process, already get token_count.
|
||||
// Then use one core to pre-sum on token_count, consider size of int32,
|
||||
// first expert id start address should be zero.
|
||||
// to get each expert id start address after sorting, store to workspace,
|
||||
// token_count_presum.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// load token_count to token_count_presum[1~num_expert+1],
|
||||
// for i = 0 to num_expert:
|
||||
// token_count_presum[i+1] += token_count_presum[i]
|
||||
// store token_count_presum[0~num_expert]
|
||||
// -------------------------
|
||||
// |token_count_presum_onchip|
|
||||
// | {0}, num_expert |
|
||||
// -------------------------
|
||||
|
||||
if (taskId == 0) {
|
||||
// Initialize count presum onchip with a first 0
|
||||
int8_t *token_count_presum_onchip = nram_buffer;
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
// Load token_count with an offset of 1
|
||||
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// Calculate presum of token count by each expert
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
((int *)token_count_presum_onchip)[cur_expert + 1] +=
|
||||
((int *)token_count_presum_onchip)[cur_expert];
|
||||
}
|
||||
|
||||
// Store token count presum to workspace
|
||||
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
// Sync for all cores, get presum of token count
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
|
||||
int *token_count,
|
||||
const uint32_t token_total_num,
|
||||
const int num_expert) {
|
||||
uint32_t token_count_avg = token_total_num / num_expert;
|
||||
uint32_t expert_remain_num = token_total_num % num_expert;
|
||||
|
||||
int8_t *token_count_onchip = nram_buffer;
|
||||
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
|
||||
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
|
||||
for (int i = 0; i < num_expert; i++) {
|
||||
((int *)token_count_onchip)[i] =
|
||||
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
||||
((int *)token_count_presum_onchip)[i + 1] =
|
||||
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
||||
}
|
||||
|
||||
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
|
||||
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
// 3. Get expert position index after sorting
|
||||
__mlu_func__ void getSortedIdx(int *sorted_idx,
|
||||
int *expert_id,
|
||||
int *token_count_presum,
|
||||
const int token_total_num,
|
||||
const int num_expert,
|
||||
const int expert_cur_core,
|
||||
const int cur_expert_start,
|
||||
const int cur_expert_end) {
|
||||
// 3. Partition on num_expert, each core generate position index from 0,
|
||||
// and for-loop on all expert_id data, use eq with own each expert_id,
|
||||
// and filter on index, stores to each expert_id start address of
|
||||
// sorted_idx on workspace.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// -------------------------------------------------------------------
|
||||
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
|
||||
// | deal_num | deal_num | deal_num | deal_num |
|
||||
// -------------------------------------------------------------------
|
||||
// |expert_start_addr|
|
||||
// | num_expert |
|
||||
// -----------------
|
||||
|
||||
// Calculate new deal_num of sorting process
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
|
||||
|
||||
// Each core deal with whole token expert_id data
|
||||
int repeat = token_total_num / deal_num;
|
||||
int remain = token_total_num % deal_num;
|
||||
int token_addr_offset = 0;
|
||||
|
||||
int8_t *expert_id_onchip = nram_buffer;
|
||||
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
|
||||
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
|
||||
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
|
||||
|
||||
// When num_expert < taskDim, not all cores need to sort
|
||||
if (expert_cur_core > 0) {
|
||||
// Generate position index from 0
|
||||
if (deal_num <= token_total_num) {
|
||||
generateIntSeq((int *)gen_idx_onchip, deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
||||
}
|
||||
|
||||
// Initialize expert start address with presum of token count
|
||||
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core expert_id
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
token_addr_offset += deal_num;
|
||||
|
||||
// Loop for current core expert, eq, filter position index
|
||||
// use filter, store to sorted_idx[expert_start_addr]
|
||||
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
|
||||
|
||||
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
||||
|
||||
// NOTE: __bang_filter() only support floating data type
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, deal_num);
|
||||
|
||||
// Store to the corresponding address of sorted_idx
|
||||
if (cur_expert_count > 0) {
|
||||
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
||||
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
||||
|
||||
// Update address offset of current expert
|
||||
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
|
||||
}
|
||||
}
|
||||
|
||||
// Update position index for each data loop
|
||||
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
|
||||
|
||||
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
||||
|
||||
// NOTE: __bang_filter() only support floating data type
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, remain);
|
||||
// Store to the corresponding address of sorted_idx
|
||||
if (cur_expert_count > 0) {
|
||||
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
||||
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sync for all cores, get position index after sorting
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
// 4. Get gather index for expand and combine
|
||||
template <bool is_sram_scatter>
|
||||
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start,
|
||||
const int topk) {
|
||||
// 4. Partition on [num_token*topk],
|
||||
// load sorted_idx onchip,
|
||||
// generate sequence according to position index from 0, add token offset
|
||||
// gather_combine_idx = scatter(seq, sorted_idx)
|
||||
// gather_expand_idx = sorted_idx / topk
|
||||
// update sequence
|
||||
// NRAM:
|
||||
// -------------------------------------------------------------------
|
||||
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
|
||||
// | deal_num | deal_num | deal_num | deal_num |
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
// scatter dst GDRAM addr should align to 64B
|
||||
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
||||
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
||||
|
||||
int8_t *sorted_idx_onchip = nram_buffer;
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
||||
|
||||
// Generate position index from 0
|
||||
// Add base offset to sequence according to current core token start address
|
||||
if (token_cur_core > 0) {
|
||||
if (deal_num <= token_cur_core) {
|
||||
generateIntSeq((int *)scatter_sequence, deal_num);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
token_cur_core);
|
||||
}
|
||||
}
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
if (is_sram_scatter) {
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
deal_num);
|
||||
} else {
|
||||
// GDRAM addr should align to 64B
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
|
||||
}
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
if (is_sram_scatter) {
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
deal_num);
|
||||
} else {
|
||||
// Scatter to output gather_combine_idx
|
||||
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
||||
(uint32_t *)scatter_offset, deal_num);
|
||||
}
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
deal_num * sizeof(int), NRAM2GDRAM);
|
||||
if (is_sram_scatter) {
|
||||
// if scatter to SRAM, need to sync compute with mv
|
||||
__sync_move();
|
||||
}
|
||||
// Add offset to sequence and token_address
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
if (is_sram_scatter) {
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
remain);
|
||||
} else {
|
||||
// GDRAM addr should align to 64B
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), remain);
|
||||
}
|
||||
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
if (is_sram_scatter) {
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
remain);
|
||||
} else {
|
||||
// Scatter to output gather_combine_idx
|
||||
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
||||
(uint32_t *)scatter_offset, remain);
|
||||
}
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
remain * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.1 Get gather combine index on SRAM
|
||||
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start) {
|
||||
// 4.1 Partition on [num_token*topk], with only 1 union
|
||||
// load sorted_idx onchip,
|
||||
// generate sequence according to position index from 0, add token offset
|
||||
// gather_combine_idx = scatter(seq, sorted_idx)
|
||||
// update sequence
|
||||
// NRAM:
|
||||
// -------------------------------
|
||||
// |scatter_offset|scatter_sequence|
|
||||
// | deal_num | deal_num |
|
||||
// -------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
int8_t *scatter_offset = nram_buffer;
|
||||
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
||||
|
||||
// Generate position index from 0
|
||||
// Add base offset to sequence according to current core token start address
|
||||
if (token_cur_core > 0) {
|
||||
if (deal_num <= token_cur_core) {
|
||||
generateIntSeq((int *)scatter_sequence, deal_num);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
token_cur_core);
|
||||
}
|
||||
}
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
// Scatter to SRAM
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
deal_num);
|
||||
__sync_move();
|
||||
|
||||
// Add offset to sequence and token_address
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.2 Get gather expand index
|
||||
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
|
||||
int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start,
|
||||
const int topk) {
|
||||
// 4.2 Partition on [num_token*topk],
|
||||
// load sorted_idx onchip,
|
||||
// gather_expand_idx = sorted_idx / topk
|
||||
// NRAM:
|
||||
// -----------------------------------
|
||||
// |sorted_idx_onchip|expand_idx_onchip|
|
||||
// | deal_num | deal_num |
|
||||
// -----------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
int8_t *sorted_idx_onchip = nram_buffer;
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
deal_num * sizeof(int), NRAM2GDRAM);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
remain * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
// Store token count presum result, shape [num_expert + 1]
|
||||
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
|
||||
// Store position index after sorting, shape [num_token*topk]
|
||||
int *sorted_idx = ((int *)workspace) + num_expert + 1;
|
||||
|
||||
// Calculate partition information for different processes
|
||||
// Partition on [num_token*topk]
|
||||
uint32_t token_total_num = num_token * topk;
|
||||
uint32_t token_cur_core = token_total_num / taskDim;
|
||||
uint32_t token_remain_num = token_total_num % taskDim;
|
||||
token_cur_core += (uint32_t)(taskId < token_remain_num);
|
||||
// Current core range according to partition on [num_token*topk]
|
||||
uint32_t cur_token_start = (taskId < token_remain_num)
|
||||
? token_cur_core * taskId
|
||||
: token_cur_core * taskId + token_remain_num;
|
||||
|
||||
// Partition on [num_expert]
|
||||
uint32_t expert_cur_core = num_expert / taskDim;
|
||||
uint32_t expert_remain_num = num_expert % taskDim;
|
||||
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
|
||||
// Current core range according to partition on [num_expert]
|
||||
uint32_t cur_expert_start = (taskId < expert_remain_num)
|
||||
? expert_cur_core * taskId
|
||||
: expert_cur_core * taskId + expert_remain_num;
|
||||
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
|
||||
|
||||
// Use Union1 SRAM to scatter, only MLU500 series support now
|
||||
#if __BANG_ARCH__ >= 592
|
||||
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
|
||||
#else
|
||||
bool is_sram_scatter = false;
|
||||
#endif
|
||||
|
||||
if (__is_ipu()) {
|
||||
// 1. Get token count
|
||||
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
|
||||
num_expert);
|
||||
// 2. Get presum of token count
|
||||
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
|
||||
|
||||
// 3. Get expert position index after sorting
|
||||
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
|
||||
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
|
||||
}
|
||||
|
||||
#if EXPERT_AVG_COUNT_TEST
|
||||
// NOTE: test avg expert code here:
|
||||
if (__is_ipu() && taskId == 0) {
|
||||
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
|
||||
num_expert);
|
||||
}
|
||||
__sync_cluster();
|
||||
#endif
|
||||
|
||||
// 4. Get gather index for expand and combine
|
||||
if (is_sram_scatter) {
|
||||
// Only use Union1 SRAM
|
||||
uint32_t scatter_idx_cur_core = token_total_num / 4;
|
||||
uint32_t scatter_idx_remain_num = token_total_num % 4;
|
||||
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
|
||||
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
|
||||
? scatter_idx_cur_core * taskId
|
||||
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
|
||||
|
||||
// Only Union1 task type,
|
||||
// deal once num is same with deal_num in getGatherIdx,
|
||||
// which means only 1 repeat to generate both expand and combine idx on NRAM
|
||||
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
||||
if (taskDim <= 4 || token_total_num < deal_once_num) {
|
||||
if (taskId < 4) {
|
||||
if (__is_ipu()) {
|
||||
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
||||
scatter_idx_cur_core, cur_idx_start, topk);
|
||||
// sync for ipu and mpu
|
||||
__sync_cluster();
|
||||
} else {
|
||||
// sync for ipu and mpu
|
||||
__sync_cluster();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
||||
token_total_num * sizeof(int), SRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If taskDim > 4, use first union to generate combine idx,
|
||||
// use other union to generate expand idx
|
||||
if (taskId < 4) {
|
||||
if (__is_ipu()) {
|
||||
// Scatter combine idx to SRAM
|
||||
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
|
||||
__sync_cluster();
|
||||
} else {
|
||||
__sync_cluster();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
||||
token_total_num * sizeof(int), SRAM2GDRAM);
|
||||
}
|
||||
} else {
|
||||
// Other union generate expand idx
|
||||
if (__is_ipu()) {
|
||||
uint32_t expand_dim = taskDim - 4;
|
||||
uint32_t expand_id = taskId - 4;
|
||||
uint32_t expand_token_cur_core = token_total_num / expand_dim;
|
||||
uint32_t expand_token_remain_num = token_total_num % expand_dim;
|
||||
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
|
||||
|
||||
uint32_t expand_cur_token_start =
|
||||
(expand_id < expand_token_remain_num)
|
||||
? expand_token_cur_core * expand_id
|
||||
: expand_token_cur_core * expand_id + expand_token_remain_num;
|
||||
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
|
||||
expand_cur_token_start, topk);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// not use SRAM to generate both expand and combine idx
|
||||
if (__is_ipu()) {
|
||||
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
||||
token_cur_core, cur_token_start, topk);
|
||||
}
|
||||
}
|
||||
|
||||
// step 5 does not need MPU
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
} // end of kernel
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
||||
int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
const int token_total_num = num_token * topk;
|
||||
|
||||
// For partition on num_token*topk, single core processes at least 128 num
|
||||
const int single_core_num_limit = 1024;
|
||||
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
|
||||
// When partition on num_expert, each core at least processes one expert
|
||||
need_core_num = std::max(num_expert, need_core_num);
|
||||
|
||||
// When consider UnionX cnrt func type, reset cluster_num
|
||||
if (token_total_num <= 4096) { // Block
|
||||
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
|
||||
cnrtDim3_t k_dim{1, 1, 1};
|
||||
// Block kernel does not need workspace
|
||||
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
|
||||
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
|
||||
num_expert, topk);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
} else if (need_core_num <= 4) { // Union1
|
||||
cluster_num = 1;
|
||||
} else if (need_core_num <= 8) { // Union2
|
||||
cluster_num = std::min(cluster_num, 2);
|
||||
} else if (need_core_num <= 16) { // Union4
|
||||
cluster_num = std::min(cluster_num, 4);
|
||||
} else if (need_core_num <= 32) { // Union8
|
||||
cluster_num = std::min(cluster_num, 8);
|
||||
}
|
||||
|
||||
cnrtFunctionType_t k_type;
|
||||
cnrtDim3_t k_dim{1, 1, 1};
|
||||
|
||||
// Find max UnionX cnrt func type
|
||||
if (cluster_num == 1) {
|
||||
k_type = cnrtFuncTypeUnion1;
|
||||
k_dim.x = 4;
|
||||
} else if (cluster_num < 4) { // cluster num is 2 or 3
|
||||
k_type = cnrtFuncTypeUnion2;
|
||||
k_dim.x = 8;
|
||||
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
|
||||
k_type = cnrtFuncTypeUnion4;
|
||||
k_dim.x = 16;
|
||||
} else { // cluster num larger than 8
|
||||
k_type = cnrtFuncTypeUnion8;
|
||||
k_dim.x = 32;
|
||||
}
|
||||
|
||||
// The expert_id is int data type
|
||||
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
|
||||
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
|
||||
num_token, num_expert, topk);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
#undef EXPERT_AVG_COUNT_TEST // undef test macro
|
||||
|
||||
} // namespace tmo
|
||||
58
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
Normal file
58
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
Normal file
@@ -0,0 +1,58 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
|
||||
#include <vector>
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Apply generate MOE index operation, which performs the following
|
||||
* tasks:
|
||||
* - 1. Generate gather_expand_idx and gather_combine_idx.
|
||||
* - 2. Output token_count, the token number of each expert.
|
||||
* - 3. Prepare inputs and outputs address for group_gemm.
|
||||
* @param queue: The queue of mlu.
|
||||
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
|
||||
* gather index for expand hidden state operation, the shape must be
|
||||
* [num_token * topk].
|
||||
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
|
||||
* gather index for combine MOE operation, the shape must be
|
||||
* [num_token * topk].
|
||||
* @param token_count: Output. Pointer to the MLU memory that stores the token
|
||||
* number of each expert, the shape must be [num_expert].
|
||||
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
|
||||
* cumulative sum of the token number of each expert, the shape must be
|
||||
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
|
||||
* @param workspace: Input. A pointer to the extra workspace required in the
|
||||
* operation, the size must be larger than
|
||||
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
|
||||
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
|
||||
* of each token, the shape must be [num_token, topk].
|
||||
* @param num_token: The number of tokens.
|
||||
* @param num_expert: The number of experts.
|
||||
* @param topk: The number of expert selected by each token.
|
||||
*/
|
||||
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
||||
int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
21
torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh
Normal file
21
torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh
Normal file
@@ -0,0 +1,21 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_MOE_MLUH_
|
||||
#define CSRC_KERNELS_MOE_MOE_MLUH_
|
||||
|
||||
#include "add_bias_activation.mluh"
|
||||
#include "combine_result.mluh"
|
||||
#include "expand_input.mluh"
|
||||
#include "gen_idx.mluh"
|
||||
#include "softmax_topk.mluh"
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_MOE_MLUH_
|
||||
602
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu
Normal file
602
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu
Normal file
@@ -0,0 +1,602 @@
|
||||
#include <mlu.h>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "softmax_topk.mluh"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
#define SCATTER_ALIGN (64) // align for __scatter()
|
||||
|
||||
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
#define TILING_ALIGN (64)
|
||||
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
|
||||
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
|
||||
|
||||
#define __TRANS_TILING(TYPE, CONVERT) \
|
||||
__asm__ volatile("trans.tiling." TYPE \
|
||||
" [%[dst]], [%[src]]," \
|
||||
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
|
||||
"%[is4], %[in5], %[is5]," \
|
||||
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
|
||||
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
|
||||
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
|
||||
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
|
||||
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
|
||||
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
|
||||
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
|
||||
|
||||
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
|
||||
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
|
||||
const SRC_DTYPE *src,
|
||||
const uint32_t in0,
|
||||
const uint32_t in1,
|
||||
const uint32_t is1,
|
||||
const uint32_t in2,
|
||||
const uint32_t is2,
|
||||
const uint32_t in3,
|
||||
const uint32_t is3,
|
||||
const uint32_t in4,
|
||||
const uint32_t is4,
|
||||
const uint32_t in5,
|
||||
const uint32_t is5,
|
||||
const uint32_t dn0,
|
||||
const uint32_t dn1,
|
||||
const uint32_t ds1,
|
||||
const uint32_t dn2,
|
||||
const uint32_t ds2,
|
||||
const uint32_t dn3,
|
||||
const uint32_t ds3,
|
||||
const uint32_t dn4,
|
||||
const uint32_t ds4,
|
||||
const uint32_t dn5,
|
||||
const uint32_t ds5) {
|
||||
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
|
||||
if (std::is_same<SRC_DTYPE, float>::value) {
|
||||
__TRANS_TILING("nram.sram.b32", ";")
|
||||
} else if (std::is_same<SRC_DTYPE, half>::value) {
|
||||
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
|
||||
#if __BANG_ARCH__ >= 500
|
||||
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
|
||||
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* 将shape为[h,w]的数据转置为[w,h](带转数),分4块分别进行处理。
|
||||
* dst: dst地址
|
||||
* src: src地址
|
||||
* h: h方向大小
|
||||
* w: w方向大小
|
||||
*/
|
||||
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
|
||||
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
|
||||
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
|
||||
uint32_t w_align = w / align_num;
|
||||
uint32_t w_rem = w % align_num;
|
||||
uint32_t h_align = h / align_num;
|
||||
uint32_t h_rem = h % align_num;
|
||||
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
|
||||
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
|
||||
uint32_t in3 = w_align, is3 = TILING_ALIGN;
|
||||
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
|
||||
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
|
||||
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
|
||||
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
|
||||
/* 1. h_align * w_align */
|
||||
if (w_align > 0 && h_align > 0) {
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
|
||||
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
|
||||
}
|
||||
/* 2. h_align * w_rem */
|
||||
if (w_rem > 0 && h_align > 0) {
|
||||
SRC_DTYPE *src_temp = src + w_align * align_num;
|
||||
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
|
||||
in0 = w_rem * sizeof(SRC_DTYPE);
|
||||
dn0 = TILING_ALIGN;
|
||||
in1 = align_num;
|
||||
is1 = w * sizeof(SRC_DTYPE);
|
||||
in4 = h_align;
|
||||
is4 = w * TILING_ALIGN;
|
||||
dn1 = w_rem;
|
||||
ds1 = h * sizeof(DST_DTYPE);
|
||||
dn4 = in4;
|
||||
ds4 = align_num * sizeof(DST_DTYPE);
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
|
||||
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
|
||||
}
|
||||
/* 3. h_rem * w_align */
|
||||
if (w_align > 0 && h_rem > 0) {
|
||||
SRC_DTYPE *src_temp = src + h_align * align_num * w;
|
||||
DST_DTYPE *dst_temp = dst + h_align * align_num;
|
||||
in0 = TILING_ALIGN;
|
||||
dn0 = h_rem * sizeof(SRC_DTYPE);
|
||||
in1 = h_rem;
|
||||
is1 = w * sizeof(SRC_DTYPE);
|
||||
in4 = w_align;
|
||||
is4 = TILING_ALIGN;
|
||||
dn1 = align_num;
|
||||
ds1 = h * sizeof(DST_DTYPE);
|
||||
dn4 = in4;
|
||||
ds4 = h * align_num * sizeof(DST_DTYPE);
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
|
||||
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
|
||||
}
|
||||
/* 4. h_rem * w_rem */
|
||||
if (w_rem > 0 && h_rem > 0) {
|
||||
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
|
||||
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
|
||||
in0 = w_rem * sizeof(SRC_DTYPE);
|
||||
dn0 = h_rem * sizeof(SRC_DTYPE);
|
||||
in1 = h_rem;
|
||||
is1 = w * sizeof(SRC_DTYPE);
|
||||
dn1 = w_rem;
|
||||
ds1 = h * sizeof(DST_DTYPE);
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
|
||||
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void getTopk(float *value_buffer,
|
||||
uint32_t *index_buffer,
|
||||
float *src_buffer,
|
||||
float *compute_buffer,
|
||||
float *max_buffer,
|
||||
float *temp_buffer,
|
||||
uint32_t *i_buffer,
|
||||
uint32_t *col_buffer,
|
||||
uint32_t topk,
|
||||
uint32_t num_expert_group,
|
||||
uint32_t col,
|
||||
uint32_t row,
|
||||
uint32_t value_index_stride,
|
||||
uint32_t group_size,
|
||||
bool is_deal_group) {
|
||||
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
|
||||
for (int k = 0; k < topk; k++) {
|
||||
if (is_deal_group) {
|
||||
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
|
||||
1, num_expert_group, 1, 1);
|
||||
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
|
||||
col);
|
||||
} else {
|
||||
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
|
||||
value_index_stride);
|
||||
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
|
||||
}
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
|
||||
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
|
||||
col); // replace max value with -inf
|
||||
#else
|
||||
for (int i = 0; i < col; i++) {
|
||||
uint32_t index = __load_nram(col_buffer + i);
|
||||
max_buffer[index] = -INFINITY;
|
||||
}
|
||||
#endif
|
||||
#if __BANG_ARCH__ < 500
|
||||
if (is_deal_group) {
|
||||
for (int i = 0; i < col; i++) {
|
||||
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
|
||||
__memcpy(compute_buffer + i * row + index * group_size,
|
||||
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#if __BANG_ARCH__ >= 592
|
||||
if (is_deal_group) {
|
||||
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
|
||||
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
|
||||
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
|
||||
topk - 1);
|
||||
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
|
||||
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
|
||||
(uint32_t *)compute_buffer, col * topk, col * topk);
|
||||
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
|
||||
NRAM2NRAM, group_size * sizeof(float), col * topk);
|
||||
__bang_write_value(src_buffer, row * col, -INFINITY);
|
||||
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
|
||||
group_size * sizeof(float), col * topk);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
|
||||
T *load_buffer,
|
||||
float *src_buffer,
|
||||
float *compute_buffer,
|
||||
float *group_max_buffer,
|
||||
float *nramout_value,
|
||||
uint32_t *nramout_index,
|
||||
uint32_t *i_buffer,
|
||||
uint32_t *col_buffer,
|
||||
float *softmax_buffer,
|
||||
uint32_t row,
|
||||
uint32_t nram_compute_col_num,
|
||||
uint32_t mask_num,
|
||||
uint32_t nram_max_col_num,
|
||||
uint32_t topk,
|
||||
int num_expert_group,
|
||||
uint32_t topk_group,
|
||||
uint32_t top_num,
|
||||
uint32_t nram_col_offset,
|
||||
int normalize_mode,
|
||||
bool valid_mask,
|
||||
bool split_mask) {
|
||||
uint32_t nram_compute_num = nram_compute_col_num * row;
|
||||
// convert to float for half/bf16 datatype
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
|
||||
}
|
||||
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
|
||||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||||
|
||||
// compute softmax
|
||||
int tmp = 0x3fb8aa3b;
|
||||
float log2e = *(float *)&tmp; // for exp
|
||||
// src_buffer reuse as buffer for max/sum.
|
||||
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
|
||||
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
|
||||
nram_compute_col_num);
|
||||
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
|
||||
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
|
||||
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
|
||||
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
|
||||
nram_compute_col_num);
|
||||
__sync_cluster();
|
||||
// move mask and compute
|
||||
if (valid_mask) {
|
||||
if (!split_mask) {
|
||||
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
|
||||
if (std::is_same<T, half>::value) {
|
||||
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
|
||||
SRAM2NRAM);
|
||||
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
|
||||
mask_num * row);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
|
||||
mask_num * row * sizeof(T), SRAM2NRAM);
|
||||
__bang_bfloat162float((float *)compute_buffer,
|
||||
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
|
||||
} else {
|
||||
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
|
||||
}
|
||||
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
|
||||
mask_num * row);
|
||||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||||
} else {
|
||||
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
|
||||
nram_compute_col_num, row);
|
||||
__sync();
|
||||
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
|
||||
}
|
||||
}
|
||||
if (normalize_mode == 2) {
|
||||
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
|
||||
}
|
||||
|
||||
if (num_expert_group <= 1) {
|
||||
// num_expert_group <= 1, maintain original topk calculation logic
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
|
||||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||||
nram_max_col_num * topk * sizeof(float), 0, false);
|
||||
} else {
|
||||
// num_expert_group > 1, use grouped_topk calculation logic
|
||||
uint32_t group_size = row / num_expert_group;
|
||||
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
|
||||
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
|
||||
group_size, 1, group_size, 1, 1);
|
||||
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
|
||||
// get topk_group
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
|
||||
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
|
||||
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
|
||||
// get topk
|
||||
#if __BANG_ARCH__ < 500
|
||||
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
|
||||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||||
nram_max_col_num * top_num * sizeof(float), 0, false);
|
||||
#else
|
||||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
|
||||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||||
nram_max_col_num * top_num * sizeof(float), 0, false);
|
||||
#endif
|
||||
} // end else
|
||||
|
||||
// normalize result
|
||||
if (normalize_mode == 1) {
|
||||
// compute_buffer reuse as buffer for sum.
|
||||
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
|
||||
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
|
||||
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
|
||||
nram_compute_col_num);
|
||||
} else if (normalize_mode == 2) {
|
||||
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
|
||||
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
|
||||
nram_compute_col_num);
|
||||
}
|
||||
|
||||
// transpose back. src and dst of transpose can not be the same address.
|
||||
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
|
||||
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
|
||||
T *mask,
|
||||
int *index_out,
|
||||
float *value_out,
|
||||
int col,
|
||||
int row,
|
||||
int mask_num,
|
||||
int topk,
|
||||
int num_expert_group,
|
||||
int topk_group,
|
||||
int normalize_mode) {
|
||||
bool valid_mask = (mask != nullptr);
|
||||
int top_num = topk >= topk_group ? topk : topk_group;
|
||||
uint32_t nram_low_space =
|
||||
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
|
||||
SCATTER_ALIGN);
|
||||
if (num_expert_group <= 1) {
|
||||
nram_low_space =
|
||||
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
|
||||
}
|
||||
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
|
||||
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
|
||||
nram_max_col_num = col / taskDim + (col % taskDim > 0);
|
||||
}
|
||||
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
|
||||
if (nram_max_col_num <= 0) {
|
||||
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
|
||||
}
|
||||
uint32_t nram_deal_num = nram_max_col_num * row;
|
||||
uint32_t batch = col / mask_num;
|
||||
|
||||
// nram split:
|
||||
// |--------------------------|--------------------------|--------------------|...
|
||||
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
|
||||
// | src_buffer | compute_buffer | group_max_buffer |...
|
||||
// |--------------------------|--------------------------|--------------------|...
|
||||
|
||||
// |----------------------------------------|---------------|--------------|
|
||||
// | nram_col_num*3 | col*topk | col*topk |
|
||||
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
|
||||
// |----------------------------------------|---------------|--------------|
|
||||
float *src_buffer = (float *)nram_buffer;
|
||||
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
|
||||
float *group_max_buffer = compute_buffer + nram_deal_num;
|
||||
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
|
||||
if (num_expert_group <= 1) {
|
||||
i_buffer = (uint32_t *)group_max_buffer;
|
||||
}
|
||||
uint32_t *col_buffer = i_buffer + nram_max_col_num;
|
||||
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
|
||||
if (normalize_mode != 2) {
|
||||
softmax_buffer = (float *)col_buffer;
|
||||
}
|
||||
|
||||
float *nramout_value = softmax_buffer + nram_max_col_num;
|
||||
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
|
||||
if (num_expert_group <= 1) {
|
||||
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
|
||||
}
|
||||
T *load_buffer = (T *)src_buffer;
|
||||
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
|
||||
load_buffer = load_buffer + nram_deal_num;
|
||||
}
|
||||
|
||||
// set i_buffer
|
||||
for (uint32_t i = 0; i < nram_max_col_num; i++) {
|
||||
i_buffer[i] = i;
|
||||
}
|
||||
|
||||
// input[batch, mask, low], mask[mask, low]
|
||||
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
|
||||
bool split_mask = false;
|
||||
uint32_t batch_seg = nram_max_col_num / mask_num;
|
||||
uint32_t batch_rem = batch % batch_seg;
|
||||
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
|
||||
int repeat = DIV_UP(batch_seg_num, taskDim);
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
uint32_t seg_id = i * taskDim + taskId;
|
||||
uint32_t sram_load_num = mask_num * row;
|
||||
uint32_t sram_load_offset = 0;
|
||||
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
|
||||
? batch_rem * mask_num
|
||||
: batch_seg * mask_num;
|
||||
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
|
||||
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
|
||||
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
|
||||
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
|
||||
|
||||
// Load
|
||||
if (valid_mask) {
|
||||
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
|
||||
}
|
||||
if (nram_load_num > 0) {
|
||||
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
|
||||
// Compute
|
||||
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
|
||||
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
|
||||
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
|
||||
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
|
||||
valid_mask, split_mask);
|
||||
|
||||
// Store
|
||||
if (nram_store_num > 0) {
|
||||
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
|
||||
NRAM2GDRAM);
|
||||
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
} else {
|
||||
bool split_mask = true;
|
||||
uint32_t mask_seg = nram_max_col_num;
|
||||
uint32_t mask_rem = mask_num % mask_seg;
|
||||
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
|
||||
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
|
||||
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
|
||||
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
|
||||
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
|
||||
uint32_t batch_idx = i / sram_mask_seg_num;
|
||||
uint32_t mask_idx = i % sram_mask_seg_num;
|
||||
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
|
||||
uint32_t sram_load_num = sram_deal_mask_num * row;
|
||||
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
|
||||
? mask_idx * (sram_average_mask_num + 1)
|
||||
: mask_idx * sram_average_mask_num + sram_mask_rem;
|
||||
uint32_t sram_load_offset = sram_mask_offset * row;
|
||||
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
|
||||
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
|
||||
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
|
||||
uint32_t nram_load_num = nram_deal_mask_num * row;
|
||||
uint32_t nram_col_offset = taskIdX < nram_mask_rem
|
||||
? taskIdX * (nram_average_mask_num + 1)
|
||||
: taskIdX * nram_average_mask_num + nram_mask_rem;
|
||||
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
|
||||
uint32_t nram_store_num = nram_deal_mask_num * topk;
|
||||
uint32_t nram_store_offset =
|
||||
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
|
||||
// Load
|
||||
if (valid_mask) {
|
||||
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
|
||||
}
|
||||
if (nram_load_num > 0) {
|
||||
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
|
||||
// Compute
|
||||
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
|
||||
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
|
||||
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
|
||||
topk, num_expert_group, topk_group, top_num, nram_col_offset,
|
||||
normalize_mode, valid_mask, split_mask);
|
||||
|
||||
// Store
|
||||
if (nram_store_num > 0) {
|
||||
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
|
||||
NRAM2GDRAM);
|
||||
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
|
||||
float *reduce_weight,
|
||||
int *expert_id,
|
||||
const void *input,
|
||||
const void *mask,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int num_mask,
|
||||
const int topk,
|
||||
const int num_expert_group,
|
||||
const int topk_group,
|
||||
const cnnlDataType_t dtype,
|
||||
const int normalize_mode) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
int top_num = topk >= topk_group ? topk : topk_group;
|
||||
if (num_expert_group <= 1) {
|
||||
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
|
||||
<< "Supported max num_expert:"
|
||||
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
|
||||
<< ". Current num_expert:" << num_expert;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
} else {
|
||||
if (num_expert >
|
||||
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
|
||||
<< "Supported max num_expert:"
|
||||
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
|
||||
<< ". Current num_expert:" << num_expert;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
if (topk > num_expert) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
|
||||
<< "topk:" << topk << ". num_expert:" << num_expert;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (num_expert_group > 1) {
|
||||
if (mask != nullptr) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
|
||||
}
|
||||
if (num_expert % num_expert_group != 0) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
|
||||
<< "divisible by num_expert_group, but now num_expert:" << num_expert
|
||||
<< ", num_expert_group:" << num_expert_group;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (topk_group <= 0 || topk_group > num_expert_group) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
|
||||
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
|
||||
<< topk_group << ", num_expert group:" << num_expert_group;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (topk > (num_expert / num_expert_group) * topk_group) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
|
||||
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
|
||||
<< "topk :" << topk << ", num_expert:" << num_expert
|
||||
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
|
||||
topk, num_expert_group, topk_group, normalize_mode);
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
|
||||
topk, num_expert_group, topk_group, normalize_mode);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
|
||||
num_mask, topk, num_expert_group, topk_group, normalize_mode);
|
||||
} else {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
66
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh
Normal file
66
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh
Normal file
@@ -0,0 +1,66 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
||||
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Execute MOE Softmax Top-K Kernel.
|
||||
*
|
||||
* This function executes the MOE Softmax Top-K Kernel, which computes
|
||||
* the Top-K values along a specified dimension after applying softmax to the input data.
|
||||
* It is specifically designed for reduction along the lowest dimension.
|
||||
*
|
||||
* @param queue CNRT queue used to specify the queue for execution.
|
||||
* @param reduce_weight Pointer to store the Top-K values.
|
||||
* The shape must be [num_token, topk].
|
||||
* @param expert_id Pointer to store the indices of the Top-K values.
|
||||
* The shape must be [num_token, topk].
|
||||
* @param input Pointer to the input data containing the values to be computed.
|
||||
* The shape must be [num_token, num_expert].
|
||||
* @param mask Pointer to the input data containing the mask value to be computed after
|
||||
* computing softmax, Mask can be nullptr, which means no need to compute,
|
||||
* otherwise the shape and datatype of mask should be the same as input.
|
||||
* @param num_token Number of channels in the input data.
|
||||
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
|
||||
* @param num_mask Number of channels in the mask data.
|
||||
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
|
||||
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
|
||||
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
|
||||
* is not valid.
|
||||
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
|
||||
* than num_expert_group.
|
||||
* @param dtype Data type of the input data, should match the actual data type.
|
||||
* float, half, bfloat16 is supported.
|
||||
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
|
||||
normalization is performed; if normalize_mode == 1, the normalized denominator is
|
||||
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
|
||||
* the products of softmax_result mask.
|
||||
*/
|
||||
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
|
||||
float *reduce_weight,
|
||||
int *expert_id,
|
||||
const void *input,
|
||||
const void *mask,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int num_mask,
|
||||
const int topk,
|
||||
const int num_expert_group,
|
||||
const int topk_group,
|
||||
const cnnlDataType_t dtype,
|
||||
const int normalize_mode);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
||||
@@ -0,0 +1,425 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include "offline_quant_to_linear_cache.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#define NRAM_BUFFER_SIZE (480 * 1024)
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
#define sizeof_(T) (uint32_t)sizeof(T)
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void quantify(int8_t *nram_output,
|
||||
float *nram_input_float,
|
||||
T *nram_input,
|
||||
float *nram_scale,
|
||||
int input_num,
|
||||
int scale_num) {
|
||||
if (std::is_same<half, T>::value) {
|
||||
__bang_half2float(nram_input_float, (half *)nram_input, input_num);
|
||||
} else if (std::is_same<bfloat16_t, T>::value) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_bfloat162float(nram_input_float, (bfloat16_t *)nram_input, input_num);
|
||||
#endif
|
||||
}
|
||||
__bang_cycle_mul(nram_input_float, nram_input_float, nram_scale, input_num, scale_num);
|
||||
__bang_float2int8_rn(nram_output, (float *)nram_input_float, input_num, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void quantPerHead(int8_t *output_gdram,
|
||||
int8_t *output_nram,
|
||||
const T *input_gdram,
|
||||
T *input_nram,
|
||||
const float *scale_gdram,
|
||||
float *scale_nram,
|
||||
float *input_nram_float,
|
||||
T *trans_nram,
|
||||
int seq,
|
||||
int head_num,
|
||||
int head_size,
|
||||
size_t in_hstr_bytes, // context head_num stide bytes
|
||||
size_t in_sstr_bytes, // context seq stide bytes
|
||||
size_t scale_hstr_bytes, // scale head_num stide bytes
|
||||
size_t out_hstr_bytes, // cache head_num stride bytes
|
||||
size_t out_sstr_bytes // cache seq stride bytes
|
||||
) {
|
||||
constexpr int dtype_size = sizeof_(T);
|
||||
// nram_input: (head_num, seq, head_size)
|
||||
int io1_size = head_size * dtype_size;
|
||||
__memcpy(trans_nram, input_gdram, io1_size, GDRAM2NRAM, seq * io1_size, head_num - 1, io1_size,
|
||||
seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
|
||||
|
||||
// nram_scale:(head_num, seq);
|
||||
int io2_size = seq * sizeof_(float);
|
||||
__memcpy(scale_nram, scale_gdram, io2_size, GDRAM2NRAM, io2_size, scale_hstr_bytes, head_num - 1);
|
||||
__bang_recip(scale_nram, scale_nram, head_num * seq);
|
||||
|
||||
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_transpose((half *)input_nram, (half *)trans_nram, head_num * seq, head_size);
|
||||
} else {
|
||||
__bang_transpose(input_nram_float, (float *)trans_nram, head_num * seq, head_size);
|
||||
}
|
||||
quantify<T>(output_nram, input_nram_float, input_nram, scale_nram, head_size * head_num * seq,
|
||||
head_num * seq);
|
||||
__bang_transpose((int8_t *)trans_nram, output_nram, head_size, head_num * seq);
|
||||
|
||||
__memcpy(output_gdram, trans_nram, head_size, NRAM2GDRAM, out_hstr_bytes, head_num - 1,
|
||||
out_sstr_bytes, seq - 1, seq * head_size, head_num - 1, head_size, seq - 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerHead(int8_t *key_cache,
|
||||
int8_t *value_cache,
|
||||
const float *key_cache_scale,
|
||||
const float *value_cache_scale,
|
||||
const int *cache_bs_offsets,
|
||||
const int *cache_seq_offsets,
|
||||
const T *key,
|
||||
const T *value,
|
||||
const int *context_seq_offsets,
|
||||
const int *context_lens,
|
||||
const int batch,
|
||||
const int head_num,
|
||||
const int head_size,
|
||||
const int max_context_len,
|
||||
const int cache_mem_len,
|
||||
const size_t context_bs_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t cache_scale_head_stride,
|
||||
const bool packed,
|
||||
const int seq_block) {
|
||||
bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr);
|
||||
bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr);
|
||||
if ((!handle_key) && (!handle_value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int dtype_size = sizeof_(T);
|
||||
size_t in_hstr_bytes = context_head_stride * dtype_size;
|
||||
size_t in_sstr_bytes = context_seq_stride * dtype_size;
|
||||
size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t);
|
||||
size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t);
|
||||
size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float);
|
||||
|
||||
/* ***************************nram space ****************************
|
||||
* | scale | input/output | trans |
|
||||
* scale size:[head_num, seq_block], float
|
||||
* input size:[head_num, seq_block, head_size], float
|
||||
* trans size:, [head_size, head_num, seq_block], T
|
||||
*/
|
||||
float *scale_nram = (float *)nram_buffer;
|
||||
float *input_nram_float = nullptr;
|
||||
T *trans_nram = nullptr, *input_nram = nullptr;
|
||||
input_nram_float = scale_nram + head_num * seq_block;
|
||||
trans_nram = (T *)(input_nram_float + head_num * seq_block * head_size);
|
||||
|
||||
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
|
||||
// need cast from input_nram to input_nram_float
|
||||
input_nram = (T *)input_nram_float + seq_block * head_num * head_size;
|
||||
} else {
|
||||
input_nram = (T *)input_nram_float;
|
||||
}
|
||||
int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space
|
||||
|
||||
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
|
||||
int context_len = __load_gdram(context_lens + bs_idx);
|
||||
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
|
||||
int task_seq_begin = taskIdZ * seq_block;
|
||||
if (task_seq_begin >= seq_len) continue;
|
||||
int seq = std::min(seq_len - task_seq_begin, seq_block);
|
||||
|
||||
// context offset
|
||||
size_t context_offset = 0;
|
||||
if (packed) {
|
||||
context_offset = (context_len + task_seq_begin) * context_seq_stride;
|
||||
} else {
|
||||
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
|
||||
context_offset =
|
||||
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
|
||||
}
|
||||
|
||||
// cache offset
|
||||
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
|
||||
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
|
||||
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
|
||||
continue;
|
||||
}
|
||||
cache_seq_offset += task_seq_begin;
|
||||
size_t cache_offset = (cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
|
||||
// per_head, nram input[head_num, seq, head_size], nram scale[head_num, seq]
|
||||
if (handle_key) {
|
||||
quantPerHead(key_cache + cache_offset, output_nram, key + context_offset, input_nram,
|
||||
key_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram,
|
||||
seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes,
|
||||
out_hstr_bytes, out_sstr_bytes);
|
||||
}
|
||||
|
||||
if (handle_value) {
|
||||
quantPerHead(value_cache + cache_offset, output_nram, value + context_offset, input_nram,
|
||||
value_cache_scale + cache_seq_offset, scale_nram, input_nram_float, trans_nram,
|
||||
seq, head_num, head_size, in_hstr_bytes, in_sstr_bytes, scale_hstr_bytes,
|
||||
out_hstr_bytes, out_sstr_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUOfflineQuantToLinearCacheKernelPerChannel(
|
||||
int8_t *key_cache,
|
||||
int8_t *value_cache,
|
||||
const float *key_cache_scale,
|
||||
const float *value_cache_scale,
|
||||
const int *cache_bs_offsets,
|
||||
const int *cache_seq_offsets,
|
||||
const T *key,
|
||||
const T *value,
|
||||
const int *context_seq_offsets,
|
||||
const int *context_lens,
|
||||
const int batch,
|
||||
const int head_num,
|
||||
const int head_size,
|
||||
const int max_context_len,
|
||||
const int cache_mem_len,
|
||||
const size_t context_bs_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t cache_scale_head_stride,
|
||||
const bool packed,
|
||||
const int seq_block) {
|
||||
bool handle_key = (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr);
|
||||
bool handle_value = (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr);
|
||||
if ((!handle_key) && (!handle_value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int dtype_size = sizeof_(T);
|
||||
size_t in_hstr_bytes = context_head_stride * dtype_size;
|
||||
size_t in_sstr_bytes = context_seq_stride * dtype_size;
|
||||
size_t out_hstr_bytes = cache_head_stride * sizeof_(int8_t);
|
||||
size_t out_sstr_bytes = cache_seq_stride * sizeof_(int8_t);
|
||||
size_t scale_hstr_bytes = cache_scale_head_stride * sizeof_(float);
|
||||
|
||||
/* *********************************nram space **************************************
|
||||
* per_chennel: |scale[head_num, head_size]| input[seq_block, head_num, head_size]|
|
||||
*/
|
||||
float *scale_nram = (float *)nram_buffer;
|
||||
float *input_nram_float = scale_nram + head_num * head_size;
|
||||
|
||||
T *input_nram = (T *)input_nram_float;
|
||||
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
|
||||
// need cast from input_nram to input_nram_float
|
||||
input_nram = (T *)input_nram_float + seq_block * head_num * head_size;
|
||||
}
|
||||
int8_t *output_nram = (int8_t *)input_nram_float; // output and input share nram space
|
||||
int size1 = head_size * sizeof_(float);
|
||||
int size2 = head_size * dtype_size;
|
||||
int scale_num = head_num * head_size;
|
||||
|
||||
if (handle_key) {
|
||||
// load offline scale nram_scale:(head_num, head_size);
|
||||
__memcpy(scale_nram, key_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes, head_num - 1);
|
||||
__bang_recip(scale_nram, scale_nram, scale_num);
|
||||
|
||||
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
|
||||
int context_len = __load_gdram(context_lens + bs_idx);
|
||||
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
|
||||
int task_seq_begin = taskIdZ * seq_block;
|
||||
if (task_seq_begin >= seq_len) continue;
|
||||
int seq = std::min(seq_len - task_seq_begin, seq_block);
|
||||
|
||||
// context offset
|
||||
size_t context_offset = 0;
|
||||
if (packed) {
|
||||
context_offset = (context_len + task_seq_begin) * context_seq_stride;
|
||||
} else {
|
||||
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
|
||||
context_offset =
|
||||
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
|
||||
}
|
||||
|
||||
// cache offset
|
||||
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
|
||||
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
|
||||
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
|
||||
continue;
|
||||
}
|
||||
cache_seq_offset += task_seq_begin;
|
||||
size_t cache_offset =
|
||||
(cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
|
||||
|
||||
__memcpy(input_nram, key + context_offset, size2, GDRAM2NRAM, size2, head_num - 1,
|
||||
head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
|
||||
|
||||
quantify<T>((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num,
|
||||
scale_num);
|
||||
|
||||
__memcpy(key_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes,
|
||||
head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (handle_value) {
|
||||
// load offline scale nram_scale:(head_num, head_size);
|
||||
__memcpy(scale_nram, value_cache_scale, size1, GDRAM2NRAM, size1, scale_hstr_bytes,
|
||||
head_num - 1);
|
||||
__bang_recip(scale_nram, scale_nram, scale_num);
|
||||
|
||||
for (int bs_idx = taskIdY; bs_idx < batch; bs_idx += taskDimY) {
|
||||
int context_len = __load_gdram(context_lens + bs_idx);
|
||||
int seq_len = packed ? (__load_gdram(context_lens + bs_idx + 1) - context_len) : context_len;
|
||||
int task_seq_begin = taskIdZ * seq_block;
|
||||
if (task_seq_begin >= seq_len) continue;
|
||||
int seq = std::min(seq_len - task_seq_begin, seq_block);
|
||||
|
||||
// context offset
|
||||
size_t context_offset = 0;
|
||||
if (packed) {
|
||||
context_offset = (context_len + task_seq_begin) * context_seq_stride;
|
||||
} else {
|
||||
int seq_offset = (context_seq_offsets == nullptr) ? 0 : context_seq_offsets[bs_idx];
|
||||
context_offset =
|
||||
(bs_idx * context_bs_stride + (seq_offset + task_seq_begin) * context_seq_stride);
|
||||
}
|
||||
|
||||
// cache offset
|
||||
int cache_seq_offset = (cache_seq_offsets == nullptr ? 0 : cache_seq_offsets[bs_idx]);
|
||||
int cache_bs_offset = cache_bs_offsets == nullptr ? bs_idx : cache_bs_offsets[bs_idx];
|
||||
if (cache_seq_offset < 0 || cache_bs_offset < 0) {
|
||||
continue;
|
||||
}
|
||||
cache_seq_offset += task_seq_begin;
|
||||
size_t cache_offset =
|
||||
(cache_bs_offset * cache_bs_stride + cache_seq_offset * cache_seq_stride);
|
||||
|
||||
__memcpy(input_nram, value + context_offset, size2, GDRAM2NRAM, size2, head_num - 1,
|
||||
head_num * size2, seq - 1, in_hstr_bytes, head_num - 1, in_sstr_bytes, seq - 1);
|
||||
|
||||
quantify<T>((int8_t *)output_nram, input_nram_float, input_nram, scale_nram, seq * scale_num,
|
||||
scale_num);
|
||||
|
||||
__memcpy(value_cache + cache_offset, output_nram, head_size, NRAM2GDRAM, out_hstr_bytes,
|
||||
head_num - 1, out_sstr_bytes, seq - 1, head_size, head_num - 1, scale_num, seq - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
#define LAUNCH_OFFLINE_QUANT_KERNEL(Dtype, Name) \
|
||||
kernels::MLUOfflineQuantToLinearCacheKernel##Name<Dtype><<<dim, cnrtFuncTypeBlock, queue>>>( \
|
||||
(int8_t *)key_cache, (int8_t *)value_cache, (float *)key_cache_scale, \
|
||||
(float *)value_cache_scale, (int *)cache_bs_offsets, (int *)cache_seq_offsets, (Dtype *)key, \
|
||||
(Dtype *)value, (int *)context_seq_offsets, (int *)context_lens, batch, head_num, head_size, \
|
||||
max_context_len, cache_mem_len, context_bs_stride, context_head_stride, context_seq_stride, \
|
||||
cache_bs_stride, cache_head_stride, cache_seq_stride, cache_scale_head_stride, packed, \
|
||||
seq_block);
|
||||
|
||||
KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue,
|
||||
void *key_cache,
|
||||
void *value_cache,
|
||||
const void *key_cache_scale,
|
||||
const void *value_cache_scale,
|
||||
const void *cache_bs_offsets,
|
||||
const void *cache_seq_offsets,
|
||||
const void *key,
|
||||
const void *value,
|
||||
const void *context_seq_offsets,
|
||||
const void *context_lens,
|
||||
const cnnlDataType_t dtype,
|
||||
const int batch,
|
||||
const int head_num,
|
||||
const int head_size,
|
||||
const int max_context_len,
|
||||
const int cache_mem_len,
|
||||
const size_t context_bs_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t cache_scale_head_stride,
|
||||
const bool packed,
|
||||
const int quant_mode) {
|
||||
constexpr int nram_size = 480 * 1024;
|
||||
int dtype_size = dtype == CNNL_DTYPE_FLOAT ? sizeof(float) : sizeof(half);
|
||||
int seq_block = 0;
|
||||
if (quant_mode == 0) {
|
||||
seq_block = nram_size / (head_num * head_size * sizeof(float)) - 1;
|
||||
if (seq_block <= 0) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * head_size * sizeof(float) should be less than 240KB when "
|
||||
"quant_mode is 0."
|
||||
<< std::endl;
|
||||
}
|
||||
} else {
|
||||
seq_block = nram_size /
|
||||
(head_num * sizeof(float) + head_num * head_size * (sizeof(float) + dtype_size));
|
||||
if (seq_block <= 0) {
|
||||
std::cerr << __func__ << "," << __LINE__
|
||||
<< " :head_num * sizeof(float) + head_num * head_size * (sizeof(float) + "
|
||||
"context_dtype_size)) "
|
||||
<< " should be less than 480KB when quant_mode is not 0." << std::endl;
|
||||
}
|
||||
}
|
||||
seq_block = std::min(seq_block, max_context_len);
|
||||
if (seq_block > 16 && seq_block < max_context_len) {
|
||||
seq_block = seq_block / 16 * 16;
|
||||
}
|
||||
int seq_seg = (max_context_len + seq_block - 1) / seq_block;
|
||||
|
||||
CNdev dev;
|
||||
int cluster_dim, core_dim;
|
||||
cnCtxGetDevice(&dev);
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_dim, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
uint32_t core_num = cluster_dim * core_dim;
|
||||
uint32_t task_y_dim = std::min((uint32_t)batch, core_num);
|
||||
cnrtDim3_t dim{1, task_y_dim, (uint32_t)seq_seg};
|
||||
|
||||
if (dtype == CNNL_DTYPE_HALF) {
|
||||
if (quant_mode == 0) {
|
||||
LAUNCH_OFFLINE_QUANT_KERNEL(half, PerChannel);
|
||||
} else {
|
||||
LAUNCH_OFFLINE_QUANT_KERNEL(half, PerHead);
|
||||
}
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (quant_mode == 0) {
|
||||
LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerChannel);
|
||||
} else {
|
||||
LAUNCH_OFFLINE_QUANT_KERNEL(bfloat16_t, PerHead);
|
||||
}
|
||||
} else {
|
||||
if (quant_mode == 0) {
|
||||
LAUNCH_OFFLINE_QUANT_KERNEL(float, PerChannel);
|
||||
} else {
|
||||
LAUNCH_OFFLINE_QUANT_KERNEL(float, PerHead);
|
||||
}
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,103 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
|
||||
#define CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
|
||||
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Quantize current key and value, Then store key and value to key_cache and value_cache.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param key_cache: Pointer to the MLU memory that stores the key cache,
|
||||
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
|
||||
* Data type of key_cache must be int8. key_cache could be nullptr.
|
||||
* @param value_cache: Pointer to the MLU memory that stores the value cache,
|
||||
* the shape must be [max_batch, head_num, cache_mem_len, head_size].
|
||||
* Data type of value_cache must be int8. value_cache could be nullptr.
|
||||
* @param key_cache_scale: Pointer to the MLU memory that stores the key cache scale,
|
||||
* the shape must be [head_num, cache_mem_len] when quant_mode is not zero,
|
||||
* and [head_num, head_size] when quant_mode is zero. Data type of key_cache_scale
|
||||
* must be float. value_cache could be nullptr.
|
||||
* @param value_cache_scale: Pointer to the MLU memory that stores the value cache scale,
|
||||
* the shape must be [head_num, cache_mem_len] when quant_mode is not zero,
|
||||
* and [head_num, head_size] when quant_mode is zero. Data type of value_cache_scale
|
||||
* must be float. value_cache_scale could be nullptr.
|
||||
* @param cache_bs_offsets: Pointer to the MLU memory that stores the batch
|
||||
* offset of cache, the shape must be [batch], if it's nullptr, the
|
||||
* default value is {0, 1, 2 ... batch - 1}.
|
||||
* @param cache_seq_offsets: Pointer to the MLU memory that stores the sequence
|
||||
* offset of cache, the shape must be [batch], if it's nullptr, the
|
||||
* default value is 0 for every batch.
|
||||
* @param key: Pointer to the MLU memory that stores the key,
|
||||
* the shape must be [batch, max_contxt_len, head_num, head_size].
|
||||
* Data type of key couble be float/half/bfloat16. key could be nullptr.
|
||||
* @param value: Pointer to the MLU memory that stores the value,
|
||||
* the shape must be [batch, max_contxt_len, head_num, head_size].
|
||||
* Data type of value couble be float/half/bfloat16, value could be nullptr.
|
||||
* @param context_seq_offsets: Pointer to the MLU memory that stores the
|
||||
* sequence offset of context, the shape must be [batch]. if it's nullptr,
|
||||
* the default value is 0 for every batch. It must be nullptr when packed is true.
|
||||
* @param context_lens: Pointer to the MLU memory that stores the sequence length or cumulative
|
||||
* sequence length of context. when packed is false, the shape must be [batch], which
|
||||
* indicates sequence length of context. when packed is true, the shape must be [batch + 1],
|
||||
which
|
||||
* indicates cumulative sequence length of context.
|
||||
* @param dtype: Data type.
|
||||
* @param batch: Batch size.
|
||||
* @param head_num: Head number.
|
||||
* @param head_size: Head size.
|
||||
* @param max_contxt_len: The maximum sequence length of context.
|
||||
* @param cache_mem_len: The maximum sequence length of cache.
|
||||
* @param contxt_bs_stride: The stride of batch in context, does not work when packed is true.
|
||||
* @param contxt_head_stride: The stride of head_num in context.
|
||||
* @param contxt_seq_stride: The stride of max_contxt_len in context.
|
||||
* @param cache_bs_stride: The stride of batch in cache.
|
||||
* @param cache_head_stride: The stride of head_num in cache.
|
||||
* @param cache_seq_stride: The stride of cache_mem_len in cache.
|
||||
* @param cache_scale_bs_stride: The stride of batch in cache scale.
|
||||
* @param cache_scale_head_stride: The stride of head in cache scale.
|
||||
* @param packed: A boolean value indicates whether to use pack mode.
|
||||
* @param quant_mode: A int value indicates the quantify mode, 0 means quantify by per_channel, and
|
||||
others value means quantify by per_head.
|
||||
* @note If one of key/key_cache/key_cache_scale is nullptr, nothing todo for key.
|
||||
If one of value/value_cache/value_cache_scale is nullptr, nothing todo for value.
|
||||
*/
|
||||
KernelStatus invokeOfflineQuantToLinearCache(cnrtQueue_t queue,
|
||||
void *key_cache,
|
||||
void *value_cache,
|
||||
const void *key_cache_scale,
|
||||
const void *value_cache_scale,
|
||||
const void *cache_bs_offsets,
|
||||
const void *cache_seq_offsets,
|
||||
const void *key,
|
||||
const void *value,
|
||||
const void *context_seq_offsets,
|
||||
const void *context_lens,
|
||||
const cnnlDataType_t dtype,
|
||||
const int batch,
|
||||
const int head_num,
|
||||
const int head_size,
|
||||
const int max_context_len,
|
||||
const int cache_mem_len,
|
||||
const size_t context_bs_stride,
|
||||
const size_t context_head_stride,
|
||||
const size_t context_seq_stride,
|
||||
const size_t cache_bs_stride,
|
||||
const size_t cache_head_stride,
|
||||
const size_t cache_seq_stride,
|
||||
const size_t cache_scale_head_stride,
|
||||
const bool packed,
|
||||
const int quant_mode);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_LINEAR_CACHE_MLUH_
|
||||
@@ -0,0 +1,232 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <climits>
|
||||
#include "offline_quant_to_paged_cache.mluh"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define sizeof_(T) (uint32_t)sizeof(T)
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
#define REM_FOR_STACK (32 * 1024)
|
||||
__nram__ int8_t nram_buffer[__MLU_NRAM_SIZE__ * 1024 - REM_FOR_STACK];
|
||||
__nram__ int nram_range_32[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void quantifyToInt8(T *nram_input, float *nram_scale, int token_handle, int head_len) {
|
||||
// quantify
|
||||
if (std::is_same<half, T>::value) {
|
||||
__bang_half2float((float *)nram_input,
|
||||
(half *)((int8_t *)nram_input + token_handle * head_len * sizeof_(half)),
|
||||
token_handle * head_len);
|
||||
}
|
||||
if (std::is_same<bfloat16_t, T>::value) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_bfloat162float(
|
||||
(float *)nram_input,
|
||||
(bfloat16_t *)((int8_t *)nram_input + token_handle * head_len * sizeof_(bfloat16_t)),
|
||||
token_handle * head_len);
|
||||
#endif
|
||||
}
|
||||
__bang_cycle_mul((float *)nram_input, (float *)nram_input, nram_scale, token_handle * head_len,
|
||||
head_len);
|
||||
__bang_float2int8_rn((int8_t *)nram_input, (float *)nram_input, token_handle * head_len, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUOfflineQuantToPagedCacheKernel(T *key,
|
||||
T *value,
|
||||
int8_t *key_cache,
|
||||
int8_t *value_cache,
|
||||
float *key_cache_scale,
|
||||
float *value_cache_scale,
|
||||
int *slot_mapping,
|
||||
size_t key_stride0,
|
||||
size_t value_stride0,
|
||||
int tokens_num,
|
||||
int head_num,
|
||||
int block_size,
|
||||
int head_size,
|
||||
int tokens_block) {
|
||||
/*******************************************************nram space***********************
|
||||
* nram:| input | scale | cache_offset | scale_offset | mask | temp | index |
|
||||
* input size: tokens_block * head_num * head_size * sizeof(float)
|
||||
* scale size: head_num * head_size * sizeof(float)
|
||||
* cache_offset size: tokens_block * head_num * sizeof(float)
|
||||
* scale_offset size: equal to cache_offset size
|
||||
* mask size: CEIL_DIV(tokens_size * head_num, 8) * sizeof(int8_t)
|
||||
* temp size: CEIL_ALIGN(token_size * head_num, 8) * sizeof(int)
|
||||
* index size: head_num * sizeof(int)
|
||||
****************************************************************************************/
|
||||
#if __BANG_ARCH__ > 500
|
||||
int token_begin = taskId * tokens_block;
|
||||
if (token_begin >= tokens_num) return;
|
||||
int token_handle = std::min(tokens_block, tokens_num - token_begin);
|
||||
|
||||
int seq_len = token_handle * head_num;
|
||||
int head_len = head_num * head_size;
|
||||
int pad8_num = CEIL_DIV(seq_len, CHAR_BIT) * CHAR_BIT;
|
||||
int input_size = seq_len * head_size * sizeof_(float);
|
||||
int8_t *nram_input = nram_buffer;
|
||||
float *nram_scale = (float *)(nram_buffer + input_size);
|
||||
int *cache_offset = (int *)(nram_scale + head_len);
|
||||
int *scale_offset = cache_offset + pad8_num;
|
||||
int *nram_mask = scale_offset + pad8_num;
|
||||
int *nram_temp = nram_mask + pad8_num;
|
||||
int *head_index = nram_temp + pad8_num;
|
||||
|
||||
// generate range: (0, 1, 2, ..., (head_num - 1))
|
||||
__memcpy(head_index, nram_range_32, std::min(head_num, 32) * sizeof_(int), NRAM2NRAM);
|
||||
int begin = 32;
|
||||
while (begin < head_num) {
|
||||
int count = std::min(begin, head_num - begin);
|
||||
__bang_add_scalar(head_index + begin, head_index, begin, count);
|
||||
begin += count;
|
||||
}
|
||||
|
||||
// load slot(token_handle) -> expand(head_num, token_handle) ->transpose(token_handle, head_num)
|
||||
int token_size = token_handle * sizeof_(int);
|
||||
__memcpy(scale_offset, slot_mapping + token_begin, token_size, GDRAM2NRAM);
|
||||
__memcpy(nram_temp, scale_offset, token_size, NRAM2NRAM, token_size, 0, head_num - 1);
|
||||
__bang_transpose(scale_offset, nram_temp, head_num, token_handle);
|
||||
|
||||
__bang_write_zero((float *)nram_temp, pad8_num);
|
||||
__bang_ge_bitindex((float *)nram_mask, (float *)scale_offset, (float *)nram_temp, pad8_num);
|
||||
|
||||
// calculate cache/scale scatter offset
|
||||
__bang_div(cache_offset, scale_offset, (int)block_size, seq_len);
|
||||
__bang_rem(scale_offset, scale_offset, (int)block_size, seq_len);
|
||||
__bang_mul_scalar(cache_offset, cache_offset, head_num * block_size, seq_len);
|
||||
__bang_mul_scalar(head_index, head_index, block_size, head_num);
|
||||
__bang_cycle_add(cache_offset, cache_offset, head_index, seq_len, head_num);
|
||||
__bang_add(scale_offset, cache_offset, scale_offset, seq_len);
|
||||
__bang_mul_scalar(cache_offset, scale_offset, head_size, seq_len);
|
||||
__bang_mul_scalar(scale_offset, scale_offset, sizeof_(float), seq_len);
|
||||
|
||||
int hidden_bytes = head_num * head_size * sizeof_(T);
|
||||
bool half_size = (sizeof(T) == sizeof(half));
|
||||
if (key != nullptr && key_cache != nullptr && key_cache_scale != nullptr) {
|
||||
// load key_cache_scale
|
||||
__memcpy(nram_scale, key_cache_scale, head_len * sizeof_(float), GDRAM2NRAM);
|
||||
__bang_recip(nram_scale, nram_scale, head_len);
|
||||
// (token_handle, head_num, head_size)
|
||||
__memcpy(nram_input + half_size * token_handle * hidden_bytes, key + token_begin * key_stride0,
|
||||
hidden_bytes, GDRAM2NRAM, hidden_bytes, key_stride0 * sizeof_(T), token_handle - 1);
|
||||
// quantify
|
||||
quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len);
|
||||
// scatter to gdram
|
||||
__scatter(key_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size,
|
||||
NRAM2GDRAM, head_size, seq_len);
|
||||
}
|
||||
|
||||
if (value != nullptr && value_cache != nullptr && value_cache_scale != nullptr) {
|
||||
// load key_cache_scale
|
||||
__memcpy(nram_scale, value_cache_scale, head_len * sizeof_(float), GDRAM2NRAM);
|
||||
__bang_recip(nram_scale, nram_scale, head_len);
|
||||
// (token_handle, head_num, head_size)
|
||||
__memcpy(nram_input + half_size * token_handle * hidden_bytes,
|
||||
value + token_begin * value_stride0, hidden_bytes, GDRAM2NRAM, hidden_bytes,
|
||||
value_stride0 * sizeof_(T), token_handle - 1);
|
||||
// quantify
|
||||
quantifyToInt8((T *)nram_input, nram_scale, token_handle, head_len);
|
||||
// scatter to gdram
|
||||
__scatter(value_cache, (int8_t *)nram_input, (uint32_t *)cache_offset, nram_mask, head_size,
|
||||
NRAM2GDRAM, head_size, seq_len);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue,
|
||||
cnnlDataType_t data_type,
|
||||
void *key,
|
||||
void *value,
|
||||
void *key_cache,
|
||||
void *value_cache,
|
||||
void *key_cache_scale,
|
||||
void *value_cache_scale,
|
||||
void *slot_mapping,
|
||||
size_t key_stride0,
|
||||
size_t value_stride0,
|
||||
int num_tokens,
|
||||
int num_heads,
|
||||
int block_num,
|
||||
int block_size,
|
||||
int head_size) {
|
||||
if (is_arch300()) {
|
||||
std::cerr << "[invokeOfflineQuantToPagedCache]: kernel does not support MLU300 devices."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
int dtype_size = 1;
|
||||
if (data_type == CNNL_DTYPE_HALF || data_type == CNNL_DTYPE_BFLOAT16) {
|
||||
dtype_size = 2;
|
||||
} else if (data_type == CNNL_DTYPE_FLOAT) {
|
||||
dtype_size = 4;
|
||||
} else {
|
||||
std::cerr << "invokeOfflineQuantToPagedCache: unsupport data type\n";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
int64_t kv_cache_range = block_num * block_size * num_heads * head_size * dtype_size;
|
||||
if (kv_cache_range > UINT32_MAX) {
|
||||
std::cerr
|
||||
<< "invokeOfflineQuantToPagedCache: The addressing range of kv_cache cannot exceed 4G."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
// nram_size_need: token_block * head_num * head_size + head_num * head_size * sizeof(float)
|
||||
// token_block * head_num * 4 * sizeof(int) + head_num * sizeof(int)
|
||||
// nram uesd: 480KB
|
||||
int nram_size = 480 * 1024 - num_heads * sizeof(int) - num_heads * head_size * sizeof(float);
|
||||
int hidden_bytes = num_heads * head_size * sizeof(float) +
|
||||
4 * CEIL_DIV(num_heads, CHAR_BIT) * CHAR_BIT * sizeof(int);
|
||||
int seq_block = nram_size / hidden_bytes;
|
||||
if (seq_block <= 0) {
|
||||
std::cerr << "invokeOfflineQuantToPagedCache: "
|
||||
<< "num_heads * head_size * dtype_size + 4 * num_heads * sizeof(int) "
|
||||
<< "should be less than 480KB.\n";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (seq_block > 16) {
|
||||
seq_block = seq_block / 16 * 16;
|
||||
}
|
||||
int cluster_num, core_dim;
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, dev));
|
||||
int core_num = core_dim * cluster_num;
|
||||
seq_block = std::min(seq_block, CEIL_DIV(num_tokens, core_num));
|
||||
uint32_t task_dim = CEIL_DIV(num_tokens, seq_block);
|
||||
cnrtDim3_t dim{1, task_dim, 1};
|
||||
|
||||
if (data_type == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(float *)key, (float *)value, (int8_t *)key_cache, (int8_t *)value_cache,
|
||||
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
|
||||
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
|
||||
} else if (data_type == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(half *)key, (half *)value, (int8_t *)key_cache, (int8_t *)value_cache,
|
||||
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
|
||||
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
|
||||
} else {
|
||||
kernels::MLUOfflineQuantToPagedCacheKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(bfloat16_t *)key, (bfloat16_t *)value, (int8_t *)key_cache, (int8_t *)value_cache,
|
||||
(float *)key_cache_scale, (float *)value_cache_scale, (int *)slot_mapping, key_stride0,
|
||||
value_stride0, num_tokens, num_heads, block_size, head_size, seq_block);
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,62 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
|
||||
#define CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
|
||||
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Perform offline_quant_to_paged_cache operation.
|
||||
* @param queue[in]: The queue for mlu.
|
||||
* @param data_type[in]: The cnnl data type of key.
|
||||
* @param key[in]: Pointer to the MLU memory that stores the key tensor which has shape [num_tokens,
|
||||
* num_heads, head_size]. Data type of key must be half/bfloat16_t/float.
|
||||
* @param value[in]: Pointer to the MLU memory that stores the value tensor which has shape
|
||||
* [num_tokens, num_heads, head_size]. Data type of value must be half/bfloat16_t/float.
|
||||
* @param key_cache[out]: Pointer to the MLU memory that stores the key_cache tensor which has
|
||||
* shape [num_blocks, num_heads, block_size, head_size]. Data type of key cache must be int8_t.
|
||||
* @param value_cache[out]: Pointer to the MLU memory that stores the value_cache tensor which has
|
||||
* shape [num_blocks, num_heads, block_size, head_size]. Data type of value cache must be int8_t.
|
||||
* @param key_cache_scale[in]: Pointer to the MLU memory that stores the key_cache_scale tensor
|
||||
* which has shape [num_heads, head_size]. Data type of key cache scale must be float.
|
||||
* @param value_cache_scale[in]: Pointer to the MLU memory that stores the value_cache_scale tensor
|
||||
* which has shape [num_heads, head_size]. Data type of value cache scale must be float.
|
||||
* @param slot_mapping[in]: Pointer to the MLU memory that stores the slot_mapping tensor which has
|
||||
* shape [num_tokens]. Data type of slot mapping must be int32_t.
|
||||
* @param key_stride0[in]: The first dimension stride length of key_cache tensor.
|
||||
* @param value_stride0[in]: The first dimension stride length of value_cache tensor.
|
||||
* @param num_tokens[in]: Total number of tokens.
|
||||
* @param num_heads[in]: Head number.
|
||||
* @param block_num[in]: Total number of blocks.
|
||||
* @param block_size[in]: Number of tokens per block.
|
||||
* @param head_size[in]: Head size.
|
||||
* @note: offline_quant_to_paged_cache does not support MLU300 device.
|
||||
*/
|
||||
KernelStatus invokeOfflineQuantToPagedCache(cnrtQueue_t queue,
|
||||
cnnlDataType_t data_type,
|
||||
void *key,
|
||||
void *value,
|
||||
void *key_cache,
|
||||
void *value_cache,
|
||||
void *key_cache_scale,
|
||||
void *value_cache_scale,
|
||||
void *slot_mapping,
|
||||
size_t key_stride0,
|
||||
size_t value_stride0,
|
||||
int num_tokens,
|
||||
int num_heads,
|
||||
int block_num,
|
||||
int block_size,
|
||||
int head_size);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_OFFLINE_QUANT_TO_PAGED_CACHE_MLUH_
|
||||
156
torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mlu
Normal file
156
torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mlu
Normal file
@@ -0,0 +1,156 @@
|
||||
#include <algorithm>
|
||||
#include "cnrt.h"
|
||||
#include "operate_cu_seq_lens.mluh"
|
||||
|
||||
namespace {
|
||||
constexpr int pair_elem_num = 2;
|
||||
}
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
#define ONCHIP_DATA_NUM ((int)((__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) / sizeof(int)))
|
||||
__nram__ int nram_buffer[ONCHIP_DATA_NUM];
|
||||
__nram__ const int acc_seq_lens[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
|
||||
|
||||
__mlu_func__ void genSeqLens(int *seq_len_nram, int start, int multi, int elem_count) {
|
||||
constexpr int acc_seq_lens_size = 16;
|
||||
int count = std::min(acc_seq_lens_size, elem_count);
|
||||
int add_on = multi * acc_seq_lens_size;
|
||||
__bang_mul_scalar(seq_len_nram, acc_seq_lens, multi, count);
|
||||
__bang_add_scalar(seq_len_nram, seq_len_nram, start, count);
|
||||
while (count < elem_count) {
|
||||
__bang_add_scalar(seq_len_nram + count, seq_len_nram, add_on,
|
||||
std::min(count, elem_count - count));
|
||||
count *= 2;
|
||||
add_on *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUSliceCuSeqlens(int *cu_seq_lens,
|
||||
int *sliced_cu_seq_lens,
|
||||
int batch,
|
||||
int every,
|
||||
int remain,
|
||||
int loop) {
|
||||
int cu_seq_lens_elem_count = batch + 1;
|
||||
int sliced_cu_seq_lens_elem_count = batch + loop;
|
||||
|
||||
int *cu_seq_lens_narm = nram_buffer;
|
||||
int *sliced_cu_seq_lens_narm = cu_seq_lens_narm + cu_seq_lens_elem_count;
|
||||
int *sliced_cu_seq_lens_narm_start = sliced_cu_seq_lens_narm;
|
||||
|
||||
__memcpy(cu_seq_lens_narm, cu_seq_lens, cu_seq_lens_elem_count * sizeof(int), GDRAM2NRAM);
|
||||
__bang_write_zero(sliced_cu_seq_lens_narm, sliced_cu_seq_lens_elem_count);
|
||||
for (int i = 0; i < loop; ++i) {
|
||||
int elem_num = 1 + (i == loop - 1 && remain != 0 ? remain : every);
|
||||
__bang_sub_scalar(sliced_cu_seq_lens_narm, cu_seq_lens_narm, cu_seq_lens_narm[0], elem_num);
|
||||
cu_seq_lens_narm += elem_num - 1;
|
||||
sliced_cu_seq_lens_narm += elem_num;
|
||||
}
|
||||
__memcpy(sliced_cu_seq_lens, sliced_cu_seq_lens_narm_start,
|
||||
sliced_cu_seq_lens_elem_count * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUGenerateKVCuSeqlens(int *gen_cu_seq_lens,
|
||||
int every,
|
||||
int remain,
|
||||
int loop,
|
||||
int seq_len,
|
||||
bool is_causal_mask,
|
||||
int seg_data_num,
|
||||
int task_num) {
|
||||
int offset = seg_data_num * taskIdX;
|
||||
int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset);
|
||||
int seq_len_elem_num = total_elem_num / pair_elem_num;
|
||||
|
||||
int *gen_cu_seq_lens_narm = nram_buffer;
|
||||
__bang_write_zero(gen_cu_seq_lens_narm, total_elem_num);
|
||||
|
||||
if (is_causal_mask) {
|
||||
int *seq_lens_narm = gen_cu_seq_lens_narm + total_elem_num;
|
||||
genSeqLens(seq_lens_narm, every * offset / pair_elem_num, every, seq_len_elem_num);
|
||||
__memcpy(gen_cu_seq_lens_narm + 1, seq_lens_narm, sizeof(int), NRAM2NRAM,
|
||||
pair_elem_num * sizeof(int), sizeof(int), seq_len_elem_num - 1);
|
||||
if (remain != 0 && taskIdX == task_num - 1) {
|
||||
gen_cu_seq_lens_narm[total_elem_num - 1] -= (every - remain);
|
||||
}
|
||||
} else {
|
||||
__bang_write_value(gen_cu_seq_lens_narm + 1, 1, seq_len, pair_elem_num * sizeof(int),
|
||||
seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0);
|
||||
}
|
||||
__memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUGenerateQCuSeqlens(int *gen_cu_seq_lens,
|
||||
int every,
|
||||
int remain,
|
||||
int loop,
|
||||
int seg_data_num,
|
||||
int task_num) {
|
||||
int offset = seg_data_num * taskIdX;
|
||||
int total_elem_num = std::min(seg_data_num, loop * pair_elem_num - offset);
|
||||
int seq_len_elem_num = total_elem_num / pair_elem_num;
|
||||
|
||||
int *gen_cu_seq_lens_narm = nram_buffer;
|
||||
|
||||
__bang_write_zero(gen_cu_seq_lens_narm, total_elem_num);
|
||||
__bang_write_value(gen_cu_seq_lens_narm + 1, 1, every, pair_elem_num * sizeof(int),
|
||||
seq_len_elem_num - 1, seq_len_elem_num * pair_elem_num * sizeof(int), 0);
|
||||
if (remain != 0 && taskIdX == task_num - 1) {
|
||||
gen_cu_seq_lens_narm[total_elem_num - 1] = remain;
|
||||
}
|
||||
__memcpy(gen_cu_seq_lens + offset, gen_cu_seq_lens_narm, total_elem_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue,
|
||||
int *cu_seq_lens,
|
||||
int *sliced_cu_seq_lens,
|
||||
int batch,
|
||||
int parallel_num) {
|
||||
int every = (batch + parallel_num - 1) / parallel_num;
|
||||
int repeat = batch / every;
|
||||
int remain = batch % every;
|
||||
int loop = repeat + (remain != 0);
|
||||
|
||||
cnrtDim3_t dim{1, 1, 1};
|
||||
kernels::MLUSliceCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(cu_seq_lens, sliced_cu_seq_lens,
|
||||
batch, every, remain, loop);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue,
|
||||
int *gen_cu_seq_lens,
|
||||
int seq_len,
|
||||
int parallel_num,
|
||||
bool is_causal_mask,
|
||||
bool is_kv_seq_len) {
|
||||
int every = (seq_len + parallel_num - 1) / parallel_num;
|
||||
int repeat = seq_len / every;
|
||||
int remain = seq_len % every;
|
||||
int loop = repeat + (remain != 0);
|
||||
|
||||
int seg_data_num = ONCHIP_DATA_NUM / 2;
|
||||
if (is_kv_seq_len && is_causal_mask) {
|
||||
// max segnum for 2d memcpy is 64k
|
||||
seg_data_num = std::min(pair_elem_num * 64 * 1024, seg_data_num);
|
||||
}
|
||||
int total_elem_num = loop * pair_elem_num;
|
||||
int task_num = (total_elem_num + seg_data_num - 1) / seg_data_num;
|
||||
|
||||
cnrtDim3_t dim{(unsigned int)task_num, 1, 1};
|
||||
|
||||
if (is_kv_seq_len) {
|
||||
kernels::MLUGenerateKVCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
gen_cu_seq_lens, every, remain, loop, seq_len, is_causal_mask, seg_data_num, task_num);
|
||||
} else {
|
||||
kernels::MLUGenerateQCuSeqlens<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
gen_cu_seq_lens, every, remain, loop, seg_data_num, task_num);
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
66
torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mluh
Normal file
66
torch_mlu_ops-v1.3.2/csrc/kernels/operate_cu_seq_lens.mluh
Normal file
@@ -0,0 +1,66 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
|
||||
#define CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
/**
|
||||
* @brief slice cu_seq_lens and cu_k_seq_lens for parallel context when attention split batch;
|
||||
* @example
|
||||
* cu_seq_lens: [0, 2, 5, 10, 20, 33, 46, 51, 77]
|
||||
* batch: 8, parallel_num: 3
|
||||
* sliced_cu_seq_lens: [0, 2, 5, 10, 0, 10, 23, 36, 0, 5, 31]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param cu_seq_lens: Input. Pointer to the MLU memory that stores the current seq lens, the shape
|
||||
* is [batch + 1].
|
||||
* @param sliced_cu_seq_lens: Output. Pointer to the MLU memory that stores the sliced current seq
|
||||
* lens, the shape is [batch + loop_time].
|
||||
* @param batch: Batch size.
|
||||
* @param parallel_num: Parallel num of batch.
|
||||
*/
|
||||
KernelStatus invokeSliceCuSeqlens(cnrtQueue_t queue,
|
||||
int *cu_seq_lens,
|
||||
int *sliced_cu_seq_lens,
|
||||
int batch,
|
||||
int parallel_num);
|
||||
|
||||
/**
|
||||
* @brief generate cu_seq_lens and cu_k_seq_lens for parallel context when attention split seq;
|
||||
* @example
|
||||
* seq_len: 11, parallel_num: 3
|
||||
* gen_cu_seq_lens for q: [0, 4, 0, 4, 0, 3]
|
||||
* @example
|
||||
* seq_len: 11, parallel_num: 3
|
||||
* is_causal_mask false, gen_cu_seq_lens for kv: [0, 11, 0, 11, 0, 11]
|
||||
* is_causal_mask true , gen_cu_seq_lens for kv: [0, 4, 0, 8, 0, 11]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param gen_cu_seq_lens: Output. Pointer to the MLU memory that stores the generated current seq
|
||||
* lens, the shape is [2 * loop_time].
|
||||
* @param seq_len: Sequence length.
|
||||
* @param parallel_num: Parallel num of sequence length.
|
||||
* @param is_causal_mask: Whether self attention use causal mask.
|
||||
* @param is_kv_seq_len: The gen_cu_seq_lens is for q or kv.
|
||||
*/
|
||||
KernelStatus invokeGenerateCuSeqlens(cnrtQueue_t queue,
|
||||
int *gen_cu_seq_lens,
|
||||
int seq_len,
|
||||
int parallel_num,
|
||||
bool is_causal_mask,
|
||||
bool is_kv_seq_len);
|
||||
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_OPERATE_CU_SEQ_LENS_MLUH_
|
||||
81
torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu
Normal file
81
torch_mlu_ops-v1.3.2/csrc/kernels/preload.mlu
Normal file
@@ -0,0 +1,81 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "preload.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define SRAM_SIZE ((__MLU_SRAM_SIZE__ - 32) * 1024)
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
|
||||
|
||||
__mlu_func__ void split(const int64_t total,
|
||||
const int64_t num,
|
||||
const int64_t id,
|
||||
size_t &every,
|
||||
size_t &offset) {
|
||||
int64_t base = total / num;
|
||||
int64_t tail = total - base * num;
|
||||
every = base + (id < tail ? 1 : 0);
|
||||
offset = base * id + (id < tail ? id : tail);
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUUnion1Preload(void *filter_ptr, size_t preload_size) {
|
||||
#if __BANG_ARCH__ > 372
|
||||
size_t cluster_preload_size = 0;
|
||||
size_t cluster_preload_offset = 0;
|
||||
split(preload_size, taskDimY, taskIdY, cluster_preload_size, cluster_preload_offset);
|
||||
|
||||
size_t load_repeat = cluster_preload_size / SRAM_SIZE;
|
||||
size_t load_remain = cluster_preload_size % SRAM_SIZE;
|
||||
|
||||
for (size_t i = 0; i < load_repeat + 1; i++) {
|
||||
if (i == load_repeat && load_remain == 0) {
|
||||
break;
|
||||
}
|
||||
size_t loop_load_size = (i < load_repeat ? SRAM_SIZE : load_remain);
|
||||
int8_t *gdram_ptr = (int8_t *)filter_ptr + cluster_preload_offset + i * SRAM_SIZE;
|
||||
if (loop_load_size > 0) {
|
||||
__memcpy(sram_buffer, gdram_ptr, loop_load_size, GDRAM2SRAM);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokePreload(cnrtQueue_t queue,
|
||||
void *filter_ptr,
|
||||
size_t filter_size,
|
||||
size_t preload_size) {
|
||||
if (preload_size == 0) {
|
||||
std::cerr << "[invokePreload]: preload_size must be greater than 0." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (preload_size > filter_size) {
|
||||
preload_size = filter_size;
|
||||
}
|
||||
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
cnrtDim3_t dim{.x = 4, .y = (uint32_t)cluster_num, .z = 1};
|
||||
if (cluster_num == 1) {
|
||||
dim.y = 1;
|
||||
} else if (cluster_num >= 2) {
|
||||
dim.y = 2;
|
||||
}
|
||||
|
||||
kernels::MLUUnion1Preload<<<dim, cnrtFuncTypeUnion1, queue>>>(filter_ptr, preload_size);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
34
torch_mlu_ops-v1.3.2/csrc/kernels/preload.mluh
Normal file
34
torch_mlu_ops-v1.3.2/csrc/kernels/preload.mluh
Normal file
@@ -0,0 +1,34 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_PRELOAD_MLUH_
|
||||
#define CSRC_KERNELS_PRELOAD_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief When tp is greater than 1, while executing reducesum, the weight of ffn
|
||||
* or selfattention to be calculated is loaded into LLC in advance.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param filter_ptr: Input. Pointer to the MLU memory that stores the weight of ffn or
|
||||
* selfattention.
|
||||
* @param filter_size: The weight size of ffn or selfattention.
|
||||
* @param preload_size: The size of the preload weight.
|
||||
* @note The weights of ffn or selfattention must be continuous in filter_ptr.
|
||||
*/
|
||||
KernelStatus invokePreload(cnrtQueue_t queue,
|
||||
void *filter_ptr,
|
||||
size_t filter_size,
|
||||
size_t preload_size);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_PRELOAD_MLUH_
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user