[Model] Support DeepSeek-V4
This commit is contained in:
206
.gitignore
vendored
Normal file
206
.gitignore
vendored
Normal file
@@ -0,0 +1,206 @@
|
||||
# version file generated by setuptools-scm
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
cmake-build-*/
|
||||
CMakeUserPresets.json
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
/.deps/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/source/getting_started/examples/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
# Results
|
||||
*.csv
|
||||
# but may add new accuracy reference
|
||||
!benchmarks/TruthfulQA/*.csv
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
76
CMakeLists.txt
Normal file
76
CMakeLists.txt
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(vllm_mlu_C)
|
||||
|
||||
function(detect_debian10)
|
||||
if(EXISTS "/etc/os-release")
|
||||
file(READ "/etc/os-release" os_release)
|
||||
if(os_release MATCHES "PRETTY_NAME=\"Debian GNU/Linux 10" OR
|
||||
os_release MATCHES "VERSION_ID=\"10")
|
||||
set(DEBIAN_10 TRUE PARENT_SCOPE)
|
||||
message(STATUS "Detected Debian 10 (buster)")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
detect_debian10()
|
||||
|
||||
if(DEBIAN_10)
|
||||
find_program(GCC_PATH "gcc" PATHS "/usr/local/bin")
|
||||
if(GCC_PATH)
|
||||
message(STATUS "Using GCC on Debian 10: ${GCC_PATH}")
|
||||
set(CMAKE_C_COMPILER "${GCC_PATH}")
|
||||
set(CMAKE_CXX_COMPILER "/usr/local/bin/g++")
|
||||
else()
|
||||
message(WARNING "Debian 10 detected but gcc not found!")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
# Suppress potential warnings about unused manually-specified variables
|
||||
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||
|
||||
find_package(pybind11 REQUIRED)
|
||||
|
||||
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
||||
set(VLLM_MLU_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE)
|
||||
endif()
|
||||
|
||||
file(GLOB VLLM_MLU_SRC ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
|
||||
|
||||
include_directories(
|
||||
${pybind11_INCLUDE_DIRS}
|
||||
${PYTHON_INCLUDE_PATH}
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
$ENV{NEUWARE_HOME}/include
|
||||
)
|
||||
|
||||
pybind11_add_module(vllm_mlu_C ${VLLM_MLU_SRC})
|
||||
|
||||
target_link_directories(
|
||||
vllm_mlu_C
|
||||
PRIVATE
|
||||
$ENV{NEUWARE_HOME}/lib64
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
vllm_mlu_C
|
||||
PUBLIC
|
||||
${TORCH_LIBRARIES}
|
||||
libcndrv.so
|
||||
)
|
||||
|
||||
target_link_options(vllm_mlu_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
||||
|
||||
install(TARGETS vllm_mlu_C DESTINATION ${VLLM_MLU_INSTALL_PATH})
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Modifications made by Cambricon Technologies Corporation Limited. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
116
README.md
Normal file
116
README.md
Normal file
@@ -0,0 +1,116 @@
|
||||
<!-- SPDX-License-Identifier: Apache-2.0 -->
|
||||
<!-- SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project -->
|
||||
### Cambricon vLLM (vllm_mlu)
|
||||
|
||||
#### 1. 项目描述
|
||||
|
||||
Cambricon vLLM(vllm_mlu)基于社区vLLM提供的[插件系统](https://docs.vllm.ai/en/latest/design/plugin_system.html)开发,旨在为用户提供在寒武纪MLU硬件平台上高效运行大语言模型(LLM)推理和服务的能力。
|
||||
|
||||
vllm_mlu支持包括但不限于Chunk Prefill、Prefix Caching、Spec Decode、Graph Mode、Sleep Mode等vLLM原生特性。
|
||||
|
||||
#### 2. 更新历史
|
||||
|
||||
[2026.04.24] vllm_mlu day0支持DeepSeek-V4
|
||||
|
||||
#### 3. 使用说明
|
||||
|
||||
软件环境依赖:Cambricon SDK,SDK获取请联系寒武纪官方支持渠道:[ecosystem@cambricon.com](mailto:ecosystem@cambricon.com)
|
||||
|
||||
*NOTE:vllm-mlu仓库仅支持MLU370以上的设备*
|
||||
|
||||
##### 3.1 镜像使用
|
||||
|
||||
使⽤寒武纪SDK提供的镜像 Cambricon vLLM Container。
|
||||
|
||||
```
|
||||
# 加载镜像
|
||||
|
||||
docker load -i cambricon_vllm_container.tar.gz
|
||||
|
||||
|
||||
# 进入镜像
|
||||
|
||||
docker run -it --net=host \
|
||||
--shm-size '64gb' --privileged -it \
|
||||
--ulimit memlock=-1 ${IMAGE_NAME} \
|
||||
/bin/bash
|
||||
|
||||
# 使⽤推理环境
|
||||
source /torch/venv3/pytorch_infer/bin/activate
|
||||
```
|
||||
|
||||
##### 3.2 ⾃定义安装步骤
|
||||
|
||||
安装Cambricon vLLM前需要保证依赖已正确安装。
|
||||
|
||||
安装步骤:
|
||||
|
||||
```bash
|
||||
# 已经获取Cambricon vLLM源码,包含vllm源码
|
||||
|
||||
# 基于vllm源码安装
|
||||
cd vllm-v{社区vLLM版本}/
|
||||
VLLM_TARGET_DEVICE=empty pip install -e . # 使⽤开发者模式安装
|
||||
|
||||
# 基于vllm-mlu源码安装
|
||||
git clone https://github.com/Cambricon/vllm-mlu
|
||||
cd vllm-mlu
|
||||
pip install -e . # 使⽤开发者模式安装
|
||||
|
||||
# 安装ray
|
||||
# 1. 进⼊vllm-mlu源码中。
|
||||
cd tools/ray_mlu/
|
||||
# 2. 适配基于Ray安装
|
||||
pip install --no-cache-dir --force-reinstall ray==2.51.1
|
||||
# 3. 为了在寒武纪设备运⾏,Ray也需要适配寒武纪软件。
|
||||
# PIP_INSTALL_LOC 指向pip的安装路径
|
||||
cp __init__.py ${RAY_DIR}/_private/accelerators/__init__.py
|
||||
cp mlu.py ${RAY_DIR}/_private/accelerators/
|
||||
cp nsight.py ${RAY_DIR}/_private/runtime_env/nsight.py
|
||||
cp node.py ${RAY_DIR}/_private/node.py
|
||||
cp worker.py ${RAY_DIR}/_private/worker.py
|
||||
cp device_manager/__init__.py ${RAY_DIR}/air/_internal/device_manager/__init__.py
|
||||
cp device_manager/mlu.py ${RAY_DIR}/air/_internal/device_manager/
|
||||
```
|
||||
|
||||
##### 3.3 运行步骤
|
||||
|
||||
Cambricon vLLM代码运⾏和vLLM社区⼀致。
|
||||
|
||||
###### 3.3.1 离线推理命令
|
||||
|
||||
```
|
||||
# 运行推理命令
|
||||
|
||||
python examples/offline_inference/offline_inference.py ${MODEL_PATH}
|
||||
```
|
||||
|
||||
###### 3.3.2 在线推理命令
|
||||
|
||||
分别启动server和client,完成推理服务,示例如下:
|
||||
|
||||
```
|
||||
# server
|
||||
|
||||
vllm serve ${MODEL_PATH} \
|
||||
--port 8100 \
|
||||
--block-size 1 \
|
||||
--max-model-len 4096 \
|
||||
--tensor-parallel-size 8 \
|
||||
--gpu-memory-utilization 0.96 \
|
||||
--trust-remote-code \
|
||||
--enable-expert-parallel \
|
||||
--no-enable-prefix-caching \
|
||||
--disable-log-requests \
|
||||
--enforce-eager
|
||||
|
||||
|
||||
# client, we post a single request here.
|
||||
|
||||
curl -X POST http://localhost:8100/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": ${MODEL_PATH}, \
|
||||
"prompt": "The future of AI is", \
|
||||
"max_tokens": 128, "temperature": 0.7 \
|
||||
}'
|
||||
```
|
||||
50
cmake/utils.cmake
Normal file
50
cmake/utils.cmake
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
#
|
||||
# Attempt to find the python package that uses the same python executable as
|
||||
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
|
||||
#
|
||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
||||
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
||||
set(Python_EXECUTABLE ${EXECUTABLE})
|
||||
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
|
||||
if (NOT Python_FOUND)
|
||||
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
||||
endif()
|
||||
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
|
||||
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
|
||||
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
|
||||
message(FATAL_ERROR
|
||||
"Python version (${_VER}) is not one of the supported versions: "
|
||||
"${_SUPPORTED_VERSIONS_LIST}.")
|
||||
endif()
|
||||
message(STATUS "Found python matching: ${EXECUTABLE}.")
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
|
||||
# has trailing whitespace stripped. If an error is encountered when running
|
||||
# python, a fatal message `ERR_MSG` is issued.
|
||||
#
|
||||
function (run_python OUT EXPR ERR_MSG)
|
||||
execute_process(
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}" "-c" "${EXPR}"
|
||||
OUTPUT_VARIABLE PYTHON_OUT
|
||||
RESULT_VARIABLE PYTHON_ERROR_CODE
|
||||
ERROR_VARIABLE PYTHON_STDERR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(NOT PYTHON_ERROR_CODE EQUAL 0)
|
||||
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
|
||||
endif()
|
||||
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
||||
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
||||
macro (append_cmake_prefix_path PKG EXPR)
|
||||
run_python(_PREFIX_PATH
|
||||
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
|
||||
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
|
||||
endmacro()
|
||||
310
csrc/cnmem_allocator.cpp
Normal file
310
csrc/cnmem_allocator.cpp
Normal file
@@ -0,0 +1,310 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
// A MLU PluggableAllocator based on cn_api APIs.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <cn_api.h>
|
||||
|
||||
|
||||
#define DRV_CHECK_GET_RETURN(...) \
|
||||
DRV_CHECK_GET_RETURN_IMPL(__VA_ARGS__, return, )
|
||||
#define DRV_CHECK_GET_RETURN_IMPL(_1, _2, ...) _2
|
||||
|
||||
#define CN_CHECK(return_code, ...) \
|
||||
do { \
|
||||
CNresult rc = (return_code); \
|
||||
if (rc) { \
|
||||
const char *error_str; \
|
||||
cnGetErrorString(rc, &error_str); \
|
||||
std::cout << "Error: " << error_str \
|
||||
<< " at " << __FILE__ \
|
||||
<< ":" << __LINE__ \
|
||||
<< std::endl; \
|
||||
DRV_CHECK_GET_RETURN(__VA_ARGS__) \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Global references to Python callables
|
||||
static PyObject* g_python_malloc_callback = nullptr;
|
||||
static PyObject* g_python_free_callback = nullptr;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper functions:
|
||||
void ensure_context(CNdev device) {
|
||||
CNcontext pctx;
|
||||
CN_CHECK(cnCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
// Ensure device context;
|
||||
CN_CHECK(cnCtxCreate(&pctx, 0, device));
|
||||
CN_CHECK(cnCtxSetCurrent(pctx));
|
||||
}
|
||||
}
|
||||
|
||||
void create_and_map(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) {
|
||||
ensure_context(device);
|
||||
// Define memory allocation properties
|
||||
CNmemAllocationProp prop = {};
|
||||
// The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide.
|
||||
prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED
|
||||
prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE;
|
||||
prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Allocate memory using cnMemCreate
|
||||
CN_CHECK(cnMemCreate(p_memHandle, size, &prop, 0));
|
||||
CN_CHECK(cnMemMap(d_mem, size, 0, *p_memHandle, 0));
|
||||
|
||||
CNmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = device;
|
||||
accessDesc.accessFlags = CN_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
CN_CHECK(cnMemSetAccess(d_mem, size, &accessDesc, 1));
|
||||
}
|
||||
|
||||
void unmap_and_release(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) {
|
||||
ensure_context(device);
|
||||
CN_CHECK(cnMemUnmap(d_mem, size));
|
||||
CN_CHECK(cnMemRelease(*p_memHandle));
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
unsigned long long b,
|
||||
unsigned long long c,
|
||||
unsigned long long d) {
|
||||
// Create a new tuple of size 4
|
||||
PyObject* tuple = PyTuple_New(4);
|
||||
if (!tuple) {
|
||||
return NULL;
|
||||
}
|
||||
// Convert integers to Python objects and set them in the tuple
|
||||
// Steals reference to the PyLong
|
||||
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
|
||||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
|
||||
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
|
||||
|
||||
// Note: PyTuple_SetItem "steals" a reference to each object,
|
||||
// so we do not need to Py_DECREF the PyLong objects explicitly.
|
||||
|
||||
return tuple;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Our exported C functions that call Python:
|
||||
__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, CNqueue stream) {
|
||||
ensure_context(device);
|
||||
// first allocation, align the size, and reserve an address, and also allocate
|
||||
// a CNmemGenericAllocationHandle
|
||||
|
||||
// Define memory allocation properties
|
||||
CNmemAllocationProp prop = {};
|
||||
// The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide.
|
||||
prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED
|
||||
prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE;
|
||||
prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
//Check if the allocation is supported
|
||||
size_t granularity;
|
||||
CN_CHECK(cnMemGetAllocationGranularity(&granularity, &prop, CN_MEM_ALLOC_GRANULARITY_MINIMUM), nullptr);
|
||||
|
||||
size_t alignedSize = ((size+granularity-1)/granularity)*granularity;
|
||||
CNaddr d_mem;
|
||||
CN_CHECK(cnMemAddressReserve(&d_mem, alignedSize, 0, 0, 0), nullptr);
|
||||
|
||||
// allocate the CNmemGenericAllocationHandle
|
||||
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)malloc(sizeof(CNmemGenericAllocationHandle));
|
||||
|
||||
if (!g_python_malloc_callback) {
|
||||
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
|
||||
return nullptr;
|
||||
}
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* arg_tuple = create_tuple_from_c_integers(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
|
||||
|
||||
// Call g_python_malloc_callback
|
||||
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
|
||||
Py_DECREF(arg_tuple);
|
||||
|
||||
if (!py_result) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// do the final mapping
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle);
|
||||
|
||||
return (void*)d_mem;
|
||||
}
|
||||
|
||||
__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, CNqueue stream) {
|
||||
// get memory handle from the pointer
|
||||
if (!g_python_free_callback) {
|
||||
std::cerr << "ERROR: g_python_free_callback not set.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
|
||||
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
|
||||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// Free memory
|
||||
CNaddr d_mem = (CNaddr)recv_d_mem;
|
||||
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
unmap_and_release(device, size, d_mem, p_memHandle);
|
||||
|
||||
//free address and the handle
|
||||
CN_CHECK(cnMemAddressFree(d_mem, size));
|
||||
free(p_memHandle);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Python extension boilerplate:
|
||||
// Python-exposed function: init_module(python_malloc, python_free)
|
||||
static PyObject* py_init_module(PyObject* self, PyObject* args) {
|
||||
PyObject* malloc_callback = nullptr;
|
||||
PyObject* free_callback = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Save the Python callables
|
||||
// This module does not handle GC of these objects, so they must be kept alive
|
||||
// outside of this module.
|
||||
// This module keeps a strong reference to prevent premature GC
|
||||
Py_XINCREF(malloc_callback);
|
||||
Py_XINCREF(free_callback);
|
||||
|
||||
Py_XDECREF(g_python_malloc_callback);
|
||||
Py_XDECREF(g_python_free_callback);
|
||||
|
||||
g_python_malloc_callback = malloc_callback;
|
||||
g_python_free_callback = free_callback;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNaddr d_mem_ptr = (CNaddr)recv_d_mem;
|
||||
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNaddr d_mem_ptr = (CNaddr)recv_d_mem;
|
||||
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_cn_memcpy(PyObject* self, PyObject* args){
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 3) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 3");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNaddr dst, src;
|
||||
cn_uint64_t bytes;
|
||||
if (!PyArg_ParseTuple(args, "KKK", &dst, &src, &bytes)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CN_CHECK(cnMemcpy(dst, src, bytes), nullptr);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
|
||||
"Initialize module with python_malloc and python_free callables."},
|
||||
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
|
||||
"Create and map memory on the device."},
|
||||
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
|
||||
METH_VARARGS, "Unmap and release memory on the device."},
|
||||
{"python_cn_memcpy", (PyCFunction)python_cn_memcpy, METH_VARARGS, "Copies data from source address to destination address."},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef cnmem_allocator_module = {
|
||||
PyModuleDef_HEAD_INIT, "cnmem_allocator",
|
||||
"cnapi-mem-based allocator for MLUPluggableAllocator", -1, module_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_vllm_mlu_C(void) {
|
||||
// Initialize the module
|
||||
PyObject* module = PyModule_Create(&cnmem_allocator_module);
|
||||
if (!module) {
|
||||
return NULL;
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
34
csrc/ops.h
Normal file
34
csrc/ops.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace vllm_mlu {
|
||||
|
||||
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
|
||||
// Ensure tensor is on MLU
|
||||
if (!tensor.is_privateuseone()) {
|
||||
throw std::runtime_error("Tensor must be on MLU device");
|
||||
}
|
||||
|
||||
// Get the raw data pointer
|
||||
void* data_ptr = tensor.data_ptr();
|
||||
|
||||
// Get tensor sizes and strides
|
||||
std::vector<int64_t> sizes = tensor.sizes().vec();
|
||||
std::vector<int64_t> strides = tensor.strides().vec();
|
||||
|
||||
// Get tensor options (dtype, device)
|
||||
auto options = tensor.options();
|
||||
|
||||
// Create a new tensor from the raw data pointer
|
||||
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
|
||||
|
||||
return new_tensor;
|
||||
}
|
||||
|
||||
}
|
||||
18
csrc/torch_bindings.cpp
Normal file
18
csrc/torch_bindings.cpp
Normal file
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/version.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
|
||||
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
{
|
||||
// vLLM-MLU custom ops
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_mlu::weak_ref_tensor);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C)
|
||||
29
csrc/utils.h
Normal file
29
csrc/utils.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
||||
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
||||
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
48
examples/offline_inference/offline_inference.py
Normal file
48
examples/offline_inference/offline_inference.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import sys
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main(model_path):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"The benefits of exercise include",
|
||||
"The importance of reading books is",
|
||||
"Gardening can be relaxing because",
|
||||
"A good night's sleep is essential for",
|
||||
]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6, top_p=0.95, max_tokens=10)
|
||||
|
||||
# Create an LLM.
|
||||
engine_args_dict = {
|
||||
"model": model_path,
|
||||
"tensor_parallel_size": 8,
|
||||
"enable_expert_parallel": True,
|
||||
"enable_prefix_caching": False,
|
||||
"enforce_eager": True,
|
||||
"trust_remote_code": True,
|
||||
"max_num_seqs": len(prompts),
|
||||
"max_model_len": 4096,
|
||||
"block_size": 1,
|
||||
"gpu_memory_utilization": 0.96,
|
||||
}
|
||||
llm = LLM(**engine_args_dict)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python offline_inference.py <model_path>")
|
||||
sys.exit(1)
|
||||
main(sys.argv[1])
|
||||
14
requirements.txt
Normal file
14
requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
# Dependencies for Cambricon MLUs
|
||||
ray == 2.51.1
|
||||
click == 8.2.1
|
||||
triton >= 3.2.0
|
||||
torch == 2.9.1
|
||||
torch-mlu >= 1.29.1
|
||||
torch_mlu_ops >= 1.8.1
|
||||
|
||||
matplotlib == 3.10.3
|
||||
datasets == 3.6.0
|
||||
blobfile == 3.0.0
|
||||
scipy == 1.10.1
|
||||
276
setup.py
Normal file
276
setup.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import importlib.util
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from sysconfig import get_paths
|
||||
from typing import List, Dict
|
||||
from setuptools import Extension
|
||||
from setuptools import find_namespace_packages, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.install import install
|
||||
from setuptools.command.develop import develop
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def check_or_set_default_env(cmake_args, env_name, env_variable, default_path=""):
|
||||
if env_variable is None:
|
||||
logging.warning(f"Set default {env_name}: {default_path}")
|
||||
env_variable = default_path
|
||||
else:
|
||||
logging.info(f"Found existing {env_name}: {env_variable}")
|
||||
cmake_args += [f"-D{env_name}={env_variable}"]
|
||||
return cmake_args
|
||||
|
||||
def load_module_from_path(module_name, path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm_mlu", "envs.py"))
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self,
|
||||
name: str,
|
||||
cmake_lists_dir: str = ".",
|
||||
**kwargs) -> None:
|
||||
super().__init__(name, sources=[], py_limited_api=False, **kwargs)
|
||||
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
|
||||
|
||||
def get_vllm_version() -> str:
|
||||
"""
|
||||
get vllm version
|
||||
"""
|
||||
with open(get_path("tools/build.property"), 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
results = re.findall(r'VLLM_VERSION=([\d|\.]+)\+mlu([\d|\.]+)\.pt(\d+)', content)
|
||||
|
||||
assert results, "fail to get vllm, vllm_mlu and pytorch version."
|
||||
|
||||
version = f"{results[-1][0]}+mlu{results[-1][1]}.pt{results[-1][2]}"
|
||||
|
||||
return version
|
||||
|
||||
|
||||
def read_readme() -> str:
|
||||
"""Read the README file if present."""
|
||||
p = get_path("README.md")
|
||||
if os.path.isfile(p):
|
||||
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
"""Get Python package dependencies from requirements.txt."""
|
||||
|
||||
def _read_requirements(filename: str) -> List[str]:
|
||||
with open(get_path(filename)) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
resolved_requirements = []
|
||||
for line in requirements:
|
||||
if line.startswith("-r "):
|
||||
resolved_requirements += _read_requirements(line.split()[1])
|
||||
elif line.startswith("--"):
|
||||
continue
|
||||
else:
|
||||
resolved_requirements.append(line)
|
||||
return resolved_requirements
|
||||
return _read_requirements("requirements.txt")
|
||||
|
||||
|
||||
class cmake_build_ext(build_ext):
|
||||
# A dict of extension directories that have been configured.
|
||||
did_config: Dict[str, bool] = {}
|
||||
|
||||
# Determine number of compilation jobs
|
||||
def compute_num_jobs(self):
|
||||
# `num_jobs` is either the value of the MAX_JOBS environment variable
|
||||
# (if defined) or the number of CPUs available.
|
||||
num_jobs = envs.MAX_JOBS
|
||||
if num_jobs is not None:
|
||||
num_jobs = int(num_jobs)
|
||||
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
|
||||
else:
|
||||
try:
|
||||
# os.sched_getaffinity() isn't universally available, so fall
|
||||
# back to os.cpu_count() if we get an error here.
|
||||
num_jobs = len(os.sched_getaffinity(0))
|
||||
except AttributeError:
|
||||
num_jobs = os.cpu_count()
|
||||
num_jobs = max(1, num_jobs)
|
||||
|
||||
return num_jobs
|
||||
|
||||
#
|
||||
# Perform cmake configuration for a single extension.
|
||||
#
|
||||
def configure(self, ext: CMakeExtension) -> None:
|
||||
os.makedirs(self.build_temp, exist_ok=True)
|
||||
source_dir = os.path.abspath(ROOT_DIR)
|
||||
python_executable = sys.executable
|
||||
cmake_args = ["cmake"]
|
||||
# Default use release mode to compile the csrc code
|
||||
# Turbo now support compiled with Release, Debug and RelWithDebugInfo
|
||||
if envs.CMAKE_BUILD_TYPE is None or envs.CMAKE_BUILD_TYPE not in [
|
||||
"Debug",
|
||||
"Release",
|
||||
"RelWithDebugInfo",
|
||||
]:
|
||||
envs.CMAKE_BUILD_TYPE = "Release"
|
||||
cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"]
|
||||
# Default dump the compile commands for lsp
|
||||
cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"]
|
||||
if envs.CXX_COMPILER is not None:
|
||||
cmake_args += [f"-DCMAKE_CXX_COMPILER={envs.CXX_COMPILER}"]
|
||||
if envs.C_COMPILER is not None:
|
||||
cmake_args += [f"-DCMAKE_C_COMPILER={envs.C_COMPILER}"]
|
||||
if envs.VERBOSE:
|
||||
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
|
||||
|
||||
# find PYTHON_EXECUTABLE
|
||||
check_or_set_default_env(cmake_args, "PYTHON_EXECUTABLE", sys.executable)
|
||||
|
||||
# find PYTHON_INCLUDE_PATH
|
||||
check_or_set_default_env(cmake_args, "PYTHON_INCLUDE_PATH",
|
||||
get_paths()["include"])
|
||||
|
||||
try:
|
||||
# if pybind11 is installed via pip
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "pybind11==2.13.6"])
|
||||
pybind11_cmake_path = (subprocess.check_output([python_executable, "-m",
|
||||
"pybind11", "--cmake"]).decode().strip())
|
||||
except subprocess.CalledProcessError as e:
|
||||
# else specify pybind11 path installed from source code on CI container
|
||||
raise RuntimeError(f"CMake configuration failed: {e}")
|
||||
|
||||
install_path = os.path.join(ROOT_DIR, self.build_lib)
|
||||
if isinstance(self.distribution.get_command_obj("develop"), develop):
|
||||
install_path = os.path.join(ROOT_DIR, "vllm_mlu")
|
||||
|
||||
# add CMAKE_INSTALL_PATH
|
||||
cmake_args += [f"-DCMAKE_INSTALL_PREFIX={install_path}"]
|
||||
|
||||
cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"]
|
||||
|
||||
cmake_args += [source_dir]
|
||||
logging.info(f"cmake config command: {cmake_args}")
|
||||
try:
|
||||
subprocess.check_call(cmake_args, cwd=self.build_temp)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"CMake configuration failed: {e}")
|
||||
|
||||
def build_extensions(self) -> None:
|
||||
if not envs.COMPILE_CUSTOM_KERNELS:
|
||||
return
|
||||
# Ensure that CMake is present and working
|
||||
try:
|
||||
subprocess.check_output(["cmake", "--version"])
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Cannot find CMake executable: {e}")
|
||||
|
||||
# Create build directory if it does not exist.
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
os.makedirs(os.path.join(self.build_lib, "vllm_mlu"), exist_ok=True)
|
||||
|
||||
targets = []
|
||||
|
||||
def get_target_name(s: str) -> str:
|
||||
return s.removeprefix("vllm_mlu.")
|
||||
|
||||
# Build all the extensions
|
||||
for ext in self.extensions:
|
||||
self.configure(ext)
|
||||
targets.append(get_target_name(ext.name))
|
||||
|
||||
num_jobs = self.compute_num_jobs()
|
||||
|
||||
build_args = ["--build", ".", f"-j={num_jobs}",
|
||||
*[f"--target={name}" for name in targets],
|
||||
]
|
||||
logger.info(build_args)
|
||||
try:
|
||||
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Build library failed: {e}")
|
||||
|
||||
# Install the libraries
|
||||
install_args = ["--install", ".", ]
|
||||
try:
|
||||
subprocess.check_call(["cmake", *install_args], cwd=self.build_temp)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Install library failed: {e}")
|
||||
|
||||
# copy back to build folder for editable build
|
||||
if isinstance(self.distribution.get_command_obj("develop"), develop):
|
||||
for root, _, files in os.walk(self.build_temp):
|
||||
for file in files:
|
||||
if file.endswith(".so"):
|
||||
src_path = os.path.join(root, file)
|
||||
dst_path = os.path.join(self.build_lib, "vllm_mlu", file)
|
||||
self.copy_file(src_path, dst_path)
|
||||
logger.info(f"Copy: {src_path} -> {dst_path}")
|
||||
|
||||
def run(self):
|
||||
# First, run the standard build_ext command to compile the extensions
|
||||
super().run()
|
||||
|
||||
|
||||
class custom_install(install):
|
||||
def run(self):
|
||||
self.run_command("build_ext")
|
||||
install.run(self)
|
||||
|
||||
ext_modules = []
|
||||
if envs.COMPILE_CUSTOM_KERNELS:
|
||||
ext_modules = [CMakeExtension(name="vllm_mlu.vllm_mlu_C")]
|
||||
cmdclass = {"build_ext": cmake_build_ext, "install": custom_install}
|
||||
|
||||
setup(
|
||||
name="vllm_mlu",
|
||||
version=get_vllm_version(),
|
||||
author="Cambricon vLLM Team",
|
||||
license="Apache 2.0",
|
||||
description=("A high-throughput and memory-efficient inference and "
|
||||
"serving engine for LLMs on MLU backend"),
|
||||
long_description=read_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="",
|
||||
project_urls={
|
||||
"Homepage": "https://github.com/vllm-project/vllm",
|
||||
"Documentation": "https://vllm.readthedocs.io/en/latest/",
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
packages=find_namespace_packages(exclude=("docs", "examples", "tests*", "csrc")),
|
||||
include_package_data=True,
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules = ext_modules,
|
||||
cmdclass=cmdclass,
|
||||
entry_points={
|
||||
'vllm.platform_plugins': ["mlu = vllm_mlu:register_mlu_platform"],
|
||||
'vllm.general_plugins': ["mlu_hijack = vllm_mlu:register_mlu_hijack"]
|
||||
}
|
||||
)
|
||||
89
tools/ray_mlu/__init__.py
Normal file
89
tools/ray_mlu/__init__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Optional, Set
|
||||
|
||||
from ray._private.accelerators.accelerator import (
|
||||
RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR,
|
||||
AcceleratorManager,
|
||||
)
|
||||
from ray._private.accelerators.amd_gpu import AMDGPUAcceleratorManager
|
||||
from ray._private.accelerators.hpu import HPUAcceleratorManager
|
||||
from ray._private.accelerators.intel_gpu import IntelGPUAcceleratorManager
|
||||
from ray._private.accelerators.neuron import NeuronAcceleratorManager
|
||||
from ray._private.accelerators.npu import NPUAcceleratorManager
|
||||
from ray._private.accelerators.nvidia_gpu import NvidiaGPUAcceleratorManager
|
||||
from ray._private.accelerators.rbln import RBLNAcceleratorManager
|
||||
from ray._private.accelerators.tpu import TPUAcceleratorManager
|
||||
from ray._private.accelerators.mlu import MLUAcceleratorManager
|
||||
|
||||
|
||||
def get_all_accelerator_managers() -> Set[AcceleratorManager]:
|
||||
"""Get all accelerator managers supported by Ray."""
|
||||
return {
|
||||
NvidiaGPUAcceleratorManager,
|
||||
IntelGPUAcceleratorManager,
|
||||
AMDGPUAcceleratorManager,
|
||||
TPUAcceleratorManager,
|
||||
NeuronAcceleratorManager,
|
||||
HPUAcceleratorManager,
|
||||
NPUAcceleratorManager,
|
||||
RBLNAcceleratorManager,
|
||||
MLUAcceleratorManager,
|
||||
}
|
||||
|
||||
|
||||
def get_all_accelerator_resource_names() -> Set[str]:
|
||||
"""Get all resource names for accelerators."""
|
||||
return {
|
||||
accelerator_manager.get_resource_name()
|
||||
for accelerator_manager in get_all_accelerator_managers()
|
||||
}
|
||||
|
||||
|
||||
def get_accelerator_manager_for_resource(
|
||||
resource_name: str,
|
||||
) -> Optional[AcceleratorManager]:
|
||||
"""Get the corresponding accelerator manager for the given
|
||||
accelerator resource name
|
||||
|
||||
E.g., TPUAcceleratorManager is returned if resource name is "TPU"
|
||||
"""
|
||||
try:
|
||||
return get_accelerator_manager_for_resource._resource_name_to_accelerator_manager.get( # noqa: E501
|
||||
resource_name, None
|
||||
)
|
||||
except AttributeError:
|
||||
# Lazy initialization.
|
||||
resource_name_to_accelerator_manager = {
|
||||
accelerator_manager.get_resource_name(): accelerator_manager
|
||||
for accelerator_manager in get_all_accelerator_managers()
|
||||
}
|
||||
# Special handling for GPU resource name since multiple accelerator managers
|
||||
# have the same GPU resource name.
|
||||
if AMDGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
||||
resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager
|
||||
elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
||||
resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager
|
||||
elif MLUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
||||
resource_name_to_accelerator_manager["GPU"] = MLUAcceleratorManager
|
||||
else:
|
||||
resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager
|
||||
get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = (
|
||||
resource_name_to_accelerator_manager
|
||||
)
|
||||
return resource_name_to_accelerator_manager.get(resource_name, None)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NvidiaGPUAcceleratorManager",
|
||||
"IntelGPUAcceleratorManager",
|
||||
"AMDGPUAcceleratorManager",
|
||||
"TPUAcceleratorManager",
|
||||
"NeuronAcceleratorManager",
|
||||
"HPUAcceleratorManager",
|
||||
"NPUAcceleratorManager",
|
||||
"RBLNAcceleratorManager",
|
||||
"MLUAcceleratorManager",
|
||||
"get_all_accelerator_managers",
|
||||
"get_all_accelerator_resource_names",
|
||||
"get_accelerator_manager_for_resource",
|
||||
"RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR",
|
||||
]
|
||||
114
tools/ray_mlu/device_manager/__init__.py
Normal file
114
tools/ray_mlu/device_manager/__init__.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import ray._private.ray_constants as ray_constants
|
||||
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.mlu import MLUTorchDeviceManager
|
||||
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager
|
||||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use MLUTorchDeviceManager when key="GPU"
|
||||
'''
|
||||
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = {
|
||||
ray_constants.GPU: MLUTorchDeviceManager,
|
||||
ray_constants.HPU: HPUTorchDeviceManager,
|
||||
ray_constants.NPU: NPUTorchDeviceManager,
|
||||
}
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None:
|
||||
if backend == "hccl":
|
||||
# The name for the communication backend of Habana and torch-npu is the same.
|
||||
HPUTorchDeviceManager.register_custom_torch_dist_backend()
|
||||
|
||||
NPUTorchDeviceManager.register_custom_torch_dist_backend()
|
||||
|
||||
|
||||
_torch_device_manager = None
|
||||
_torch_device_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_torch_device_manager_by_context() -> TorchDeviceManager:
|
||||
global _torch_device_manager
|
||||
|
||||
with _torch_device_manager_lock:
|
||||
if not _torch_device_manager:
|
||||
existing_device_manager_cls = None
|
||||
resources = ray.get_runtime_context().get_accelerator_ids()
|
||||
|
||||
# select correct accelerator type from resources
|
||||
for resource_type, resource_value in resources.items():
|
||||
device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get(
|
||||
resource_type, None
|
||||
)
|
||||
if resource_value and device_manager_cls:
|
||||
# An error will raise when multiple accelerators are specified.
|
||||
if existing_device_manager_cls:
|
||||
raise RuntimeError(
|
||||
"Unable to determine the appropriate DeviceManager "
|
||||
f"for the specified resources {resources}."
|
||||
)
|
||||
else:
|
||||
existing_device_manager_cls = device_manager_cls
|
||||
|
||||
device_manager_cls = (
|
||||
existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS
|
||||
)
|
||||
|
||||
_torch_device_manager = device_manager_cls()
|
||||
|
||||
return _torch_device_manager
|
||||
|
||||
|
||||
def get_torch_device_manager_by_device_type(device_type: str):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use MLUTorchDeviceManager when key="GPU"
|
||||
'''
|
||||
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda":
|
||||
return MLUTorchDeviceManager()
|
||||
elif device_type.lower() == ray_constants.NPU.lower():
|
||||
return NPUTorchDeviceManager()
|
||||
elif device_type.lower() == ray_constants.HPU.lower():
|
||||
return HPUTorchDeviceManager()
|
||||
elif device_type.lower() == "cpu":
|
||||
return CPUTorchDeviceManager()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
raise RuntimeError(f"Device type {device_type} cannot be recognized.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
TorchDeviceManager,
|
||||
CPUTorchDeviceManager,
|
||||
CUDATorchDeviceManager,
|
||||
HPUTorchDeviceManager,
|
||||
NPUTorchDeviceManager,
|
||||
MLUTorchDeviceManager,
|
||||
register_custom_torch_dist_backend,
|
||||
get_torch_device_manager_by_context,
|
||||
get_torch_device_manager_by_device_type,
|
||||
]
|
||||
103
tools/ray_mlu/device_manager/mlu.py
Normal file
103
tools/ray_mlu/device_manager/mlu.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
import ray
|
||||
import ray._private.ray_constants as ray_constants
|
||||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
|
||||
from ray._private.accelerators.mlu import MLU_VISIBLE_DEVICES_ENV_VAR
|
||||
|
||||
|
||||
def is_package_present(package_name: str) -> bool:
|
||||
try:
|
||||
return find_spec(package_name) is not None
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
MLU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_mlu")
|
||||
|
||||
|
||||
if MLU_TORCH_PACKAGE_AVAILABLE:
|
||||
import torch_mlu # noqa: F401
|
||||
|
||||
|
||||
class MLUTorchDeviceManager(TorchDeviceManager):
|
||||
"""Cambricon MLU device manager"""
|
||||
|
||||
@staticmethod
|
||||
def register_custom_torch_dist_backend():
|
||||
if MLU_TORCH_PACKAGE_AVAILABLE:
|
||||
import torch_mlu # noqa: F401, F811
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if not MLU_TORCH_PACKAGE_AVAILABLE:
|
||||
return False
|
||||
|
||||
return torch.mlu.is_available()
|
||||
|
||||
def get_devices(self) -> List[torch.device]:
|
||||
"""Gets the correct torch device list configured for this process.
|
||||
Returns a list of torch MLU devices allocated for the current worker.
|
||||
If no MLUs are assigned, then it returns a list with a single CPU device.
|
||||
"""
|
||||
if MLU_TORCH_PACKAGE_AVAILABLE and torch.mlu.is_available():
|
||||
mlu_ids = [
|
||||
str(id)
|
||||
for id in ray.get_runtime_context().get_accelerator_ids()[
|
||||
ray_constants.GPU
|
||||
]
|
||||
]
|
||||
|
||||
device_ids = []
|
||||
|
||||
if len(mlu_ids) > 0:
|
||||
mlu_visible_str = os.environ.get(MLU_VISIBLE_DEVICES_ENV_VAR, "")
|
||||
if mlu_visible_str and mlu_visible_str != "NoDevFiles":
|
||||
mlu_visible_list = mlu_visible_str.split(",")
|
||||
else:
|
||||
mlu_visible_list = []
|
||||
|
||||
for mlu_id in mlu_ids:
|
||||
try:
|
||||
device_ids.append(mlu_visible_list.index(mlu_id))
|
||||
except IndexError:
|
||||
raise RuntimeError(
|
||||
"MLU_VISIBLE_DEVICES set incorrectly. "
|
||||
f"Got {mlu_visible_str}, expected to include {mlu_id}. "
|
||||
"Did you override the `MLU_VISIBLE_DEVICES` "
|
||||
"environment variable?"
|
||||
)
|
||||
else:
|
||||
# If called on the driver or outside of Ray Train, return the
|
||||
# 0th device.
|
||||
device_ids.append(0)
|
||||
|
||||
devices = [torch.device(f"mlu:{device_id}") for device_id in device_ids]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Using MLUTorchDeviceManager but torch mlu is not available."
|
||||
)
|
||||
|
||||
return devices
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]):
|
||||
torch.mlu.set_device(device)
|
||||
|
||||
def supports_stream(self) -> bool:
|
||||
"""Validate if the device type support to create a stream"""
|
||||
return True
|
||||
|
||||
def create_stream(self, device):
|
||||
"""Create a stream on MLU device"""
|
||||
return torch.mlu.Stream(device)
|
||||
|
||||
def get_stream_context(self, stream):
|
||||
"""Get a torch.stream context on MLU device"""
|
||||
return torch.mlu.stream(stream)
|
||||
|
||||
def get_current_stream(self):
|
||||
"""Get current stream for MLU device"""
|
||||
return torch.mlu.current_stream()
|
||||
243
tools/ray_mlu/diff.patch
Normal file
243
tools/ray_mlu/diff.patch
Normal file
@@ -0,0 +1,243 @@
|
||||
commit 7376225d16e381ecae5cc07d84db9eed043ed06a
|
||||
Author: tanhaojue <tanhaojue@cambricon.com>
|
||||
Date: Thu Mar 7 15:54:09 2024 +0800
|
||||
|
||||
support mlu
|
||||
|
||||
diff --git a/python/ray/_private/accelerators/__init__.py b/python/ray/_private/accelerators/__init__.py
|
||||
index 71550bc..07bdcd6 100644
|
||||
--- a/python/ray/_private/accelerators/__init__.py
|
||||
+++ b/python/ray/_private/accelerators/__init__.py
|
||||
@@ -8,6 +8,7 @@ from ray._private.accelerators.tpu import TPUAcceleratorManager
|
||||
from ray._private.accelerators.neuron import NeuronAcceleratorManager
|
||||
from ray._private.accelerators.hpu import HPUAcceleratorManager
|
||||
from ray._private.accelerators.npu import NPUAcceleratorManager
|
||||
+from ray._private.accelerators.mlu import MLUAcceleratorManager
|
||||
|
||||
|
||||
def get_all_accelerator_managers() -> Set[AcceleratorManager]:
|
||||
@@ -20,6 +21,7 @@ def get_all_accelerator_managers() -> Set[AcceleratorManager]:
|
||||
NeuronAcceleratorManager,
|
||||
HPUAcceleratorManager,
|
||||
NPUAcceleratorManager,
|
||||
+ MLUAcceleratorManager,
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +57,8 @@ def get_accelerator_manager_for_resource(
|
||||
resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager
|
||||
elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
||||
resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager
|
||||
+ elif MLUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
||||
+ resource_name_to_accelerator_manager["GPU"] = MLUAcceleratorManager
|
||||
else:
|
||||
resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager
|
||||
get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = (
|
||||
@@ -71,6 +75,7 @@ __all__ = [
|
||||
"NeuronAcceleratorManager",
|
||||
"HPUAcceleratorManager",
|
||||
"NPUAcceleratorManager",
|
||||
+ "MLUAcceleratorManager",
|
||||
"get_all_accelerator_managers",
|
||||
"get_all_accelerator_resource_names",
|
||||
"get_accelerator_manager_for_resource",
|
||||
diff --git a/python/ray/_private/accelerators/mlu.py b/python/ray/_private/accelerators/mlu.py
|
||||
new file mode 100755
|
||||
index 0000000..21a5771
|
||||
--- /dev/null
|
||||
+++ b/python/ray/_private/accelerators/mlu.py
|
||||
@@ -0,0 +1,92 @@
|
||||
+import os
|
||||
+import glob
|
||||
+import logging
|
||||
+from typing import Optional, List, Tuple
|
||||
+import torch
|
||||
+import torch_mlu
|
||||
+from ray._private.accelerators.accelerator import AcceleratorManager
|
||||
+
|
||||
+logger = logging.getLogger(__name__)
|
||||
+
|
||||
+MLU_VISIBLE_DEVICES_ENV_VAR = "MLU_VISIBLE_DEVICES"
|
||||
+NOSET_MLU_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_MLU_VISIBLE_DEVICES"
|
||||
+
|
||||
+
|
||||
+class MLUAcceleratorManager(AcceleratorManager):
|
||||
+ """Cambricon MLU accelerators."""
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def get_resource_name() -> str:
|
||||
+ return "GPU"
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def get_visible_accelerator_ids_env_var() -> str:
|
||||
+ return MLU_VISIBLE_DEVICES_ENV_VAR
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
||||
+ mlu_visible_devices = os.environ.get(
|
||||
+ MLUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
||||
+ )
|
||||
+
|
||||
+ if mlu_visible_devices is None:
|
||||
+ return None
|
||||
+
|
||||
+ if mlu_visible_devices == "":
|
||||
+ return []
|
||||
+
|
||||
+ if mlu_visible_devices == "NoDevFiles":
|
||||
+ return []
|
||||
+
|
||||
+ return list(mlu_visible_devices.split(","))
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def get_current_node_num_accelerators() -> int:
|
||||
+ """Attempt to detect the number of MLUs on this machine.
|
||||
+
|
||||
+ MLU chips are represented as devices within `/dev/`, either as `/dev/davinci?`.
|
||||
+
|
||||
+ Returns:
|
||||
+ The number of MLUs if any were detected, otherwise 0.
|
||||
+ """
|
||||
+ try:
|
||||
+ return torch.mlu.device_count()
|
||||
+ except Exception as e:
|
||||
+ logger.debug("Could not import CambriconCL: %s", e)
|
||||
+
|
||||
+ try:
|
||||
+ mlu_files = glob.glob("/dev/cambricon_dev?")
|
||||
+ return len(mlu_files)
|
||||
+ except Exception as e:
|
||||
+ logger.debug("Failed to detect number of MLUs: %s", e)
|
||||
+ return 0
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def get_current_node_accelerator_type() -> Optional[str]:
|
||||
+ """Get the type of the Cambricon MLU on the current node.
|
||||
+
|
||||
+ Returns:
|
||||
+ A string of the type, such as "MLU370".
|
||||
+ """
|
||||
+ try:
|
||||
+ return torch.mlu.get_device_name(0)
|
||||
+ except Exception:
|
||||
+ logger.exception("Failed to detect MLU type.")
|
||||
+ return None
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def validate_resource_request_quantity(
|
||||
+ quantity: float,
|
||||
+ ) -> Tuple[bool, Optional[str]]:
|
||||
+ return (True, None)
|
||||
+
|
||||
+ @staticmethod
|
||||
+ def set_current_process_visible_accelerator_ids(
|
||||
+ visible_mlu_devices: List[str],
|
||||
+ ) -> None:
|
||||
+ if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR):
|
||||
+ return
|
||||
+
|
||||
+ os.environ[
|
||||
+ MLUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
||||
+ ] = ",".join([str(i) for i in visible_mlu_devices])
|
||||
diff --git a/python/ray/tests/accelerators/test_mlu.py b/python/ray/tests/accelerators/test_mlu.py
|
||||
new file mode 100755
|
||||
index 0000000..70e81f7
|
||||
--- /dev/null
|
||||
+++ b/python/ray/tests/accelerators/test_mlu.py
|
||||
@@ -0,0 +1,92 @@
|
||||
+import os
|
||||
+import sys
|
||||
+import pytest
|
||||
+from unittest.mock import patch
|
||||
+
|
||||
+import ray
|
||||
+from ray._private.accelerators import MLUAcceleratorManager as Accelerator
|
||||
+
|
||||
+
|
||||
+@patch("glob.glob")
|
||||
+@patch("os.listdir")
|
||||
+def test_autodetect_num_mlus(mock_list, mock_glob):
|
||||
+ mock_glob.return_value = [f"/dev/davinci{i}" for i in range(4)]
|
||||
+ # mock_list.return_value = []
|
||||
+ assert Accelerator.get_current_node_num_accelerators() == 4
|
||||
+
|
||||
+
|
||||
+@patch("glob.glob")
|
||||
+@patch("os.listdir")
|
||||
+def test_autodetect_num_mlus_without_devices(mock_list, mock_glob):
|
||||
+ mock_glob.side_effect = Exception
|
||||
+ # mock_list.return_value = []
|
||||
+ assert Accelerator.get_current_node_num_accelerators() == 0
|
||||
+
|
||||
+
|
||||
+def test_mlu_accelerator_manager_api():
|
||||
+ assert Accelerator.get_resource_name() == "MLU"
|
||||
+ assert Accelerator.get_visible_accelerator_ids_env_var() == "MLU_VISIBLE_DEVICES"
|
||||
+ assert Accelerator.validate_resource_request_quantity(0.5) == (True, None)
|
||||
+ assert Accelerator.validate_resource_request_quantity(1) == (True, None)
|
||||
+
|
||||
+
|
||||
+def test_visible_mlu_type(monkeypatch, shutdown_only):
|
||||
+ with patch.object(
|
||||
+ Accelerator, "get_current_node_num_accelerators", return_value=4
|
||||
+ ), patch.object(
|
||||
+ Accelerator, "get_current_node_accelerator_type", return_value="MLU370"
|
||||
+ ):
|
||||
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
+ manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
|
||||
+ assert manager.get_current_node_accelerator_type() == "MLU370"
|
||||
+
|
||||
+@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
|
||||
+def test_visible_mlu_ids(monkeypatch, shutdown_only):
|
||||
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
+ with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
|
||||
+
|
||||
+ ray.init()
|
||||
+ manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
|
||||
+ assert manager.get_current_node_num_accelerators() == 4
|
||||
+ assert manager.__name__ == "MLUAcceleratorManager"
|
||||
+ assert ray.available_resources()["MLU"] == 3
|
||||
+
|
||||
+def test_get_current_process_visible_accelerator_ids(monkeypatch, shutdown_only):
|
||||
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
+ assert Accelerator.get_current_process_visible_accelerator_ids() == ["0", "1", "2"]
|
||||
+
|
||||
+ monkeypatch.delenv("MLU_VISIBLE_DEVICES")
|
||||
+ assert Accelerator.get_current_process_visible_accelerator_ids() is None
|
||||
+
|
||||
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "")
|
||||
+ assert Accelerator.get_current_process_visible_accelerator_ids() == []
|
||||
+
|
||||
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "NoDevFiles")
|
||||
+ assert Accelerator.get_current_process_visible_accelerator_ids() == []
|
||||
+
|
||||
+
|
||||
+def test_set_current_process_visible_accelerator_ids(shutdown_only):
|
||||
+ Accelerator.set_current_process_visible_accelerator_ids(["0"])
|
||||
+ assert os.environ["MLU_VISIBLE_DEVICES"] == "0"
|
||||
+
|
||||
+ Accelerator.set_current_process_visible_accelerator_ids(["0", "1"])
|
||||
+ assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1"
|
||||
+
|
||||
+ Accelerator.set_current_process_visible_accelerator_ids(["0", "1", "2"])
|
||||
+ assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1,2"
|
||||
+
|
||||
+
|
||||
+@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
|
||||
+def test_auto_detected_more_than_visible(monkeypatch, shutdown_only):
|
||||
+ with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
|
||||
+ # If more MLUs are detected than visible.
|
||||
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
+
|
||||
+ ray.init()
|
||||
+ assert ray.available_resources()["MLU"] == 3
|
||||
+
|
||||
+if __name__ == "__main__":
|
||||
+ if os.environ.get("PARALLEL_CI"):
|
||||
+ sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
|
||||
+ else:
|
||||
+ sys.exit(pytest.main(["-sv", __file__]))
|
||||
diff --git a/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl b/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
|
||||
new file mode 100644
|
||||
index 0000000..8628a88
|
||||
Binary files /dev/null and b/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl differ
|
||||
11
tools/ray_mlu/diff_for_dump_info.patch
Normal file
11
tools/ray_mlu/diff_for_dump_info.patch
Normal file
@@ -0,0 +1,11 @@
|
||||
diff --git a/ray_mlu/mlu.py b/ray_mlu/mlu.py
|
||||
index 21a57719..2c63fd5b 100755
|
||||
--- a/ray_mlu/mlu.py
|
||||
+++ b/ray_mlu/mlu.py
|
||||
@@ -87,6 +87,3 @@ class MLUAcceleratorManager(AcceleratorManager):
|
||||
if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR):
|
||||
return
|
||||
|
||||
- os.environ[
|
||||
- MLUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
||||
- ] = ",".join([str(i) for i in visible_mlu_devices])
|
||||
94
tools/ray_mlu/mlu.py
Executable file
94
tools/ray_mlu/mlu.py
Executable file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
import torch
|
||||
import torch_mlu
|
||||
from ray._private.accelerators.accelerator import AcceleratorManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MLU_VISIBLE_DEVICES_ENV_VAR = "MLU_VISIBLE_DEVICES"
|
||||
NOSET_MLU_VISIBLE_DEVICES_ENV_VAR = (
|
||||
"RAY_EXPERIMENTAL_NOSET_MLU_VISIBLE_DEVICES"
|
||||
)
|
||||
|
||||
|
||||
class MLUAcceleratorManager(AcceleratorManager):
|
||||
"""Cambricon MLU accelerators."""
|
||||
|
||||
@staticmethod
|
||||
def get_resource_name() -> str:
|
||||
return "GPU"
|
||||
|
||||
@staticmethod
|
||||
def get_visible_accelerator_ids_env_var() -> str:
|
||||
return MLU_VISIBLE_DEVICES_ENV_VAR
|
||||
|
||||
@staticmethod
|
||||
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
||||
mlu_visible_devices = os.environ.get(
|
||||
MLUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
||||
)
|
||||
|
||||
if mlu_visible_devices is None:
|
||||
return None
|
||||
|
||||
if mlu_visible_devices == "":
|
||||
return []
|
||||
|
||||
if mlu_visible_devices == "NoDevFiles":
|
||||
return []
|
||||
|
||||
return list(mlu_visible_devices.split(","))
|
||||
|
||||
@staticmethod
|
||||
def get_current_node_num_accelerators() -> int:
|
||||
"""Attempt to detect the number of MLUs on this machine.
|
||||
|
||||
MLU chips are represented as devices within `/dev/`, either as `/dev/davinci?`.
|
||||
|
||||
Returns:
|
||||
The number of MLUs if any were detected, otherwise 0.
|
||||
"""
|
||||
try:
|
||||
return torch.mlu.device_count()
|
||||
except Exception as e:
|
||||
logger.debug("Could not import CambriconCL: %s", e)
|
||||
|
||||
try:
|
||||
mlu_files = glob.glob("/dev/cambricon_dev?")
|
||||
return len(mlu_files)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to detect number of MLUs: %s", e)
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def get_current_node_accelerator_type() -> Optional[str]:
|
||||
"""Get the type of the Cambricon MLU on the current node.
|
||||
|
||||
Returns:
|
||||
A string of the type, such as "MLU370".
|
||||
"""
|
||||
try:
|
||||
return torch.mlu.get_device_name(0)
|
||||
except Exception:
|
||||
logger.exception("Failed to detect MLU type.")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def validate_resource_request_quantity(
|
||||
quantity: float,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
return (True, None)
|
||||
|
||||
@staticmethod
|
||||
def set_current_process_visible_accelerator_ids(
|
||||
visible_mlu_devices: List[str],
|
||||
) -> None:
|
||||
if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR):
|
||||
return
|
||||
|
||||
os.environ[
|
||||
MLUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
||||
] = ",".join([str(i) for i in visible_mlu_devices])
|
||||
1890
tools/ray_mlu/node.py
Normal file
1890
tools/ray_mlu/node.py
Normal file
File diff suppressed because it is too large
Load Diff
142
tools/ray_mlu/nsight.py
Normal file
142
tools/ray_mlu/nsight.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from ray._common.utils import (
|
||||
try_to_create_directory,
|
||||
)
|
||||
from ray._private.runtime_env.context import RuntimeEnvContext
|
||||
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
|
||||
from ray.exceptions import RuntimeEnvSetupError
|
||||
|
||||
default_logger = logging.getLogger(__name__)
|
||||
|
||||
# Nsight options used when runtime_env={"_nsight": "default"}
|
||||
# use default cnperf config, no need to specify any options
|
||||
NSIGHT_DEFAULT_CONFIG = {}
|
||||
|
||||
def parse_nsight_config(nsight_config: Dict[str, str]) -> List[str]:
|
||||
"""
|
||||
Function to convert dictionary of nsight options into
|
||||
nsight command line
|
||||
|
||||
The function returns:
|
||||
- List[str]: nsys profile cmd line split into list of str
|
||||
"""
|
||||
nsight_cmd = ["cnperf-cli", "record"]
|
||||
for option, option_val in nsight_config.items():
|
||||
# option standard based on
|
||||
# https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html
|
||||
if len(option) > 1:
|
||||
nsight_cmd.append(f"--{option}={option_val}")
|
||||
else:
|
||||
nsight_cmd += [f"-{option}", option_val]
|
||||
return nsight_cmd
|
||||
|
||||
|
||||
class NsightPlugin(RuntimeEnvPlugin):
|
||||
name = "_nsight"
|
||||
|
||||
def __init__(self, resources_dir: str):
|
||||
self.nsight_cmd = []
|
||||
|
||||
# replace this with better way to get logs dir
|
||||
session_dir, runtime_dir = os.path.split(resources_dir)
|
||||
self._nsight_dir = Path(session_dir) / "logs" / "nsight"
|
||||
try_to_create_directory(self._nsight_dir)
|
||||
|
||||
async def _check_nsight_script(
|
||||
self, nsight_config: Dict[str, str]
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Function to validate if nsight_config is a valid nsight profile options
|
||||
Args:
|
||||
nsight_config: dictionary mapping nsight option to it's value
|
||||
Returns:
|
||||
a tuple consists of a boolean indicating if the nsight_config
|
||||
is valid option and an error message if the nsight_config is invalid
|
||||
"""
|
||||
|
||||
# use empty as nsight report test filename
|
||||
nsight_config_copy = copy.deepcopy(nsight_config)
|
||||
try_to_create_directory(Path(self._nsight_dir) / "empty")
|
||||
nsight_config_copy["o"] = str(Path(self._nsight_dir) / "empty/test")
|
||||
nsight_cmd = parse_nsight_config(nsight_config_copy)
|
||||
try:
|
||||
nsight_cmd = nsight_cmd + ["python", "-c", '""']
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*nsight_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
error_msg = stderr.strip() if stderr.strip() != "" else stdout.strip()
|
||||
|
||||
# cleanup test.cnperf-rep file
|
||||
clean_up_cmd = ["rm", f"{nsight_config_copy['o']}.cnperf-rep"]
|
||||
cleanup_process = await asyncio.create_subprocess_exec(
|
||||
*clean_up_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
_, _ = await cleanup_process.communicate()
|
||||
if process.returncode == 0:
|
||||
return True, None
|
||||
else:
|
||||
return False, error_msg
|
||||
except FileNotFoundError:
|
||||
return False, ("cnperf-cli is not installed")
|
||||
|
||||
async def create(
|
||||
self,
|
||||
uri: Optional[str],
|
||||
runtime_env: "RuntimeEnv", # noqa: F821
|
||||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger = default_logger,
|
||||
) -> int:
|
||||
nsight_config = runtime_env.nsight()
|
||||
if not nsight_config:
|
||||
return 0
|
||||
|
||||
if nsight_config and sys.platform != "linux":
|
||||
raise RuntimeEnvSetupError(
|
||||
"CNPerf CLI is only available in Linux.\n"
|
||||
"More information can be found in "
|
||||
"https://docs.nvidia.com/nsight-compute/NsightComputeCli/index.html"
|
||||
)
|
||||
|
||||
if isinstance(nsight_config, str):
|
||||
if nsight_config == "default":
|
||||
nsight_config = NSIGHT_DEFAULT_CONFIG
|
||||
else:
|
||||
raise RuntimeEnvSetupError(
|
||||
f"Unsupported nsight config: {nsight_config}. "
|
||||
"The supported config is 'default' or "
|
||||
"Dictionary of cnperf options"
|
||||
)
|
||||
|
||||
is_valid_nsight_cmd, error_msg = await self._check_nsight_script(nsight_config)
|
||||
if not is_valid_nsight_cmd:
|
||||
logger.warning(error_msg)
|
||||
raise RuntimeEnvSetupError(
|
||||
"cnperf-cli failed to run with the following "
|
||||
f"error message:\n {error_msg}"
|
||||
)
|
||||
self.nsight_cmd = parse_nsight_config(nsight_config)
|
||||
return 0
|
||||
|
||||
def modify_context(
|
||||
self,
|
||||
uris: List[str],
|
||||
runtime_env: "RuntimeEnv", # noqa: F821
|
||||
context: RuntimeEnvContext,
|
||||
logger: Optional[logging.Logger] = default_logger,
|
||||
):
|
||||
context.py_executable = " ".join(self.nsight_cmd) + " python"
|
||||
logger.info("Running CNPerf cmd: %s", context.py_executable)
|
||||
|
||||
92
tools/ray_mlu/test_mlu.py
Executable file
92
tools/ray_mlu/test_mlu.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray._private.accelerators import MLUAcceleratorManager as Accelerator
|
||||
|
||||
|
||||
@patch("glob.glob")
|
||||
@patch("os.listdir")
|
||||
def test_autodetect_num_mlus(mock_list, mock_glob):
|
||||
mock_glob.return_value = [f"/dev/davinci{i}" for i in range(4)]
|
||||
# mock_list.return_value = []
|
||||
assert Accelerator.get_current_node_num_accelerators() == 4
|
||||
|
||||
|
||||
@patch("glob.glob")
|
||||
@patch("os.listdir")
|
||||
def test_autodetect_num_mlus_without_devices(mock_list, mock_glob):
|
||||
mock_glob.side_effect = Exception
|
||||
# mock_list.return_value = []
|
||||
assert Accelerator.get_current_node_num_accelerators() == 0
|
||||
|
||||
|
||||
def test_mlu_accelerator_manager_api():
|
||||
assert Accelerator.get_resource_name() == "MLU"
|
||||
assert Accelerator.get_visible_accelerator_ids_env_var() == "MLU_VISIBLE_DEVICES"
|
||||
assert Accelerator.validate_resource_request_quantity(0.5) == (True, None)
|
||||
assert Accelerator.validate_resource_request_quantity(1) == (True, None)
|
||||
|
||||
|
||||
def test_visible_mlu_type(monkeypatch, shutdown_only):
|
||||
with patch.object(
|
||||
Accelerator, "get_current_node_num_accelerators", return_value=4
|
||||
), patch.object(
|
||||
Accelerator, "get_current_node_accelerator_type", return_value="MLU370"
|
||||
):
|
||||
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
|
||||
assert manager.get_current_node_accelerator_type() == "MLU370"
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
|
||||
def test_visible_mlu_ids(monkeypatch, shutdown_only):
|
||||
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
|
||||
|
||||
ray.init()
|
||||
manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
|
||||
assert manager.get_current_node_num_accelerators() == 4
|
||||
assert manager.__name__ == "MLUAcceleratorManager"
|
||||
assert ray.available_resources()["MLU"] == 3
|
||||
|
||||
def test_get_current_process_visible_accelerator_ids(monkeypatch, shutdown_only):
|
||||
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
assert Accelerator.get_current_process_visible_accelerator_ids() == ["0", "1", "2"]
|
||||
|
||||
monkeypatch.delenv("MLU_VISIBLE_DEVICES")
|
||||
assert Accelerator.get_current_process_visible_accelerator_ids() is None
|
||||
|
||||
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "")
|
||||
assert Accelerator.get_current_process_visible_accelerator_ids() == []
|
||||
|
||||
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "NoDevFiles")
|
||||
assert Accelerator.get_current_process_visible_accelerator_ids() == []
|
||||
|
||||
|
||||
def test_set_current_process_visible_accelerator_ids(shutdown_only):
|
||||
Accelerator.set_current_process_visible_accelerator_ids(["0"])
|
||||
assert os.environ["MLU_VISIBLE_DEVICES"] == "0"
|
||||
|
||||
Accelerator.set_current_process_visible_accelerator_ids(["0", "1"])
|
||||
assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1"
|
||||
|
||||
Accelerator.set_current_process_visible_accelerator_ids(["0", "1", "2"])
|
||||
assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1,2"
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
|
||||
def test_auto_detected_more_than_visible(monkeypatch, shutdown_only):
|
||||
with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
|
||||
# If more MLUs are detected than visible.
|
||||
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
|
||||
|
||||
ray.init()
|
||||
assert ray.available_resources()["MLU"] == 3
|
||||
|
||||
if __name__ == "__main__":
|
||||
if os.environ.get("PARALLEL_CI"):
|
||||
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
|
||||
else:
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
||||
3785
tools/ray_mlu/worker.py
Normal file
3785
tools/ray_mlu/worker.py
Normal file
File diff suppressed because it is too large
Load Diff
15
vllm_mlu/__init__.py
Normal file
15
vllm_mlu/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
|
||||
def register_mlu_platform():
|
||||
"""Register the MLU platform."""
|
||||
return "vllm_mlu.platforms.mlu.MLUPlatform"
|
||||
|
||||
|
||||
def register_mlu_hijack():
|
||||
"""Register the MLU models and hijack."""
|
||||
from vllm_mlu import mlu_hijack
|
||||
from vllm_mlu.model_executor.models import register_model
|
||||
register_model()
|
||||
return
|
||||
1853
vllm_mlu/_mlu_ops.py
Normal file
1853
vllm_mlu/_mlu_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
107
vllm_mlu/_mlu_utils.py
Normal file
107
vllm_mlu/_mlu_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import os
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
def _check_env(env, default=False):
|
||||
if env in os.environ:
|
||||
return os.environ[env].lower() in ["true", "1"]
|
||||
return default
|
||||
|
||||
|
||||
def _check_env_value(env, default=0):
|
||||
if env in os.environ:
|
||||
if not os.environ[env].isdigit():
|
||||
raise ValueError(f"'{env}' should be set with integer")
|
||||
value = int(os.environ[env])
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
def _check_env_float(env, default=0):
|
||||
if env in os.environ:
|
||||
try:
|
||||
value = float(os.environ[env])
|
||||
except ValueError:
|
||||
raise ValueError(f"'{env}' should be set with float")
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency.
|
||||
VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False)
|
||||
|
||||
# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency.
|
||||
VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False)
|
||||
|
||||
# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference.
|
||||
VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False)
|
||||
|
||||
# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference.
|
||||
VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False)
|
||||
|
||||
# VLLM_DUMP_MLU_INFO_DEBUG: Dump device debug info when running vLLM inference.
|
||||
VLLM_DUMP_MLU_INFO_DEBUG = _check_env("VLLM_DUMP_MLU_INFO_DEBUG", default=False)
|
||||
|
||||
# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler.
|
||||
VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False)
|
||||
|
||||
# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True.
|
||||
# Set to False to disable warning messages.
|
||||
VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True)
|
||||
|
||||
# VLLM_AVG_MOE_EN: make moe experts workload balance, default value is False.
|
||||
VLLM_AVG_MOE_EN = _check_env("VLLM_AVG_MOE_EN", default=False) or _check_env("VLLM_RANDOM_MOE_EN", default=False)
|
||||
VLLM_RANDOM_MOE_EN = _check_env("VLLM_RANDOM_MOE_EN", default=False)
|
||||
|
||||
# VLLM_LOGITS_USE_ALL_GATHER: use allgather for logits collection, default value is False.
|
||||
VLLM_LOGITS_USE_ALL_GATHER = _check_env("VLLM_LOGITS_USE_ALL_GATHER", default=False)
|
||||
|
||||
VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE)
|
||||
VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE)
|
||||
VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO)
|
||||
VLLM_DUMP_MLU_INFO_DEBUG = (VLLM_DUMP_MLU_INFO_DEBUG and VLLM_DUMP_MLU_INFO_EN)
|
||||
|
||||
# VLLM_V1_USE_UNCHUNK_SCHED: v1 use unchunk scheduler, default value is True.
|
||||
VLLM_V1_USE_UNCHUNK_SCHED = _check_env("VLLM_V1_USE_UNCHUNK_SCHED", default=True)
|
||||
|
||||
# VLLM_V1_MIN_PREFILL_BATCH: the min scheduling batch in v1, default is 1.
|
||||
VLLM_V1_MIN_PREFILL_BATCH = _check_env_value("VLLM_V1_MIN_PREFILL_BATCH", default=1)
|
||||
|
||||
# VLLM_V1_USE_FULL_GRAPH: v1 use full graph capture, default value is True.
|
||||
VLLM_V1_USE_FULL_GRAPH = _check_env("VLLM_V1_USE_FULL_GRAPH", default=True)
|
||||
|
||||
# VLLM_V1_BENCHMARK: v1 benchmark, default value is False.
|
||||
VLLM_V1_BENCHMARK = _check_env("VLLM_V1_BENCHMARK", default=False)
|
||||
|
||||
# VLLM_MTP_DEBUG: use to show mtp accepted rate, default value is False.
|
||||
VLLM_MTP_DEBUG = _check_env("VLLM_MTP_DEBUG", default=False)
|
||||
|
||||
# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None
|
||||
VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False)
|
||||
|
||||
# VLLM_MTP_FIXED_ACCEPTANCE_RATE: use fixed acceptance rate, default value is None.
|
||||
VLLM_MTP_FIXED_ACCEPTANCE_RATE = _check_env_float("VLLM_MTP_FIXED_ACCEPTANCE_RATE", default=None)
|
||||
|
||||
# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None
|
||||
VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False)
|
||||
|
||||
# VLLM_V1_UNCHUNK_SCHED_LOG: print v1 unchunk scheduler state
|
||||
VLLM_V1_UNCHUNK_SCHED_LOG = _check_env("VLLM_V1_UNCHUNK_SCHED_LOG", default=False)
|
||||
|
||||
# VLLM_MOE_PREFILL_CHUNK_SIZE: in number of tokens. enabled when > 0.
|
||||
VLLM_MOE_PREFILL_CHUNK_SIZE = _check_env_value("VLLM_MOE_PREFILL_CHUNK_SIZE", default=0)
|
||||
|
||||
# VLLM_CI_ACCURACY_TEST: CI accuracy test, default value is False.
|
||||
VLLM_CI_ACCURACY_TEST = _check_env("VLLM_CI_ACCURACY_TEST", default=False)
|
||||
|
||||
# VLLM_DISAGG_TRANS_ALL_BLOCKS: optimize the performance of disagg
|
||||
VLLM_DISAGG_TRANS_ALL_BLOCKS = _check_env("VLLM_DISAGG_TRANS_ALL_BLOCKS", default=True)
|
||||
|
||||
# vllm disagg debug
|
||||
VLLM_DISAGG_CNPX_EXECUTE = _check_env("VLLM_DISAGG_CNPX_EXECUTE", default=False)
|
||||
VLLM_DISAGG_CNPX_REQUEST = _check_env("VLLM_DISAGG_CNPX_REQUEST", default=False)
|
||||
VLLM_DISAGG_FAKE_DECODER = _check_env("VLLM_DISAGG_FAKE_DECODER", default=False)
|
||||
3
vllm_mlu/attention/__init__.py
Normal file
3
vllm_mlu/attention/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
351
vllm_mlu/attention/layer.py
Normal file
351
vllm_mlu/attention/layer.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import MLAAttentionImpl
|
||||
from vllm.attention.layer import Attention, MLAAttention, _init_kv_cache_quant
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.vllm import QuantizationConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm_mlu.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.v1.kv_cache_interface import (
|
||||
MLUFullAttentionSpec,
|
||||
MLUMLAAttentionSpec,
|
||||
MLUSlidingWindowSpec,
|
||||
)
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
kwargs: dict[str, Any] = {},
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add return for self.impl.forward and it's param kwargs
|
||||
'''
|
||||
output = self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return output
|
||||
|
||||
class Attention_MluHijack(Attention):
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
if self.sliding_window is not None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace SlidingWindowSpec with MLUSlidingWindowSpec.
|
||||
'''
|
||||
return MLUSlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace FullAttentionSpec with MLUFullAttentionSpec.
|
||||
'''
|
||||
return MLUFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
class MLAAttention_MluHijack(MLAAttention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_sparse: bool = False,
|
||||
indexer: object | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.num_heads = num_heads
|
||||
self.scale = scale
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
# self.head_size = kv_lora_rank + qk_rope_head_dim
|
||||
self.layer_name = prefix
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: insert num_kv_heads for mlu platform
|
||||
'''
|
||||
self.head_size = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.num_kv_heads = extra_impl_args.pop("num_kv_heads", None)
|
||||
if self.num_kv_heads is None:
|
||||
self.num_kv_heads = num_heads
|
||||
|
||||
self.decoder_attn_dtype = None
|
||||
decoder_attn_dtype = get_current_vllm_config().mlu_config.decoder_attn_dtype
|
||||
if decoder_attn_dtype in ["int8", "fp8_e4m3", "fp8"]:
|
||||
self.decoder_attn_dtype = (
|
||||
torch.int8 if decoder_attn_dtype == "int8"
|
||||
else torch.float8_e4m3fn
|
||||
)
|
||||
extra_impl_args['decoder_attn_dtype'] = self.decoder_attn_dtype
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||
)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
|
||||
self.impl = impl_cls(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
None, # alibi_slops
|
||||
None, # sliding_window
|
||||
kv_cache_dtype,
|
||||
None, # logits_soft_cap
|
||||
AttentionType.DECODER, # attn_dtype
|
||||
None, # kv_sharing_target_layer_name
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
if current_platform.is_out_of_tree():
|
||||
self.use_direct_call = False
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support kv8 and deepseek v3.2
|
||||
'''
|
||||
self.kv_cache = [
|
||||
[torch.tensor([]), torch.tensor([]), torch.tensor([])]
|
||||
for _ in range(
|
||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
||||
)
|
||||
]
|
||||
self.impl.use_mla = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||
self.kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace MLAAttentionSpec with MLUMLAAttentionSpec.
|
||||
'''
|
||||
index_head_dim, index_n_heads = 0, 0
|
||||
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
|
||||
index_n_heads = 1
|
||||
|
||||
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v4":
|
||||
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
|
||||
index_n_heads = 1
|
||||
|
||||
return MLUMLAAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
index_head_dim=index_head_dim,
|
||||
index_n_heads=index_n_heads,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
kwargs: dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
|
||||
assert not self.use_direct_call, "MLU-V1 does not support direct call."
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output_lse = None
|
||||
output_shape = (output_shape if output_shape is not None else query.shape)
|
||||
output_shape = [output_shape[0], self.num_heads * self.v_head_dim]
|
||||
|
||||
output = torch.empty(
|
||||
output_shape,
|
||||
dtype=self.dtype if query.dtype == torch.int8 else query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.v_head_dim)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.v_head_dim)
|
||||
if not kwargs:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
attn_output_list = output
|
||||
else:
|
||||
attn_output_list = unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name, kwargs=kwargs)
|
||||
if isinstance(attn_output_list, (list, tuple)) and len(attn_output_list) > 1:
|
||||
output_lse = attn_output_list[1]
|
||||
if output_lse is not None:
|
||||
return output.view(-1, hidden_size), output_lse
|
||||
else:
|
||||
return output.view(-1, hidden_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Attention,
|
||||
Attention.get_kv_cache_spec,
|
||||
Attention_MluHijack.get_kv_cache_spec,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
MLAAttention,
|
||||
MLAAttention.__init__,
|
||||
MLAAttention_MluHijack.__init__,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
MLAAttention,
|
||||
MLAAttention.get_kv_cache_spec,
|
||||
MLAAttention_MluHijack.get_kv_cache_spec,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
MLAAttention,
|
||||
MLAAttention.forward,
|
||||
MLAAttention_MluHijack.forward,
|
||||
)
|
||||
3
vllm_mlu/attention/utils/__init__.py
Normal file
3
vllm_mlu/attention/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
62
vllm_mlu/attention/utils/kv_transfer_utils.py
Normal file
62
vllm_mlu/attention/utils/kv_transfer_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
|
||||
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
"""Decorator that handles KV layer transfer prior and after execution of
|
||||
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||
|
||||
On entry: waits for the KV layer from the connector.
|
||||
On exit: saves the KV layer to the connector.
|
||||
"""
|
||||
# Import at runtime to avoid circular dependency
|
||||
from vllm.attention.layer import get_attention_context
|
||||
|
||||
# Inspect the signature ONCE when the decorator is applied.
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Find the index of 'layer_name' parameter.
|
||||
try:
|
||||
layer_name_index = param_names.index("layer_name")
|
||||
except ValueError as e:
|
||||
raise TypeError(
|
||||
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||
) from e
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Wait for KV layer on entry
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Save KV cache layer on exit
|
||||
|
||||
if kwargs is None or kwargs.get("save_kv_layer", True):
|
||||
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
3
vllm_mlu/benchmarks/__init__.py
Normal file
3
vllm_mlu/benchmarks/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
72
vllm_mlu/benchmarks/datasets.py
Normal file
72
vllm_mlu/benchmarks/datasets.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
This module defines a framework for sampling benchmark requests from various
|
||||
datasets. Each dataset subclass of BenchmarkDataset must implement sample
|
||||
generation. Supported dataset types include:
|
||||
- ShareGPT
|
||||
- Random (synthetic)
|
||||
- Sonnet
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
"""
|
||||
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.benchmarks.datasets import RandomMultiModalDataset
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video(
|
||||
self, width: int, height: int, num_frames: int
|
||||
) -> dict:
|
||||
"""Generate synthetic video with random values.
|
||||
|
||||
Creates a video with random pixel values, encodes it to MP4 format,
|
||||
and returns the content as bytes.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
random_pixels = self._rng.integers(
|
||||
0,
|
||||
256,
|
||||
(num_frames, height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
# Create a temporary video file in memory
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
fps = 30 # frames per second
|
||||
|
||||
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Create video writer
|
||||
video_writer = cv2.VideoWriter(
|
||||
temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height)
|
||||
)
|
||||
|
||||
if not video_writer.isOpened():
|
||||
raise RuntimeError("Failed to create video writer")
|
||||
|
||||
for frame in random_pixels:
|
||||
video_writer.write(frame)
|
||||
|
||||
video_writer.release()
|
||||
temp_file.close()
|
||||
|
||||
# Read the video file content
|
||||
with open(temp_path, "rb") as f:
|
||||
video_content = f.read()
|
||||
|
||||
return {"bytes": video_content}
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RandomMultiModalDataset,
|
||||
RandomMultiModalDataset.generate_synthetic_video,
|
||||
vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video,
|
||||
)
|
||||
3
vllm_mlu/compilation/__init__.py
Normal file
3
vllm_mlu/compilation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
185
vllm_mlu/compilation/fix_functionalization.py
Normal file
185
vllm_mlu/compilation/fix_functionalization.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import operator
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fx_utils import is_func
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass_MluHijack(FixFunctionalizationPass):
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug(
|
||||
"XPU platform does not support fix functionalizationpass currently."
|
||||
)
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: skip custom op on mlu
|
||||
'''
|
||||
if current_platform.is_out_of_tree():
|
||||
continue # skip the count on mlu
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of an mm_node.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into mm_node. So after de-functionalization, we can
|
||||
# just use mm_node directly.
|
||||
|
||||
mm_node = query.args[0].args[0]
|
||||
for user in getitem_nodes.values():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(
|
||||
user_of_getitem, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
else:
|
||||
# Directly replace the auto_functionalize(rotary_embedding)
|
||||
# with the inplace rotary_embedding. In theory, we shouldn't
|
||||
# do this blindly, but in practice in vLLM it's ok. The best
|
||||
# solution is to use auto_functionalization_v2 and then use
|
||||
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||
mutated_args = {1: "query", 2: "key"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: "input", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input")
|
||||
)
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input", "scale")
|
||||
)
|
||||
elif (
|
||||
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
||||
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
||||
):
|
||||
mutated_args = {1: "result", 2: "result_block_scale"}
|
||||
self.defunctionalize(
|
||||
graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=(
|
||||
"result",
|
||||
"result_block_scale",
|
||||
"input",
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
|
||||
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
|
||||
mutated_args = {1: "qkv"}
|
||||
args = (
|
||||
"qkv",
|
||||
"num_heads_q",
|
||||
"num_heads_k",
|
||||
"num_heads_v",
|
||||
"head_dim",
|
||||
"eps",
|
||||
"q_weight",
|
||||
"k_weight",
|
||||
"cos_sin_cache",
|
||||
"is_neox",
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
FixFunctionalizationPass,
|
||||
FixFunctionalizationPass.__call__,
|
||||
FixFunctionalizationPass_MluHijack.__call__
|
||||
)
|
||||
242
vllm_mlu/compilation/mlu_graph.py
Normal file
242
vllm_mlu/compilation/mlu_graph.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import weak_ref_tensors
|
||||
from vllm.compilation.cuda_graph import (
|
||||
CUDAGraphEntry,
|
||||
CUDAGraphWrapper,
|
||||
CUDAGraphOptions,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.utils import MLUInferMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: specialized graph entry for prefill graphs
|
||||
'''
|
||||
@dataclasses.dataclass
|
||||
class PrefillGraphEntry:
|
||||
batch_size: int = 0
|
||||
seq_len: int = 0
|
||||
cudagraph: torch.mlu.MLUGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: list[int] | None = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
class MLUGraphWrapper(CUDAGraphWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
):
|
||||
super().__init__(runnable, vllm_config, runtime_mode, cudagraph_options)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add separate dict for prefill graph entries
|
||||
'''
|
||||
self.prefill_mlugraph_entry: PrefillGraphEntry | None = None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: check if running in prefill mode
|
||||
'''
|
||||
def is_running_in_prefill(self, entry: PrefillGraphEntry | None = None) -> bool:
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.attn_metadata is None:
|
||||
return False
|
||||
infer_mode = forward_context.attn_metadata['common_metadata'].infer_mode
|
||||
seq_lens_cpu = forward_context.attn_metadata['common_metadata'].seq_lens_cpu
|
||||
if entry is not None \
|
||||
and infer_mode == MLUInferMode.PREFILL_ONLY \
|
||||
and seq_lens_cpu.size(0) == entry.batch_size \
|
||||
and (seq_lens_cpu == entry.seq_len).all().item():
|
||||
return True
|
||||
return False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
is_capturing_prefill: bool = False,
|
||||
prefill_enable_mlugraph: bool = False,
|
||||
prefill_batch_size: int = 0,
|
||||
prefill_seq_len: int = 0,
|
||||
is_running_drafter: bool = False,
|
||||
*args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode != self.runtime_mode
|
||||
):
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: handle prefill graph separately
|
||||
@brief: skip check in running drafter model
|
||||
'''
|
||||
if is_capturing_prefill: # PREFILL capture
|
||||
self.prefill_mlugraph_entry = PrefillGraphEntry(
|
||||
batch_size=prefill_batch_size,
|
||||
seq_len=prefill_seq_len)
|
||||
else: # FULL/DECODE capture
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
batch_descriptor=batch_descriptor
|
||||
)
|
||||
|
||||
if ((self.is_running_in_prefill(self.prefill_mlugraph_entry) and prefill_enable_mlugraph)
|
||||
or is_capturing_prefill):
|
||||
entry = self.prefill_mlugraph_entry
|
||||
logger.debug(
|
||||
f"Hitting a prefill cudagraph on {self.runtime_mode.name}, "
|
||||
f"batch_size: {entry.batch_size}, seq_len: {entry.seq_len}")
|
||||
else: # FULL/DECODE capture
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
logger.debug(
|
||||
"Hitting a decode cudagraph on (%s, %s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
if is_capturing_prefill:
|
||||
logger.debug(
|
||||
"Capturing a prefill cudagraph on (%s, batch_size=%d, seq_len=%d)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_size,
|
||||
entry.seq_len,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Capturing a decode cudagraph on (%s, %s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
if ((not is_capturing_prefill) and (not is_running_drafter)):
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.mlu.MLUGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(patch("torch.mlu.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.mlu.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
3
vllm_mlu/config/__init__.py
Normal file
3
vllm_mlu/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
71
vllm_mlu/config/model.py
Normal file
71
vllm_mlu/config/model.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__config__model__ModelConfig__is_embedding_task(self) -> bool:
|
||||
return self.runner_type == "pooling"
|
||||
|
||||
def vllm__config__model__ModelConfig__get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if self.is_deepseek_mla:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0)
|
||||
if self.use_mla:
|
||||
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
|
||||
else:
|
||||
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if hasattr(self.hf_text_config, "model_type") and (
|
||||
self.hf_text_config.model_type == "zamba2"
|
||||
):
|
||||
return self.hf_text_config.attention_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
# NOTE: Some configs may set head_dim=None in the config
|
||||
if getattr(self.hf_text_config, "head_dim", None) is not None:
|
||||
return self.hf_text_config.head_dim
|
||||
|
||||
# NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head`
|
||||
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
|
||||
return self.hf_text_config.hidden_size_per_head
|
||||
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust num_heads and num_attention_heads.
|
||||
'''
|
||||
if hasattr(self.hf_text_config, "num_heads"):
|
||||
num_attention_heads = self.hf_text_config.num_heads
|
||||
else:
|
||||
num_attention_heads = self.hf_text_config.num_attention_heads
|
||||
|
||||
return (self.hf_text_config.hidden_size // num_attention_heads)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelConfig,
|
||||
"is_embedding_task",
|
||||
vllm__config__model__ModelConfig__is_embedding_task,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelConfig,
|
||||
ModelConfig.get_head_size,
|
||||
vllm__config__model__ModelConfig__get_head_size,
|
||||
)
|
||||
86
vllm_mlu/config/scheduler.py
Normal file
86
vllm_mlu/config/scheduler.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu._mlu_utils import VLLM_V1_BENCHMARK
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__config__scheduler__SchedulerConfig__verify_max_model_len(
|
||||
self, max_model_len: int,
|
||||
) -> Self:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: This restriction is removed when VLLM_V1_BENCHMARK is set to True
|
||||
'''
|
||||
if not VLLM_V1_BENCHMARK:
|
||||
if (
|
||||
self.max_num_batched_tokens < max_model_len
|
||||
and not self.enable_chunked_prefill
|
||||
):
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len."
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs})."
|
||||
)
|
||||
|
||||
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
|
||||
logger.warning(
|
||||
"max_num_batched_tokens (%d) exceeds max_num_seqs "
|
||||
"* max_model_len (%d). This may lead to unexpected behavior.",
|
||||
self.max_num_batched_tokens,
|
||||
self.max_num_seqs * max_model_len,
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if not self.enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Chunked prefill must be enabled to set "
|
||||
"max_num_partial_prefills > 1."
|
||||
)
|
||||
|
||||
if self.long_prefill_token_threshold > max_model_len:
|
||||
raise ValueError(
|
||||
"long_prefill_token_threshold "
|
||||
f"({self.long_prefill_token_threshold}) cannot be greater "
|
||||
f"than the max_model_len ({max_model_len})."
|
||||
)
|
||||
|
||||
if self.max_long_partial_prefills > self.max_num_partial_prefills:
|
||||
raise ValueError(
|
||||
f"{self.max_long_partial_prefills=} must be less than or equal to "
|
||||
f"{self.max_num_partial_prefills=}."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
SchedulerConfig,
|
||||
SchedulerConfig.verify_max_model_len,
|
||||
vllm__config__scheduler__SchedulerConfig__verify_max_model_len,
|
||||
)
|
||||
66
vllm_mlu/config/speculative.py
Normal file
66
vllm_mlu/config/speculative.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@staticmethod
|
||||
def vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config, except the tp_size.
|
||||
"""
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
@brief: add draft data parallel parameters
|
||||
=============================
|
||||
'''
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
# add draft data parallel parameters
|
||||
data_parallel_size=target_parallel_config.data_parallel_size,
|
||||
data_parallel_size_local=target_parallel_config.data_parallel_size_local,
|
||||
data_parallel_master_ip=target_parallel_config.data_parallel_master_ip,
|
||||
data_parallel_rpc_port=target_parallel_config.data_parallel_rpc_port,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
|
||||
vllm__config__speculative__SpeculativeConfig____post_init___org = SpeculativeConfig.__post_init__
|
||||
def vllm__config__speculative__SpeculativeConfig____post_init__(self):
|
||||
if self.model is None and self.num_speculative_tokens is not None and self.method is None:
|
||||
self.method = "mtp"
|
||||
vllm__config__speculative__SpeculativeConfig____post_init___org(self)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
SpeculativeConfig,
|
||||
SpeculativeConfig.create_draft_parallel_config,
|
||||
vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
SpeculativeConfig,
|
||||
SpeculativeConfig.__post_init__,
|
||||
vllm__config__speculative__SpeculativeConfig____post_init__,
|
||||
)
|
||||
213
vllm_mlu/config/vllm.py
Normal file
213
vllm_mlu/config/vllm.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import os
|
||||
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__config__vllm__VllmConfig___set_cudagraph_sizes(self):
|
||||
"""
|
||||
vLLM defines the default candidate list of batch sizes for CUDA graph
|
||||
capture as:
|
||||
|
||||
```python
|
||||
max_graph_size = min(max_num_seqs * 2, 512)
|
||||
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
|
||||
# up to max_graph_size
|
||||
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
||||
range(256, max_graph_size + 1, 16))
|
||||
|
||||
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||
will be the final sizes to capture cudagraph (in ascending order).
|
||||
|
||||
These sizes are used to capture and reuse CUDA graphs for
|
||||
performance-critical paths (e.g., decoding). Capturing enables
|
||||
significantly faster kernel dispatch by avoiding Python overhead. The
|
||||
list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
|
||||
most GPUs), which controls the total allowed number of tokens in a
|
||||
batch. Since each sequence may have a variable number of tokens, the
|
||||
maximum usable batch size will depend on actual sequence lengths.
|
||||
|
||||
Example:
|
||||
With `max_num_batched_tokens = 8192`, and typical sequences
|
||||
averaging ~32 tokens, most practical batch sizes fall below 256.
|
||||
However, the system will still allow capture sizes up to 512 if
|
||||
shape and memory permit.
|
||||
|
||||
Note:
|
||||
If users explicitly specify cudagraph capture sizes in the
|
||||
compilation config, those will override this default logic.
|
||||
At runtime:
|
||||
|
||||
- If batch size <= one of the `cudagraph_capture_sizes`, the closest
|
||||
padded CUDA graph will be used.
|
||||
- If batch size > largest `cudagraph_capture_sizes`, cudagraph will
|
||||
not be used.
|
||||
"""
|
||||
if hasattr(self.compilation_config, "_has_set_capture_list"):
|
||||
# avoid set capture list twice while init
|
||||
return
|
||||
|
||||
if (
|
||||
self.model_config is not None
|
||||
and not self.model_config.enforce_eager
|
||||
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
# determine the initial max_cudagraph_capture_size
|
||||
max_cudagraph_capture_size = (
|
||||
self.compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
if max_cudagraph_capture_size is None:
|
||||
max_cudagraph_capture_size = min(
|
||||
self.scheduler_config.max_num_seqs * 2, 512
|
||||
)
|
||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
|
||||
|
||||
assert max_cudagraph_capture_size >= 1, (
|
||||
"Maximum cudagraph size should be greater than or equal to 1 "
|
||||
"when using cuda graph."
|
||||
)
|
||||
|
||||
# determine the cudagraph_capture_sizes
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||
assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
|
||||
"cudagraph_capture_sizes should contain at least one element "
|
||||
"when using cuda graph."
|
||||
)
|
||||
# de-duplicate the sizes provided by the config
|
||||
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
|
||||
cudagraph_capture_sizes = [
|
||||
i for i in dedup_sizes if i <= max_num_tokens
|
||||
]
|
||||
# sort to make sure the sizes are in ascending order
|
||||
cudagraph_capture_sizes.sort()
|
||||
else:
|
||||
cudagraph_capture_sizes = [
|
||||
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
|
||||
]
|
||||
if max_cudagraph_capture_size >= 8:
|
||||
# Step size 8 for small batch sizes, up to 256(not included)
|
||||
cudagraph_capture_sizes += list(
|
||||
range(8, min(max_cudagraph_capture_size + 1, 256), 8)
|
||||
)
|
||||
if max_cudagraph_capture_size >= 256:
|
||||
# Step size 16 for larger batch sizes
|
||||
cudagraph_capture_sizes += list(
|
||||
range(256, max_cudagraph_capture_size + 1, 16)
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief:
|
||||
1) check batch_size_capture_list when enable mtp because bs * (K + 1)
|
||||
may greater than max_num_batched_tokens
|
||||
2) capture MLUGraph by given batch list
|
||||
'''
|
||||
mlu_graph_capture_list = os.getenv("MLU_GRAPH_CAPTURE_LIST", None)
|
||||
if mlu_graph_capture_list:
|
||||
if "-" in mlu_graph_capture_list:
|
||||
batch_info = mlu_graph_capture_list.split("-")
|
||||
assert len(batch_info) == 3, \
|
||||
f"Got invalid graph_capture_list={mlu_graph_capture_list}, " + \
|
||||
f"but expected format 'min_bs-max_bs(may not include)-step'."
|
||||
start, end, step = mlu_graph_capture_list.split("-")
|
||||
cudagraph_capture_sizes = [1, 2, 4] + [
|
||||
i for i in range(int(start), int(end), int(step))
|
||||
]
|
||||
cudagraph_capture_sizes = sorted(list(set(cudagraph_capture_sizes)))
|
||||
else:
|
||||
cudagraph_capture_sizes = [int(x) for x in mlu_graph_capture_list.split(",")]
|
||||
|
||||
if (self.speculative_config is not None
|
||||
and self.speculative_config.num_speculative_tokens > 0
|
||||
):
|
||||
K = self.speculative_config.num_speculative_tokens
|
||||
cudagraph_capture_sizes = [x * (1 + K) for x in cudagraph_capture_sizes]
|
||||
|
||||
cudagraph_capture_sizes = [
|
||||
size for size in cudagraph_capture_sizes
|
||||
if size <= self.scheduler_config.max_num_batched_tokens
|
||||
]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if (
|
||||
self.parallel_config.tensor_parallel_size > 1
|
||||
and self.compilation_config.pass_config.enable_sequence_parallelism
|
||||
):
|
||||
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
|
||||
cudagraph_capture_sizes
|
||||
)
|
||||
|
||||
# user-specific compilation_config.max_cudagraph_capture_size get
|
||||
# truncated to valid_max_size when they are inconsistent.
|
||||
valid_max_size = (
|
||||
cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
|
||||
)
|
||||
if (
|
||||
self.compilation_config.max_cudagraph_capture_size is not None
|
||||
and self.compilation_config.max_cudagraph_capture_size != valid_max_size
|
||||
):
|
||||
# raise error only when both two flags are user-specified
|
||||
# and they are inconsistent with each other
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
"customized max_cudagraph_capture_size"
|
||||
f"(={self.compilation_config.max_cudagraph_capture_size}) "
|
||||
"should be consistent with the max value of "
|
||||
f"cudagraph_capture_sizes(={valid_max_size})"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"Truncating max_cudagraph_capture_size to %d",
|
||||
valid_max_size,
|
||||
)
|
||||
# always set the final max_cudagraph_capture_size
|
||||
self.compilation_config.max_cudagraph_capture_size = valid_max_size
|
||||
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None and len(
|
||||
cudagraph_capture_sizes
|
||||
) < len(self.compilation_config.cudagraph_capture_sizes):
|
||||
# If users have specified capture sizes, we only need to
|
||||
# compare the lens before and after modification since the modified
|
||||
# list is only the subset of the original list.
|
||||
logger.warning(
|
||||
(
|
||||
"cudagraph_capture_sizes specified in compilation_config"
|
||||
" %s is overridden by config %s"
|
||||
),
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
cudagraph_capture_sizes,
|
||||
)
|
||||
# always write back the final sizes
|
||||
self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||
|
||||
else:
|
||||
# no cudagraph in use
|
||||
self.compilation_config.max_cudagraph_capture_size = 0
|
||||
self.compilation_config.cudagraph_capture_sizes = []
|
||||
|
||||
# complete the remaining process.
|
||||
self.compilation_config.post_init_cudagraph_sizes()
|
||||
|
||||
setattr(self.compilation_config, "_has_set_capture_list", True)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
VllmConfig,
|
||||
VllmConfig._set_cudagraph_sizes,
|
||||
vllm__config__vllm__VllmConfig___set_cudagraph_sizes,
|
||||
)
|
||||
3
vllm_mlu/device_allocator/__init__.py
Normal file
3
vllm_mlu/device_allocator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
319
vllm_mlu/device_allocator/cnmem.py
Normal file
319
vllm_mlu/device_allocator/cnmem.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# cn_api based pytorch pluggable allocator to implement sleep mode.
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> str | None:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found_line = None
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found_line = line
|
||||
break
|
||||
if found_line is None:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = found_line.index("/")
|
||||
path = found_line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), (
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
cnmem_available = False
|
||||
try:
|
||||
from vllm_mlu.vllm_mlu_C import (
|
||||
init_module,
|
||||
python_create_and_map,
|
||||
python_unmap_and_release,
|
||||
python_cn_memcpy,
|
||||
)
|
||||
|
||||
lib_name = find_loaded_library("vllm_mlu_C")
|
||||
cnmem_available = True
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error("Failed to import cnmem_allocator:%s", e)
|
||||
init_module = None
|
||||
python_create_and_map = None
|
||||
python_unmap_and_release = None
|
||||
lib_name = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = tuple[int, int, int, int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocationData:
|
||||
handle: HandleType
|
||||
tag: str
|
||||
cpu_backup_tensor: torch.Tensor | None = None
|
||||
|
||||
|
||||
def create_and_map(allocation_handle: HandleType) -> None:
|
||||
python_create_and_map(*allocation_handle)
|
||||
|
||||
|
||||
def unmap_and_release(allocation_handle: HandleType) -> None:
|
||||
python_unmap_and_release(*allocation_handle)
|
||||
|
||||
|
||||
def get_pluggable_allocator(
|
||||
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
|
||||
python_free_func: Callable[[int], tuple[int, int, int, int]]
|
||||
) -> torch.mlu.memory.MLUPluggableAllocator:
|
||||
init_module(python_malloc_fn, python_free_func)
|
||||
new_alloc = torch.mlu.memory.MLUPluggableAllocator(
|
||||
lib_name, "my_malloc", "my_free"
|
||||
)
|
||||
return new_alloc
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool_with_allocator(
|
||||
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
|
||||
python_free_func: Callable[[int], tuple[int, int, int, int]]):
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.mlu.memory.MemPool(new_alloc._allocator)
|
||||
with torch.mlu.memory.use_mem_pool(mem_pool):
|
||||
yield mem_pool, new_alloc
|
||||
|
||||
|
||||
class CnMemAllocator:
|
||||
"""
|
||||
A singleton class that manages a memory pool for MLU tensors.
|
||||
The memory in this pool can be offloaded or discarded when the
|
||||
allocator sleeps.
|
||||
|
||||
Inside the `use_memory_pool(tag)` context, all tensors created will
|
||||
be allocated in the memory pool, and has the same tag as the
|
||||
tag passed to the context.
|
||||
|
||||
When we call `sleep`, all tensors with the specified tag will be
|
||||
offloaded to CPU memory, and the rest of the tensors will be discarded.
|
||||
When we call `wake_up`, all tensors that are previously offloaded
|
||||
will be loaded back to GPU memory, and the rest of the tensors will
|
||||
have empty memory.
|
||||
|
||||
Why it needs to be a singleton?
|
||||
When allocated tensors are garbage collected, PyTorch will call
|
||||
the free callback, which will call the `python_free_callback` method.
|
||||
The C-extension uses a global variable to store the function of an
|
||||
instance of this class. If we create multiple instances of this class,
|
||||
the global variable will be overwritten and the free callback will
|
||||
not work as expected.
|
||||
"""
|
||||
|
||||
instance: "CnMemAllocator" = None
|
||||
default_tag: str = "default"
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> "CnMemAllocator":
|
||||
"""
|
||||
CnMemAllocator is a singleton class.
|
||||
We cannot call the constructor directly.
|
||||
Call this method to get the instance.
|
||||
"""
|
||||
assert cnmem_available, "cnmem allocator is not available"
|
||||
if CnMemAllocator.instance is None:
|
||||
CnMemAllocator.instance = CnMemAllocator()
|
||||
return CnMemAllocator.instance
|
||||
|
||||
def __init__(self):
|
||||
conf = os.environ.get("PYTORCH_MLU_ALLOC_CONF", "")
|
||||
assert "expandable_segments:True" not in conf, (
|
||||
"Expandable segments are not compatible with memory pool. "
|
||||
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
||||
"for the latest updates."
|
||||
)
|
||||
|
||||
self.pointer_to_data: dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CnMemAllocator.default_tag
|
||||
self.allocator_and_pools: dict[str, Any] = {}
|
||||
# Creating strong references to the two callbacks here to prevent
|
||||
# these ephemeral bound-method objects being garbage collected.
|
||||
# See discussions in https://github.com/vllm-project/vllm/pull/22724
|
||||
self.python_malloc_callback = self._python_malloc_callback
|
||||
self.python_free_callback = self._python_free_callback
|
||||
|
||||
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
Internal method to store the allocation data
|
||||
when memory is allocated in the memory pool."""
|
||||
py_d_mem = allocation_handle[2]
|
||||
self.pointer_to_data[py_d_mem] = AllocationData(
|
||||
allocation_handle, self.current_tag
|
||||
)
|
||||
logger.debug(
|
||||
"Allocated %s bytes for %s with address %s from cnmem allocator",
|
||||
allocation_handle[1],
|
||||
self.current_tag,
|
||||
py_d_mem,
|
||||
)
|
||||
return
|
||||
|
||||
def _python_free_callback(self, ptr: int) -> HandleType:
|
||||
"""
|
||||
Internal method to look up the allocation data
|
||||
when memory is freed in the memory pool."""
|
||||
data = self.pointer_to_data.pop(ptr)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
data.cpu_backup_tensor = None
|
||||
logger.debug(
|
||||
"Freed %s bytes for %s with address %s from cnmem allocator",
|
||||
data.handle[1],
|
||||
data.tag,
|
||||
ptr,
|
||||
)
|
||||
return data.handle
|
||||
|
||||
def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
|
||||
"""
|
||||
Put the allocator in sleep mode.
|
||||
All data in the memory allocation with the specified tag will be
|
||||
offloaded to CPU memory, and others will be discarded.
|
||||
|
||||
:param offload_tags: The tags of the memory allocation that will be
|
||||
offloaded. The rest of the memory allocation will be discarded.
|
||||
"""
|
||||
if offload_tags is None:
|
||||
# by default, allocated tensors are offloaded
|
||||
# when the allocator sleeps
|
||||
offload_tags = (CnMemAllocator.default_tag, )
|
||||
elif isinstance(offload_tags, str):
|
||||
offload_tags = (offload_tags,)
|
||||
|
||||
assert isinstance(offload_tags, tuple)
|
||||
|
||||
total_bytes = 0
|
||||
backup_bytes = 0
|
||||
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
total_bytes += handle[1]
|
||||
if data.tag in offload_tags:
|
||||
backup_bytes += handle[1]
|
||||
size_in_bytes = handle[1]
|
||||
cpu_backup_tensor = torch.empty(
|
||||
size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device="cpu",
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
python_cn_memcpy(cpu_ptr, ptr, size_in_bytes)
|
||||
data.cpu_backup_tensor = cpu_backup_tensor
|
||||
unmap_and_release(handle)
|
||||
|
||||
logger.info(
|
||||
"CnMemAllocator: sleep freed %.2f GiB memory in total, of which "
|
||||
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
|
||||
"directly.",
|
||||
total_bytes / 1024**3,
|
||||
backup_bytes / 1024**3,
|
||||
(total_bytes - backup_bytes) / 1024**3,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
torch.mlu.empty_cache()
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
"""
|
||||
Wake up the allocator from sleep mode.
|
||||
All data that is previously offloaded will be loaded back to GPU
|
||||
memory, and the rest of the data will have empty memory.
|
||||
|
||||
:param tags: The tags of the memory allocation that will be loaded
|
||||
back to GPU memory. If None, all memory allocation will be loaded
|
||||
back to GPU memory.
|
||||
"""
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
if tags is None or data.tag in tags:
|
||||
handle = data.handle
|
||||
create_and_map(handle)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
cpu_backup_tensor = data.cpu_backup_tensor
|
||||
if cpu_backup_tensor is not None:
|
||||
size_in_bytes = (
|
||||
cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
|
||||
)
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
python_cn_memcpy(ptr, cpu_ptr, size_in_bytes)
|
||||
data.cpu_backup_tensor = None
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool(self, tag: str | None = None):
|
||||
"""
|
||||
A context manager to use the memory pool.
|
||||
All memory allocation created inside the context will be allocated
|
||||
in the memory pool, and has the specified tag.
|
||||
|
||||
:param tag: The tag of the memory allocation. If None, the default tag
|
||||
will be used.
|
||||
"""
|
||||
if tag is None:
|
||||
tag = CnMemAllocator.default_tag
|
||||
|
||||
assert isinstance(tag, str)
|
||||
|
||||
old_tag = self.current_tag
|
||||
self.current_tag = tag
|
||||
with use_memory_pool_with_allocator(
|
||||
self.python_malloc_callback, self.python_free_callback
|
||||
) as data:
|
||||
# start to hit another PyTorch bug in PyTorch 2.6,
|
||||
# possibly because of gc-related issue w.r.t. the allocator and
|
||||
# the memory pool.
|
||||
# to avoid the issue, we keep a reference of the data.
|
||||
# see https://github.com/pytorch/pytorch/issues/146431 .
|
||||
self.allocator_and_pools[tag] = data
|
||||
yield
|
||||
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
||||
# when using pluggable allocator, see
|
||||
# https://github.com/pytorch/pytorch/issues/145168 .
|
||||
# if we have some memory allocated and then freed,
|
||||
# the memory will not be released, e.g. in online quantization,
|
||||
# where the model is created in higher precision, and then
|
||||
# quantized in lower precision.
|
||||
# Find all unused allocations and manually release them.
|
||||
# TODO: we should expose `empty_cache` method in the memory pool.
|
||||
# TODO: ask for help from PyTorch team to expose this method.
|
||||
allocations = data[0].snapshot()
|
||||
for allocation in allocations:
|
||||
if allocation["allocated_size"] == 0:
|
||||
handle = self._python_free_callback(allocation["address"])
|
||||
unmap_and_release(handle)
|
||||
self.current_tag = old_tag
|
||||
|
||||
def get_current_usage(self) -> int:
|
||||
"""
|
||||
Get the total number of bytes allocated in the memory pool.
|
||||
"""
|
||||
sum_bytes: int = 0
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
sum_bytes += handle[1]
|
||||
return sum_bytes
|
||||
3
vllm_mlu/distributed/__init__.py
Normal file
3
vllm_mlu/distributed/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
3
vllm_mlu/distributed/device_communicators/__init__.py
Normal file
3
vllm_mlu/distributed/device_communicators/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase,
|
||||
)
|
||||
|
||||
|
||||
class MLUCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = ""
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
# init device according to rank
|
||||
self.device = torch.mlu.current_device()
|
||||
self.ca_comm: CustomAllreduce | None = None
|
||||
20
vllm_mlu/distributed/kv_transfer/kv_connector/factory.py
Normal file
20
vllm_mlu/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
|
||||
|
||||
MLUKVConnectors: dict[str, tuple[str, str]] = {
|
||||
"MLUSharedStorageConnector": (
|
||||
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
||||
"SharedStorageConnector"
|
||||
),
|
||||
"MLUNixlConnector": (
|
||||
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"MLUNixlConnector"
|
||||
),
|
||||
}
|
||||
|
||||
for name, (module_path, class_name) in MLUKVConnectors.items():
|
||||
if name not in KVConnectorFactory._registry:
|
||||
KVConnectorFactory.register_connector(name, module_path, class_name)
|
||||
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
class LMCacheConnectorV1_MluHijack(LMCacheConnectorV1):
|
||||
|
||||
def response_remote_alloc_once(self) -> None:
|
||||
self._lmcache_engine.response_remote_alloc_once()
|
||||
|
||||
def request_remote_memory_send(self) -> None:
|
||||
self._lmcache_engine.request_remote_memory_send()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LMCacheConnectorV1,
|
||||
"response_remote_alloc_once",
|
||||
LMCacheConnectorV1_MluHijack.response_remote_alloc_once)
|
||||
MluHijackObject.apply_hijack(LMCacheConnectorV1,
|
||||
"request_remote_memory_send",
|
||||
LMCacheConnectorV1_MluHijack.request_remote_memory_send)
|
||||
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
EngineId, NixlConnectorWorker, NixlAgentMetadata, NixlConnectorScheduler, NixlConnector)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
logger.info("NIXL is available")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
|
||||
class MLUNixlConnector(NixlConnector):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super(NixlConnector, self).__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler : MLUNixlConnectorScheduler | None = (
|
||||
MLUNixlConnectorScheduler(vllm_config, self.engine_id)
|
||||
)
|
||||
self.connector_worker: MLUNixlConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MLUNixlConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
|
||||
class MLUNixlConnectorScheduler(NixlConnectorScheduler):
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: kv transfer info
|
||||
'''
|
||||
if request.kv_transfer_params.get("do_remote_prefill", False):
|
||||
logger.info(f"NIXLConnector update_state_after_alloc: request_id={request.request_id}, "
|
||||
f"num_prompt_tokens={request.num_prompt_tokens}, "
|
||||
f"num_external_tokens={num_external_tokens}, "
|
||||
f"kv_transfer_params={request.kv_transfer_params}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"NIXLConnector update_state_after_alloc: "
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
self._reqs_in_batch.add(request.request_id)
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
|
||||
# save all blocks
|
||||
block_ids = blocks.get_block_ids()[0]
|
||||
# TODO: skip the blocks that are already in the host xfer buffer.
|
||||
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
||||
# kv blocks, so host blocks won't be flushed as long as its device
|
||||
# block is not overwritten; and it will be safe to skip saving them
|
||||
# to host xfer buffer.
|
||||
if block_ids:
|
||||
self._reqs_need_save[request.request_id] = (request, block_ids)
|
||||
elif params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(
|
||||
p in params
|
||||
for p in ("remote_engine_id", "remote_host", "remote_port")
|
||||
):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids()
|
||||
if num_external_tokens > 0
|
||||
else []
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request,
|
||||
local_block_ids,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
else:
|
||||
assert num_external_tokens == 0
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
|
||||
class MLUNixlConnectorWorker(NixlConnectorWorker):
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: not support kv8
|
||||
'''
|
||||
if not isinstance(first_kv_cache, torch.Tensor):
|
||||
kv_caches = {key: value[0] for key, value in kv_caches.items()}
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: support mla
|
||||
'''
|
||||
use_mla = first_kv_cache.shape[0] == 1
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
assert use_mla == self.use_mla
|
||||
|
||||
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
|
||||
# once it goes live, as a single kv layout is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: support mla
|
||||
'''
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
'''
|
||||
=============================
|
||||
Add by vllm_mlu
|
||||
=============================
|
||||
@brief: MLU kv_cache layout is [2 (k and v), num_blocks, kv_heads, block_size, head_dim]
|
||||
'''
|
||||
n_kv_heads, block_size, head_dim = block_shape[-3:]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
|
||||
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
|
||||
self.num_blocks, block_shape, first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
|
||||
# Note(tms): I modified this from the original region setup code.
|
||||
# K and V are now in different regions. Advantage is that we can
|
||||
# elegantly support MLA and any cases where the K and V tensors
|
||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
caches_data.append(
|
||||
(base_addr, region_len, cache.device.index, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(self.kv_caches.keys())
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
tp_size=self.world_size,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait()
|
||||
@@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from vllm_mlu.v1.attention.backends.flash_mla import MLAFlashAttentionCommonMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
mm_hashes: list[str]
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
|
||||
)
|
||||
|
||||
|
||||
class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp")
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the connector metadata is None"
|
||||
)
|
||||
return
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info(
|
||||
"Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping),
|
||||
)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE/MLP etc.
|
||||
kv_cache_attr = getattr(layer, "kv_cache", None)
|
||||
if kv_cache_attr is None:
|
||||
continue
|
||||
|
||||
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
token_ids = request.prompt_token_ids or []
|
||||
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = SharedStorageConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
return self._found_match_for_prompt(
|
||||
list(request.prompt_token_ids or []),
|
||||
[f.identifier for f in request.mm_features],
|
||||
)
|
||||
|
||||
def _found_match_for_prompt(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
mm_hashes: list[str],
|
||||
) -> bool:
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||
mm_hashes,
|
||||
create_folder=False,
|
||||
)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
token_bytes = token_ids.numpy().tobytes()
|
||||
# Add mm_hashes to the bytes being hashed to avoid path traversal and
|
||||
# to create a canonical key.
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(
|
||||
token_ids, mm_hashes=mm_hashes, create_folder=True
|
||||
)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size."""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
286
vllm_mlu/distributed/parallel_state.py
Normal file
286
vllm_mlu/distributed/parallel_state.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
GraphCaptureContext,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.mlu_parallel_state import(
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_expert_parallel_rank,
|
||||
get_moe_expert_parallel_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class MLUGraphCaptureContext:
|
||||
stream: torch.mlu.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mlu_graph_capture(device: torch.device):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current CUDA stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
context = MLUGraphCaptureContext(torch.mlu.Stream(device=device))
|
||||
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
|
||||
yield context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def vllm__distributed__parallel_state__GroupCoordinator__graph_capture(
|
||||
self,
|
||||
graph_capture_context: GraphCaptureContext | None = None,
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = torch.mlu.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# only cuda uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
maybe_ca_context = nullcontext()
|
||||
from vllm_mlu.distributed.device_communicators.mlu_communicator import (
|
||||
MLUCommunicator,
|
||||
)
|
||||
|
||||
if self.device_communicator is not None:
|
||||
assert isinstance(self.device_communicator, MLUCommunicator)
|
||||
ca_comm = self.device_communicator.ca_comm
|
||||
if ca_comm is not None:
|
||||
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch.mlu.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.mlu.stream(stream), maybe_ca_context:
|
||||
yield graph_capture_context
|
||||
|
||||
@dataclass
|
||||
class CnclEPBuffer:
|
||||
dispatch_send_token_tensor: torch.Tensor
|
||||
dispatch_recv_token_tensor: torch.Tensor
|
||||
combine_send_token_tensor: torch.Tensor
|
||||
combine_recv_token_tensor: torch.Tensor
|
||||
|
||||
class CnclEP:
|
||||
|
||||
def __init__(self,
|
||||
dispatch_token_size: int,
|
||||
combine_token_size: int,
|
||||
max_num_tokens_per_rank: int,
|
||||
num_global_experts: int,
|
||||
use_quant_dispatch: bool = True) -> None:
|
||||
nranks = get_moe_expert_parallel_world_size()
|
||||
rank = get_moe_expert_parallel_rank()
|
||||
moe_ep_group = get_moe_expert_parallel_group()
|
||||
self.max_num_tokens_per_rank = max_num_tokens_per_rank
|
||||
self.use_quant_dispatch = use_quant_dispatch
|
||||
|
||||
(
|
||||
handle,
|
||||
exchange_info_size,
|
||||
exchange_info,
|
||||
dispatch_send_token_tensor,
|
||||
dispatch_recv_token_tensor,
|
||||
combine_send_token_tensor,
|
||||
combine_recv_token_tensor
|
||||
) = mlu_ops.moe_all2all_create(dispatch_token_size,
|
||||
combine_token_size,
|
||||
num_global_experts,
|
||||
max_num_tokens_per_rank,
|
||||
rank,
|
||||
nranks)
|
||||
self.handle = handle
|
||||
self.buffer = CnclEPBuffer(
|
||||
dispatch_send_token_tensor,
|
||||
dispatch_recv_token_tensor,
|
||||
combine_send_token_tensor,
|
||||
combine_recv_token_tensor)
|
||||
|
||||
assert exchange_info.ndim == 1, "exchange_info should be 1D"
|
||||
all_exchange_info = torch.empty((nranks, exchange_info.size(0)),
|
||||
dtype=exchange_info.dtype,
|
||||
device=exchange_info.device)
|
||||
exchange_info = exchange_info.unsqueeze(0)
|
||||
torch.distributed.all_gather_into_tensor(all_exchange_info,
|
||||
exchange_info,
|
||||
group=moe_ep_group.cpu_group,
|
||||
async_op=False)
|
||||
mlu_ops.moe_all2all_init(self.handle, all_exchange_info)
|
||||
torch.distributed.barrier(group=moe_ep_group.cpu_group)
|
||||
|
||||
def dispatch(self,
|
||||
token_byte: int,
|
||||
token_num: int,
|
||||
send_layout: torch.Tensor,
|
||||
send_token_num: torch.Tensor,
|
||||
recv_layout: torch.Tensor,
|
||||
recv_token_num: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
'''
|
||||
The returned tensors are in-placed modified, we could directly use them
|
||||
after dispatch finishes.
|
||||
'''
|
||||
mlu_ops.moe_all2all_dispatch(self.handle,
|
||||
token_byte,
|
||||
token_num,
|
||||
send_layout,
|
||||
send_token_num,
|
||||
recv_layout,
|
||||
recv_token_num,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
def combine(self,
|
||||
token_byte: int,
|
||||
token_num: int,
|
||||
send_src_layout: torch.Tensor,
|
||||
send_dst_layout: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
) ->None:
|
||||
mlu_ops.moe_all2all_combine(self.handle,
|
||||
token_byte,
|
||||
token_num,
|
||||
send_src_layout,
|
||||
send_dst_layout,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
def destroy(self) -> None:
|
||||
mlu_ops.moe_all2all_destroy(self.handle)
|
||||
|
||||
_CNCLEP: CnclEP | None = None
|
||||
_CNCLEP_BF16: CnclEP | None = None
|
||||
|
||||
def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP:
|
||||
if use_quant_dispatch:
|
||||
assert _CNCLEP is not None, "cnclep is not initialized"
|
||||
return _CNCLEP
|
||||
else:
|
||||
assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized"
|
||||
return _CNCLEP_BF16
|
||||
|
||||
def init_cnclep(dispatch_token_size: int,
|
||||
combine_token_size: int,
|
||||
max_num_tokens_per_rank: int,
|
||||
num_global_experts: int,
|
||||
use_quant_dispatch: bool = True):
|
||||
if use_quant_dispatch:
|
||||
global _CNCLEP
|
||||
assert _CNCLEP is None, "cnclep has been initialized"
|
||||
_CNCLEP = CnclEP(dispatch_token_size,
|
||||
combine_token_size,
|
||||
max_num_tokens_per_rank,
|
||||
num_global_experts,
|
||||
use_quant_dispatch)
|
||||
else:
|
||||
global _CNCLEP_BF16
|
||||
assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized"
|
||||
_CNCLEP_BF16 = CnclEP(dispatch_token_size,
|
||||
combine_token_size,
|
||||
max_num_tokens_per_rank,
|
||||
num_global_experts,
|
||||
use_quant_dispatch)
|
||||
|
||||
def cnclep_dispatch(token_byte: int,
|
||||
token_num: int,
|
||||
send_layout: torch.Tensor,
|
||||
send_token_num: torch.Tensor,
|
||||
recv_layout: torch.Tensor,
|
||||
recv_token_num: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
use_quant_dispatch: bool = True,
|
||||
):
|
||||
if use_quant_dispatch:
|
||||
_CNCLEP.dispatch(token_byte,
|
||||
token_num,
|
||||
send_layout,
|
||||
send_token_num,
|
||||
recv_layout,
|
||||
recv_token_num,
|
||||
send_token,
|
||||
recv_token)
|
||||
else:
|
||||
_CNCLEP_BF16.dispatch(token_byte,
|
||||
token_num,
|
||||
send_layout,
|
||||
send_token_num,
|
||||
recv_layout,
|
||||
recv_token_num,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
def cnclep_combine(token_byte: int,
|
||||
token_num: int,
|
||||
send_src_layout: torch.Tensor,
|
||||
send_dst_layout: torch.Tensor,
|
||||
send_token: Optional[torch.Tensor] = None,
|
||||
recv_token: Optional[torch.Tensor] = None,
|
||||
use_quant_dispatch: bool = True,
|
||||
):
|
||||
if use_quant_dispatch:
|
||||
_CNCLEP.combine(token_byte,
|
||||
token_num,
|
||||
send_src_layout,
|
||||
send_dst_layout,
|
||||
send_token,
|
||||
recv_token)
|
||||
else:
|
||||
_CNCLEP_BF16.combine(token_byte,
|
||||
token_num,
|
||||
send_src_layout,
|
||||
send_dst_layout,
|
||||
send_token,
|
||||
recv_token)
|
||||
|
||||
|
||||
def destroy_cnclep():
|
||||
global _CNCLEP
|
||||
|
||||
if _CNCLEP:
|
||||
_CNCLEP.destroy()
|
||||
_CNCLEP = None
|
||||
|
||||
global _CNCLEP_BF16
|
||||
|
||||
if _CNCLEP_BF16:
|
||||
_CNCLEP_BF16.destroy()
|
||||
_CNCLEP_BF16 = None
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(GroupCoordinator,
|
||||
GroupCoordinator.graph_capture,
|
||||
vllm__distributed__parallel_state__GroupCoordinator__graph_capture)
|
||||
|
||||
3
vllm_mlu/engine/__init__.py
Normal file
3
vllm_mlu/engine/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
294
vllm_mlu/engine/arg_utils.py
Normal file
294
vllm_mlu/engine/arg_utils.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import get_args
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.config import (
|
||||
ModelConfig,
|
||||
VllmConfig,
|
||||
SchedulerConfig,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.engine.arg_utils import (
|
||||
EngineArgs,
|
||||
_raise_unsupported_error,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults(
|
||||
cls,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[bool, bool]:
|
||||
if model_config.runner_type != "pooling":
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: mlu-v1 default use unchunked scheduler
|
||||
'''
|
||||
if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED:
|
||||
default_chunked_prefill = False
|
||||
else:
|
||||
default_chunked_prefill = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
default_prefix_caching = not model_config.is_hybrid
|
||||
else:
|
||||
assert model_config.pooler_config is not None
|
||||
|
||||
pooling_type = model_config.pooler_config.pooling_type
|
||||
incremental_prefill_supported = (
|
||||
pooling_type is not None
|
||||
and pooling_type.lower() == "last"
|
||||
and getattr(model_config.hf_config, "is_causal", True)
|
||||
)
|
||||
|
||||
default_chunked_prefill = incremental_prefill_supported
|
||||
default_prefix_caching = incremental_prefill_supported
|
||||
|
||||
return default_chunked_prefill, default_prefix_caching
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs___set_default_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
(
|
||||
default_chunked_prefill,
|
||||
default_prefix_caching,
|
||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = default_chunked_prefill
|
||||
|
||||
logger.debug(
|
||||
"%s chunked prefill by default",
|
||||
"Enabling" if default_chunked_prefill else "Disabling",
|
||||
)
|
||||
elif (
|
||||
model_config.runner_type == "pooling"
|
||||
and self.enable_chunked_prefill
|
||||
and not default_chunked_prefill
|
||||
):
|
||||
logger.warning(
|
||||
"This model does not officially support chunked prefill. "
|
||||
"Enabling this manually may cause the engine to crash "
|
||||
"or produce incorrect outputs.",
|
||||
)
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = default_prefix_caching
|
||||
|
||||
logger.debug(
|
||||
"%s prefix caching by default",
|
||||
"Enabling" if default_prefix_caching else "Disabling",
|
||||
)
|
||||
elif (
|
||||
model_config.runner_type == "pooling"
|
||||
and self.enable_prefix_caching
|
||||
and not default_prefix_caching
|
||||
):
|
||||
logger.warning(
|
||||
"This model does not officially support prefix caching. "
|
||||
"Enabling this manually may cause the engine to crash "
|
||||
"or produce incorrect outputs.",
|
||||
)
|
||||
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
(
|
||||
default_max_num_batched_tokens,
|
||||
default_max_num_seqs,
|
||||
) = self.get_batch_defaults(world_size)
|
||||
|
||||
orig_max_num_batched_tokens = self.max_num_batched_tokens
|
||||
orig_max_num_seqs = self.max_num_seqs
|
||||
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = default_max_num_seqs.get(
|
||||
usage_context,
|
||||
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: only set max_num_batched_tokens when enable chunked_prefill
|
||||
'''
|
||||
if self.max_num_batched_tokens is None:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
|
||||
usage_context,
|
||||
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
if orig_max_num_batched_tokens is None:
|
||||
if not self.enable_chunked_prefill:
|
||||
# If max_model_len is too short, use the default for higher throughput.
|
||||
self.max_num_batched_tokens = max(
|
||||
model_config.max_model_len,
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# When using default settings,
|
||||
# Ensure max_num_batched_tokens does not exceed model limit.
|
||||
# Some models (e.g., Whisper) have embeddings tied to max length.
|
||||
self.max_num_batched_tokens = min(
|
||||
self.max_num_seqs * model_config.max_model_len,
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Defaulting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens,
|
||||
usage_context.value if usage_context else None,
|
||||
)
|
||||
|
||||
if orig_max_num_seqs is None:
|
||||
if self.max_num_batched_tokens is not None: # For type checking
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
logger.debug(
|
||||
"Defaulting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs,
|
||||
usage_context.value if usage_context else None,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3']
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs__create_engine_config(
|
||||
self,
|
||||
usage_context: UsageContext | None = None,
|
||||
headless: bool = False,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create the VllmConfig.
|
||||
|
||||
NOTE: If VllmConfig is incompatible, we raise an error.
|
||||
"""
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add data parallel params to parallel config.
|
||||
'''
|
||||
if self.mlu_config and "decoder_attn_dtype" in self.mlu_config:
|
||||
if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]:
|
||||
self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype")
|
||||
|
||||
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(
|
||||
self, usage_context, headless)
|
||||
|
||||
world_size = engine_config.parallel_config.world_size_across_dp
|
||||
tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size
|
||||
embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size
|
||||
if embedding_tp_size:
|
||||
assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, (
|
||||
f"embedding_tp_size = {embedding_tp_size} out of bounds. "
|
||||
f"Require {tensor_parallel_size} ≤ size ≤ {world_size}")
|
||||
dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size
|
||||
if dense_mlp_tp_size:
|
||||
assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, (
|
||||
f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}")
|
||||
if dense_mlp_tp_size != world_size:
|
||||
assert not engine_config.mlu_config.is_dpsk_mcc_enabled, (
|
||||
"dense_mlp_tp_size is not supported when dpsk mcc is enabled.")
|
||||
if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1:
|
||||
raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. "
|
||||
"Necessity of this constraint requires further investigation.")
|
||||
if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size:
|
||||
raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} "
|
||||
f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}")
|
||||
if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0:
|
||||
raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must "
|
||||
f"be divisible by {tensor_parallel_size}")
|
||||
|
||||
if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None
|
||||
or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph):
|
||||
logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.")
|
||||
engine_config.mlu_config.prefill_enable_mlugraph = False
|
||||
if engine_config.mlu_config.decoder_attn_dtype:
|
||||
if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType):
|
||||
raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} "
|
||||
f"decoder_attn_dtype for now")
|
||||
is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and
|
||||
engine_config.model_config.hf_text_config.model_type == "glm4_moe")
|
||||
if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe)
|
||||
and engine_config.mlu_config.decoder_attn_dtype != "auto"):
|
||||
raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model")
|
||||
|
||||
# sequence parallel checks
|
||||
if (engine_config.mlu_config.prefill_use_sequence_parallel
|
||||
and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]):
|
||||
raise ValueError("Prefill sequence parallel can only use in deepseek model.")
|
||||
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill:
|
||||
raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.")
|
||||
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled:
|
||||
raise ValueError("Prefill sequence parallel can not use with mcc.")
|
||||
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1:
|
||||
raise ValueError("Prefill sequence parallel can not use with data parallel.")
|
||||
if (engine_config.mlu_config.prefill_use_sequence_parallel
|
||||
and engine_config.model_config.hf_text_config.model_type == "deepseek_v3"
|
||||
and engine_config.quant_config.get_name() != "SmoothQuant"):
|
||||
raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.")
|
||||
|
||||
# disagg constraint
|
||||
# 1、only support deepseek-v3/r1
|
||||
# 2、unsupport kv8
|
||||
if self.kv_transfer_config is not None:
|
||||
if engine_config.model_config.hf_config.model_type != "deepseek_v3":
|
||||
raise ValueError("Disagg only support DeepDeek-V3/R1")
|
||||
if engine_config.cache_config.cache_dtype == "int8":
|
||||
raise ValueError("Disagg does not support KV cache dtype is int8")
|
||||
if engine_config.cache_config.enable_prefix_caching:
|
||||
raise ValueError("Disagg does not support prefix caching")
|
||||
|
||||
if isinstance(self.kv_transfer_config, dict):
|
||||
kv_connector = self.kv_transfer_config.get("kv_connector")
|
||||
kv_role = self.kv_transfer_config.get("kv_role")
|
||||
else:
|
||||
kv_connector = self.kv_transfer_config.kv_connector
|
||||
kv_role = self.kv_transfer_config.kv_role
|
||||
|
||||
if kv_connector != "LMCacheConnectorV1":
|
||||
raise ValueError("Disagg only support LMCacheConnectorV1 connector")
|
||||
if kv_role == "kv_consumer":
|
||||
if not self.enable_chunked_prefill:
|
||||
raise ValueError("Disagg decoder only support chunk scheduler")
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return engine_config
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs._set_default_args,
|
||||
vllm__engine__arg_utils__EngineArgs___set_default_args)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.create_engine_config,
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.get_chunked_prefill_prefix_caching_defaults,
|
||||
vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults)
|
||||
3
vllm_mlu/entrypoints/__init__.py
Normal file
3
vllm_mlu/entrypoints/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
158
vllm_mlu/entrypoints/llm.py
Normal file
158
vllm_mlu/entrypoints/llm.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import Callable
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.logger import init_logger
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm_mlu.mlu_metric import LLMMetric
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def vllm__entrypoints__llm__LLM__get_mlu_metrics(
|
||||
self,
|
||||
metrics_idx_start,
|
||||
only_average,
|
||||
input_len,
|
||||
output_len,
|
||||
tp_nums,
|
||||
quantization,
|
||||
show_per_iter=False,
|
||||
is_embedding_task=False,
|
||||
mm_kwargs=None,
|
||||
total_prefill_steps=1,
|
||||
num_speculative_tokens=0,
|
||||
dp_size=1,
|
||||
) -> None:
|
||||
'''
|
||||
@brief:该函数用来打印vLLM调用generate接口过程中代码统计的各项性能指标数据
|
||||
@params:
|
||||
metrics_idx_start: 考虑存在调用generate接口为warmup过程的情况,
|
||||
因此设置该参数可忽略统计[0,metrics_idx_start)之间的数据,默认为0,即所有性能数据有效。
|
||||
only_average: True 只打印N次调用generate接口的平均性能 False 打印每次调用generate接口的性能及其均值 若N次性能数据波动较大,需自行排查测试环境是否稳定。
|
||||
其余参数:均为模型配置参数
|
||||
'''
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
|
||||
batch_size = self.metric.batch_size_list[-1] * dp_size
|
||||
if mm_kwargs or is_embedding_task:
|
||||
# The multimodal and pooling model doesn't support the hfu feature yet.
|
||||
hfu_info, io_efficiency = None, None
|
||||
else:
|
||||
hfu_info, io_efficiency = self.llm_engine.get_hfu_info(batch_size, input_len, output_len)
|
||||
self.metric.calc_metric(
|
||||
self.llm_engine.model_config.model,
|
||||
self.llm_engine.model_config.dtype,
|
||||
metrics_idx_start, only_average,
|
||||
input_len, output_len, tp_nums,
|
||||
quantization, show_per_iter,
|
||||
is_embedding_task, mm_kwargs, total_prefill_steps,
|
||||
num_speculative_tokens, dp_size=dp_size, hfu_info=hfu_info, io_efficiency=io_efficiency)
|
||||
else:
|
||||
print("Warnning:please set VLLM_LATENCY_DEBUG=true!")
|
||||
|
||||
|
||||
def vllm__entrypoints__llm__LLM___run_engine(
|
||||
self, *, use_tqdm: bool | Callable[..., tqdm] = True
|
||||
) -> list[RequestOutput | PoolingRequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
pbar = tqdm_func(
|
||||
total=num_requests,
|
||||
desc="Processed prompts",
|
||||
dynamic_ncols=True,
|
||||
postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
|
||||
total_request_num = self.llm_engine.get_num_unfinished_requests()
|
||||
e2e_start_time = self.metric.get_mlu_cost_time()
|
||||
if not self.llm_engine.model_config.is_embedding_task():
|
||||
peak_memory, block_memory, num_gpu_blocks, num_cpu_blocks = \
|
||||
self.llm_engine.get_memory_usage()
|
||||
self.metric.update_memory_usage(peak_memory, block_memory,
|
||||
num_gpu_blocks, num_cpu_blocks)
|
||||
'''
|
||||
==================
|
||||
End of addition
|
||||
==================
|
||||
'''
|
||||
|
||||
# Run the engine.
|
||||
outputs: list[RequestOutput | PoolingRequestOutput] = []
|
||||
total_in_toks = 0
|
||||
total_out_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
if isinstance(output, RequestOutput):
|
||||
# Calculate tokens only for RequestOutput
|
||||
n = len(output.outputs)
|
||||
assert output.prompt_token_ids is not None
|
||||
total_in_toks += len(output.prompt_token_ids) * n
|
||||
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||
total_out_toks += sum(
|
||||
len(stp.token_ids) for stp in output.outputs
|
||||
)
|
||||
out_spd = total_out_toks / pbar.format_dict["elapsed"]
|
||||
pbar.postfix = (
|
||||
f"est. speed input: {in_spd:.2f} toks/s, "
|
||||
f"output: {out_spd:.2f} toks/s"
|
||||
)
|
||||
pbar.update(n)
|
||||
else:
|
||||
pbar.update(1)
|
||||
if pbar.n == num_requests:
|
||||
pbar.refresh()
|
||||
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
'''
|
||||
=============================
|
||||
Added by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
|
||||
e2e_end_time = self.metric.get_mlu_cost_time()
|
||||
e2e_latency = e2e_end_time - e2e_start_time
|
||||
|
||||
engine_step_latency, model_forward_latency, mm_encoder_latency = self.llm_engine.get_latency()
|
||||
self.metric.update_step_latency(engine_step_latency)
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
self.metric.update_step_latency_device(model_forward_latency)
|
||||
self.metric.update_mm_encoder_latency_device(mm_encoder_latency)
|
||||
|
||||
self.metric.add_metrics(total_request_num, e2e_latency)
|
||||
'''
|
||||
==================
|
||||
End of addition
|
||||
==================
|
||||
'''
|
||||
# Sort the outputs by request ID.
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||
|
||||
|
||||
LLM.metric = LLMMetric()
|
||||
MluHijackObject.apply_hijack(LLM,
|
||||
"get_mlu_metrics",
|
||||
vllm__entrypoints__llm__LLM__get_mlu_metrics)
|
||||
MluHijackObject.apply_hijack(LLM,
|
||||
LLM._run_engine,
|
||||
vllm__entrypoints__llm__LLM___run_engine)
|
||||
3
vllm_mlu/entrypoints/openai/__init__.py
Normal file
3
vllm_mlu/entrypoints/openai/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
29
vllm_mlu/entrypoints/openai/api_server.py
Normal file
29
vllm_mlu/entrypoints/openai/api_server.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
router, engine_client
|
||||
)
|
||||
from vllm_mlu.logger import logger
|
||||
|
||||
if mlu_envs.VLLM_SCHEDULER_PROFILE:
|
||||
logger.info(
|
||||
"vLLM V1 Scheduler Profiler is enabled in the API server. Please use "
|
||||
"'tools/utils/post_scheduler_view_action.py' to dump profiling data "
|
||||
"after all requests finished.")
|
||||
|
||||
@router.post("/v1/start_scheduler_profile")
|
||||
async def start_scheduler_profile(raw_request: Request):
|
||||
logger.info("VLLM-V1 starting scheduler profiler...")
|
||||
await engine_client(raw_request).start_scheduler_profile()
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/v1/stop_scheduler_profile")
|
||||
async def stop_scheduler_profile(raw_request: Request):
|
||||
logger.info("VLLM-V1 scheduler stopping profiler...")
|
||||
await engine_client(raw_request).stop_scheduler_profile()
|
||||
return Response(status_code=200)
|
||||
41
vllm_mlu/envs.py
Normal file
41
vllm_mlu/envs.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the used env vars.
|
||||
|
||||
# begin-env-vars-definition
|
||||
|
||||
env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# max compile thread num
|
||||
"MAX_JOBS":
|
||||
lambda: os.getenv("MAX_JOBS", None),
|
||||
"CMAKE_BUILD_TYPE":
|
||||
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
||||
"COMPILE_CUSTOM_KERNELS":
|
||||
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
||||
"VERBOSE":
|
||||
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
||||
"LD_LIBRARY_PATH":
|
||||
lambda: os.getenv("LD_LIBRARY_PATH", None),
|
||||
"CXX_COMPILER":
|
||||
lambda: os.getenv("CXX_COMPILER", None),
|
||||
"C_COMPILER":
|
||||
lambda: os.getenv("C_COMPILER", None)
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# lazy evaluation of environment variables
|
||||
if name in env_variables:
|
||||
return env_variables[name]()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
3
vllm_mlu/executor/__init__.py
Normal file
3
vllm_mlu/executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
47
vllm_mlu/logger.py
Normal file
47
vllm_mlu/logger.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
from vllm.logger import _VllmLogger
|
||||
|
||||
|
||||
class _ColorFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if not record.name.startswith('vllm_mlu'):
|
||||
return True
|
||||
if record.levelno == logging.INFO:
|
||||
record.msg = f"\033[32m{record.msg}\033[0m"
|
||||
elif record.levelno == logging.WARNING:
|
||||
record.msg = f"\033[33m{record.msg}\033[0m"
|
||||
return True
|
||||
|
||||
|
||||
def _apply_mlu_color(logger):
|
||||
if not logger.handlers:
|
||||
return
|
||||
for h in logger.handlers:
|
||||
if any(isinstance(f, _ColorFilter) for f in h.filters):
|
||||
return
|
||||
h.addFilter(_ColorFilter())
|
||||
|
||||
|
||||
def _mlu_init_logger(name: str) -> logging.Logger:
|
||||
"""Initialize loggers for vllm_mlu module,
|
||||
and keep the configuration consistent with the vllm module"""
|
||||
mlu_logger = logging.getLogger(name)
|
||||
vllm_logger = logging.Logger.manager.loggerDict.get('vllm', None)
|
||||
if vllm_logger:
|
||||
mlu_logger.setLevel(vllm_logger.level)
|
||||
mlu_logger.propagate = vllm_logger.propagate
|
||||
mlu_logger.handlers = vllm_logger.handlers
|
||||
return mlu_logger
|
||||
|
||||
|
||||
def init_logger(name: str) -> _VllmLogger:
|
||||
vllm_logger = cast(_VllmLogger, _mlu_init_logger(name))
|
||||
_apply_mlu_color(vllm_logger)
|
||||
return vllm_logger
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
41
vllm_mlu/lora/__init__.py
Normal file
41
vllm_mlu/lora/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.layers.column_parallel_linear import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import LoRAMapping
|
||||
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
|
||||
|
||||
__all__ = [
|
||||
"BaseLayerWithLoRA",
|
||||
"VocabParallelEmbeddingWithLoRA",
|
||||
"LogitsProcessorWithLoRA",
|
||||
"ColumnParallelLinearWithLoRA",
|
||||
"ColumnParallelLinearWithShardedLoRA",
|
||||
"MergedColumnParallelLinearWithLoRA",
|
||||
"MergedColumnParallelLinearWithShardedLoRA",
|
||||
"MergedQKVParallelLinearWithLoRA",
|
||||
"MergedQKVParallelLinearWithShardedLoRA",
|
||||
"QKVParallelLinearWithLoRA",
|
||||
"QKVParallelLinearWithShardedLoRA",
|
||||
"RowParallelLinearWithLoRA",
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"FusedMoEWithLoRA",
|
||||
]
|
||||
3
vllm_mlu/lora/layers/__init__.py
Normal file
3
vllm_mlu/lora/layers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
50
vllm_mlu/lora/layers/base_linear.py
Normal file
50
vllm_mlu/lora/layers/base_linear.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual in matmul
|
||||
'''
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# In transformers backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
if x.ndim == 3 and output.ndim == 3:
|
||||
output = output.flatten(0, 1)
|
||||
x = x.flatten(0, 1)
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
BaseLinearLayerWithLoRA,
|
||||
BaseLinearLayerWithLoRA.apply,
|
||||
vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply,
|
||||
)
|
||||
39
vllm_mlu/lora/layers/column_parallel_linear.py
Normal file
39
vllm_mlu/lora/layers/column_parallel_linear.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers.column_parallel_linear import ColumnParallelLinearWithLoRA
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add smooth_quant_scale and use_tp_weight parameters.
|
||||
'''
|
||||
def vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward(
|
||||
self,
|
||||
input_,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert not use_tp_weight, "LoRa does not support use_tp_weight yet."
|
||||
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
|
||||
return vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org(self, input_)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithLoRA.forward,
|
||||
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward,
|
||||
)
|
||||
163
vllm_mlu/lora/layers/row_parallel_linear.py
Normal file
163
vllm_mlu/lora/layers/row_parallel_linear.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual and bias in matmul
|
||||
'''
|
||||
output = self.base_layer.quant_method.apply(
|
||||
self.base_layer, x, bias, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
|
||||
buffer, x, self.lora_a_stacked, 1.0
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
buffer = shrunk_buffer
|
||||
if self.tp_size > 1:
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
# NOTE offset are based on the rank.
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
offset_start = self.tp_rank * shard_size
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
self.output_slices,
|
||||
offset_start=offset_start,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add parameters `residual`, `smooth_quant_scale`, `use_tp_weight` and `output`
|
||||
to keep parameters consistent with RowParallelLinear.forward.
|
||||
'''
|
||||
assert (not use_tp_weight) and output is None, (
|
||||
f"RowParallelLinearWithLoRA.forward does not support use_tp_wight=True"
|
||||
f" or pass output parameters.")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) apply residual fusion in matmul like RowParallelLinear
|
||||
2) add bias in matmul, not after all reduce
|
||||
'''
|
||||
# Matrix multiply.
|
||||
bias_ = (
|
||||
None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add)
|
||||
else self.base_layer.bias
|
||||
)
|
||||
residual_ = None if self.base_layer.tp_rank > 0 else residual
|
||||
output_parallel = self.apply(input_parallel, bias_, residual_)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.base_layer.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: do not add bias after all_reduce
|
||||
'''
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithShardedLoRA.apply,
|
||||
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA.forward,
|
||||
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward,
|
||||
)
|
||||
3
vllm_mlu/lora/ops/__init__.py
Normal file
3
vllm_mlu/lora/ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
16
vllm_mlu/lora/ops/triton_ops/__init__.py
Normal file
16
vllm_mlu/lora/ops/triton_ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.sgmv_expand import sgmv_expand_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops.lora_shrink_op import lora_shrink
|
||||
from vllm_mlu.lora.ops.triton_ops.lora_expand_op import lora_expand
|
||||
|
||||
__all__ = [
|
||||
"sgmv_expand_mlu",
|
||||
"sgmv_expand_slice_mlu",
|
||||
"sgmv_shrink_mlu",
|
||||
"lora_expand",
|
||||
"lora_shrink"
|
||||
]
|
||||
308
vllm_mlu/lora/ops/triton_ops/kernel_utils.py
Normal file
308
vllm_mlu/lora/ops/triton_ops/kernel_utils.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Utilities for Punica kernel construction.
|
||||
"""
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify mm triton
|
||||
1) add parameter offset_n: mlu add offset_n of matrix B,
|
||||
value: tl.arange(0, BLOCK_N) + pid_n * BLOCK_N, shape: [BLOCK_N]
|
||||
add parameter N: mlu add column number of matrix B
|
||||
2) tiled_b always need mask in case offset_n > N
|
||||
'''
|
||||
|
||||
@triton.jit
|
||||
def mm_k(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
ak_stride,
|
||||
bk_stride,
|
||||
offset_n,
|
||||
offset_k,
|
||||
K: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
b_dtype: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||
B (k x n), iterate, through the K dimension to compute the partial/complete
|
||||
matrix block product.
|
||||
If SPLIT_K == 1, the output m x n product is complete.
|
||||
If SPLIT_K > 1, the thread block computes partial outputs. The partial
|
||||
outputs are then atomically summed in the caller code.
|
||||
Args:
|
||||
a_ptr: Array of pointers, identifying rows of A
|
||||
b_ptr: Array of pointers, identifying columns of B
|
||||
ak_stride: K dimension stride of the A matrix
|
||||
bk_stride: K dimension stride of the B matrix
|
||||
K: Length of the K dimension
|
||||
BLOCK_M: M dimension of the output block m x n
|
||||
BLOCK_N: N dimension of the output block m x n
|
||||
BLOCK_K: K dimension atom
|
||||
EVEN_K: True if the blocks of A and B can be loaded without any
|
||||
masking.
|
||||
SPLIT_K: Parameter signifying parallelism in the K dimension.
|
||||
CAST_TYPE: if True, cast the values from the A matrix to the B
|
||||
matrix dtype.
|
||||
b_dtype: datatype of the B matrix
|
||||
"""
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N, other=0.0)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :]
|
||||
< K - k * (BLOCK_K * SPLIT_K),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=(offset_k[:, None]
|
||||
< K - k * (BLOCK_K * SPLIT_K)) & (offset_n < N)[None, :],
|
||||
other=0.0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
||||
return accumulator
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_expand_kernel(
|
||||
pid_n,
|
||||
lora_index,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice,
|
||||
compute the matrix product and store in the appropriate output location.
|
||||
Given that this is an expand kernel, we don't perform any split-K reduction
|
||||
as the K dimension is assumed to be small.
|
||||
"""
|
||||
|
||||
# ls_d*_ptr can be either an integer or a pointer
|
||||
if SAME_STRIDE: # 'same_stride': True
|
||||
# integer
|
||||
cur_lora_d0_stride = ls_d0_ptr
|
||||
cur_lora_d1_stride = ls_d1_ptr
|
||||
cur_lora_d2_stride = ls_d2_ptr
|
||||
else:
|
||||
# pointer
|
||||
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
||||
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
||||
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
||||
|
||||
# Identify the input_ptr and lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
cur_input_ptr = input_ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(out_ptr.dtype.element_ty))
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
|
||||
will handle as head corruption
|
||||
2) re-write b_ptr, use offset_n to identify its position
|
||||
'''
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
|
||||
offset_k[None, :] * input_d2_stride)
|
||||
# b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||
# offset_k[:, None] * cur_lora_d2_stride +
|
||||
# rbn[None, :] * cur_lora_d1_stride)
|
||||
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||
offset_k[:, None] * cur_lora_d2_stride +
|
||||
offset_n[None, :] * cur_lora_d1_stride)
|
||||
|
||||
# Compute the block matrix product.
|
||||
SPLIT_K = 1
|
||||
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, offset_n,
|
||||
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N,
|
||||
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||
if SLICE_NUM == 1:
|
||||
cur_slice_start = slice_start_loc
|
||||
else:
|
||||
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
|
||||
offset_cn[None, :] * output_d1_stride)
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
|
||||
< (cur_slice_start + N))
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_index,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram,
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice, compute the
|
||||
matrix product and store in the appropriate output location.
|
||||
"""
|
||||
|
||||
# Identify the lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(input_ptr.dtype.element_ty))
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
|
||||
will handle as head corruption
|
||||
2) re-write b_ptr, use offset_n to identify its position
|
||||
'''
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
|
||||
offset_k[None, :] * input_d1_stride)
|
||||
# b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||
# rbn[None, :] * lora_d1_stride +
|
||||
# offset_k[:, None] * lora_d2_stride)
|
||||
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||
offset_n[None, :] * lora_d1_stride +
|
||||
offset_k[:, None] * lora_d2_stride)
|
||||
|
||||
# Compute partial/complete block matrix product.
|
||||
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_n, offset_k,
|
||||
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N, False,
|
||||
cur_lora_ptr.dtype.element_ty)
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
|
||||
slice_id * output_d0_stride)
|
||||
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
|
||||
None, :] * output_d2_stride
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
||||
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
308
vllm_mlu/lora/ops/triton_ops/lora_expand_op.py
Normal file
308
vllm_mlu/lora/ops/triton_ops/lora_expand_op.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use vllm_mlu hijacked kernel
|
||||
'''
|
||||
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_loc,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride, # 1
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride, # 1
|
||||
output_hs_ptr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
):
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_mn = tl.program_id(axis=0)
|
||||
pid_m = pid_mn % cta_m_num
|
||||
pid_n = (pid_mn // cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||
# qkv linear.
|
||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||
if pid_n * BLOCK_N >= curr_N:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (
|
||||
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
|
||||
)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_expand_kernel(
|
||||
pid_n,
|
||||
lora_id,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
curr_N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SAME_STRIDE,
|
||||
SLICE_NUM,
|
||||
EVEN_K,
|
||||
CAST_TYPE,
|
||||
ADD_INPUTS,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_expand(
|
||||
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
|
||||
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (list[torch.Tensor]): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
offset_start (int, optional): Offset start for output_tensor.
|
||||
Defaults to 0.
|
||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||
output tensor. Defaults to False.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for weight in lora_b_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(0) == len(lora_b_weights)
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check.
|
||||
M = inputs.size(1)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(
|
||||
slice_start_tensor,
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
hidden_sizes_tensor,
|
||||
same_stride,
|
||||
MAX_N,
|
||||
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
|
||||
|
||||
K = lora_b_weights[0].shape[-1] # K= rank
|
||||
ADD_INPUTS = add_inputs
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
CAST_TYPE = False
|
||||
NUM_SLICES = len(lora_b_weights)
|
||||
|
||||
# Triton kernel configs.
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 16
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||
|
||||
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only a few input tokens require
|
||||
# LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_lora_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
MAX_N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
inputs.stride(2),
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
hidden_sizes_tensor,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
NUM_SLICES,
|
||||
same_stride,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use only vllm operand
|
||||
'''
|
||||
|
||||
lora_expand = _lora_expand
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
258
vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py
Normal file
258
vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py
Normal file
@@ -0,0 +1,258 @@
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use vllm_mlu hijacked kernel
|
||||
'''
|
||||
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
|
||||
lora_token_start_loc, lora_ids, scaling,
|
||||
input_d0_stride, input_d1_stride, lora_d0_stride,
|
||||
lora_d1_stride, lora_d2_stride, output_d0_stride,
|
||||
output_d1_stride, output_d2_stride,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_sk_m_n = tl.program_id(axis=0)
|
||||
pid_sk = pid_sk_m_n % SPLIT_K
|
||||
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
|
||||
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
|
||||
lora_m_indices_start + cta_m_offset)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
SLICE_NUM)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_shrink(
|
||||
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
|
||||
lora_a_weights: list[
|
||||
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
|
||||
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor
|
||||
lora_a_weights (list[torch.Tensor]): LoRA weights
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype == lora_a_weights[0].dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
for weight in lora_a_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check
|
||||
M = inputs.size(0)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||
0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||
NUM_SLICES = len(lora_a_weights)
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
|
||||
# Triton kernel configs
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 256 if M < 128 else 32
|
||||
SPLIT_K = 64 if M < 128 else 8
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only few of the input tokens
|
||||
# require LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_lora_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_strides_d0,
|
||||
lora_strides_d1,
|
||||
lora_strides_d2,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor.stride(2),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
NUM_SLICES,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use only vllm operand
|
||||
'''
|
||||
|
||||
lora_shrink = _lora_shrink
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
238
vllm_mlu/lora/ops/triton_ops/sgmv_expand.py
Normal file
238
vllm_mlu/lora/ops/triton_ops/sgmv_expand.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
|
||||
offset_k[None, :] * xk_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index + \
|
||||
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
|
||||
other=0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: Adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_expand_kernel_mlu
|
||||
'''
|
||||
_sgmv_expand_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
248
vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py
Normal file
248
vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_slice_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||
might be that in the future, we could implement a fusion operator to
|
||||
achieve the current functionality instead of having to call it multiple
|
||||
times.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
|
||||
offset_k[None, :] * xk_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index + \
|
||||
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
|
||||
other=0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||
(slice_offset + N))
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_slice_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: Adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_expand_kernel_mlu
|
||||
'''
|
||||
_sgmv_expand_slice_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
231
vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py
Normal file
231
vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_shrink_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride, # hidden_size
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||
introducing SPLIT-K can improve performance
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_sk = tl.program_id(axis=1)
|
||||
cur_batch = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
|
||||
offset_k[None, :] * xk_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride + \
|
||||
offset_k[:, None] * lora_n_stride
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)),
|
||||
other=0.0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)),
|
||||
other=0.0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += BLOCK_K * SPLIT_K * xk_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_shrink_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 32
|
||||
SPLIT_K = 8
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
SPLIT_K,
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_shrink_kernel_mlu
|
||||
'''
|
||||
_sgmv_shrink_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
41
vllm_mlu/lora/ops/triton_ops/utils.py
Normal file
41
vllm_mlu/lora/ops/triton_ops/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
from math import ceil
|
||||
|
||||
_MLU_MAX_GRID_SIZE = 65536
|
||||
|
||||
def adjust_kernel_block_size(
|
||||
m: int,
|
||||
block_m: int,
|
||||
n: int,
|
||||
block_n: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Adjust block size to meet mlu triton grid restrictions.
|
||||
|
||||
Calculation of the max block size in candidates list:
|
||||
|
||||
LLama3.1-8b-tp1 max n is 14336
|
||||
LLama3.1-70b-tp4 max n is 7168
|
||||
LLama3.1-405b-tp8 max n is 6656
|
||||
|
||||
when n is 14336, the max sequence length of block size 256 can be
|
||||
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
|
||||
"""
|
||||
candidates_list = [16, 32, 64, 96, 128, 192, 256]
|
||||
candidates_list_len = len(candidates_list)
|
||||
m_idx = 1
|
||||
n_idx = 0 if block_n == 16 else 1
|
||||
while m_idx < candidates_list_len and n_idx < candidates_list_len:
|
||||
block_m = candidates_list[m_idx]
|
||||
block_n = candidates_list[n_idx]
|
||||
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
|
||||
break
|
||||
if m_idx < candidates_list_len:
|
||||
m_idx += 1
|
||||
if n_idx < candidates_list_len:
|
||||
n_idx += 1
|
||||
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
|
||||
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
|
||||
return block_m, block_n
|
||||
3
vllm_mlu/lora/punica_wrapper/__init__.py
Normal file
3
vllm_mlu/lora/punica_wrapper/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
89
vllm_mlu/lora/punica_wrapper/punica_mlu.py
Normal file
89
vllm_mlu/lora/punica_wrapper/punica_mlu.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperMLU(PunicaWrapperCPU):
|
||||
"""
|
||||
PunicaWrapperMLU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica triton kernel.
|
||||
"""
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
120
vllm_mlu/mlu_forward_context.py
Normal file
120
vllm_mlu/mlu_forward_context.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from vllm.forward_context import DPMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLUDPMetadata(DPMetadata):
|
||||
# mlu platform arguments
|
||||
# token num for current dp group
|
||||
token_num: int = None
|
||||
# token num offset for current dp group
|
||||
token_num_offset: int = None
|
||||
# whether we can use reduce scatter for both attn layer and mlp layer
|
||||
layer_use_reduce_scatter: bool = False
|
||||
# token num need to be pad for prefill, then we can do reduce scatter +
|
||||
# all gather to optimize comm time
|
||||
prefill_pad_to_token_num: int = -1
|
||||
# token num in each dp group, the list length is attn data parallel size
|
||||
# used to do all gather in dp groups after all reduce in attn
|
||||
token_split_list: List[int] = None
|
||||
# token num in each card, the list length is world size
|
||||
# used to do all gather in all cards after reduce scatter in attn
|
||||
attn_token_split_list_reduce_scatter: List[int] = None
|
||||
# token num in each tp group, the list length is tensor parallel size
|
||||
# used to do all gather in tp groups after reduce scatter in moe
|
||||
moe_token_split_list_reduce_scatter: List[int] = None
|
||||
# prefill or decode stage in each dp group
|
||||
dp_is_prefill: List[bool] = None
|
||||
|
||||
# ADDITIONAL fields for merged compute and communication.
|
||||
# Global sequence lengths for each batch size for prefill stage.
|
||||
seq_lens: List[int] = None
|
||||
# Batch sizes for each attn dp rank for prefill stage.
|
||||
batch_sizes: List[int] = None
|
||||
|
||||
# ADDITIONAL fields for custom split for embedding, logits and dense mlp layer
|
||||
# token num in each emb tp group, the list length is tensor parallel size
|
||||
# used to do all gather in emb tp groups after reduce scatter in moe
|
||||
emb_token_split_list: List[int] = None
|
||||
# batch sizes in each logits tp group, the list length is tensor parallel size
|
||||
# used to do all gather in logits tp groups after reduce scatter in moe
|
||||
logits_batch_split_list: List[int] = None
|
||||
# token num in each dense mlp group, the list length is dense mlp tp size
|
||||
# used to do one more all gather after dense mlp and before reduce scatter
|
||||
dense_attn_token_split_list: List[int] = None
|
||||
|
||||
@staticmethod
|
||||
def make_oot(
|
||||
data_parallel_rank: int,
|
||||
data_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
dp_token_nums: List[int],
|
||||
dp_is_prefill: List[bool],
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
seq_lens: List[int] = None,
|
||||
batch_sizes: List[int] = None,
|
||||
emb_query_lens: List[int] = None,
|
||||
logits_batch_sizes: List[int] = None,
|
||||
dense_attn_token_split_list: List[int] = None,
|
||||
) -> "MLUDPMetadata":
|
||||
token_num_offset = sum(dp_token_nums[:data_parallel_rank])
|
||||
token_num = dp_token_nums[data_parallel_rank]
|
||||
token_split_list = dp_token_nums
|
||||
|
||||
attn_can_use_reduce_scatter = all(
|
||||
(num != 0 and num % tensor_parallel_size == 0)
|
||||
for num in token_split_list
|
||||
)
|
||||
all_split_token_num_equal = all(
|
||||
num == token_split_list[0] for num in token_split_list
|
||||
)
|
||||
layer_can_use_reduce_scatter = (
|
||||
attn_can_use_reduce_scatter and all_split_token_num_equal
|
||||
)
|
||||
|
||||
attn_token_split_list_reduce_scatter = None
|
||||
moe_token_split_list_reduce_scatter = None
|
||||
prefill_pad_to_token_num = -1
|
||||
tp_world_size = data_parallel_size * tensor_parallel_size
|
||||
if layer_can_use_reduce_scatter:
|
||||
attn_token_split_list_reduce_scatter = (
|
||||
[token_split_list[0] // tensor_parallel_size] * tp_world_size
|
||||
)
|
||||
moe_token_split_list_reduce_scatter = (
|
||||
attn_token_split_list_reduce_scatter[:tensor_parallel_size]
|
||||
)
|
||||
elif (
|
||||
prefill_dispatch_use_RS_AG
|
||||
and all(is_prefill for is_prefill in dp_is_prefill)
|
||||
):
|
||||
dp_group_max_token_nums = max(dp_token_nums)
|
||||
prefill_pad_to_token_num = (
|
||||
(dp_group_max_token_nums + tensor_parallel_size - 1)
|
||||
// tensor_parallel_size
|
||||
) * tensor_parallel_size
|
||||
attn_token_split_list_reduce_scatter = (
|
||||
[prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size
|
||||
)
|
||||
|
||||
return MLUDPMetadata(
|
||||
max_tokens_across_dp_cpu=None,
|
||||
num_tokens_across_dp_cpu=None,
|
||||
token_num=token_num,
|
||||
token_num_offset=token_num_offset,
|
||||
token_split_list=token_split_list,
|
||||
layer_use_reduce_scatter=layer_can_use_reduce_scatter,
|
||||
prefill_pad_to_token_num=prefill_pad_to_token_num,
|
||||
attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter,
|
||||
moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter,
|
||||
seq_lens=seq_lens,
|
||||
batch_sizes=batch_sizes,
|
||||
dp_is_prefill=dp_is_prefill,
|
||||
emb_token_split_list=emb_query_lens,
|
||||
logits_batch_split_list=logits_batch_sizes,
|
||||
dense_attn_token_split_list=dense_attn_token_split_list,
|
||||
)
|
||||
79
vllm_mlu/mlu_hijack.py
Normal file
79
vllm_mlu/mlu_hijack.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import importlib.util
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.logger import logger
|
||||
|
||||
|
||||
def is_module_available(module_name):
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
return spec is not None
|
||||
|
||||
def check_environ_compatibility():
|
||||
if is_module_available('apex'):
|
||||
logger.error(f"The `apex` package is currently present in your environment, "
|
||||
f"which may cause model accuracy issues or other problems. It is "
|
||||
f"strongly recommended that you uninstall it before using vLLM.")
|
||||
|
||||
# Check environment compatibility first before applying mlu hijack.
|
||||
check_environ_compatibility()
|
||||
|
||||
logger.info(f"[MLU] Apply Monkey Patch.")
|
||||
|
||||
# Apply v1 hijack
|
||||
import vllm_mlu.v1.engine.core
|
||||
import vllm_mlu.v1.engine.core_client
|
||||
import vllm_mlu.v1.engine.llm_engine
|
||||
import vllm_mlu.v1.engine.async_llm
|
||||
import vllm_mlu.v1.core.sched.scheduler
|
||||
import vllm_mlu.v1.core.single_type_kv_cache_manager
|
||||
import vllm_mlu.v1.core.kv_cache_utils
|
||||
import vllm_mlu.v1.core.kv_cache_manager
|
||||
import vllm_mlu.v1.executor.abstract
|
||||
import vllm_mlu.v1.executor.ray_executor
|
||||
import vllm_mlu.v1.executor.multiproc_executor
|
||||
import vllm_mlu.v1.sample.rejection_sampler
|
||||
import vllm_mlu.v1.worker.lora_model_runner_mixin
|
||||
import vllm_mlu.v1.worker.block_table
|
||||
import vllm_mlu.v1.worker.gpu_input_batch
|
||||
import vllm_mlu.v1.worker.kv_connector_model_runner_mixin
|
||||
import vllm_mlu.v1.attention.backends.gdn_attn
|
||||
import vllm_mlu.v1.attention.backends.mla.flashmla
|
||||
import vllm_mlu.compilation.fix_functionalization
|
||||
|
||||
# Apply common hijack
|
||||
import vllm_mlu.attention.layer
|
||||
import vllm_mlu.benchmarks.datasets
|
||||
import vllm_mlu.config.model
|
||||
import vllm_mlu.config.scheduler
|
||||
import vllm_mlu.config.speculative
|
||||
import vllm_mlu.config.vllm
|
||||
import vllm_mlu.utils
|
||||
import vllm_mlu.distributed.parallel_state
|
||||
import vllm_mlu.distributed.kv_transfer.kv_connector.factory
|
||||
import vllm_mlu.engine.arg_utils
|
||||
import vllm_mlu.entrypoints.llm
|
||||
import vllm_mlu.lora.layers.base_linear
|
||||
import vllm_mlu.lora.layers.row_parallel_linear
|
||||
import vllm_mlu.lora.layers.column_parallel_linear
|
||||
import vllm_mlu.model_executor.parameter
|
||||
import vllm_mlu.model_executor.layers.linear
|
||||
import vllm_mlu.model_executor.layers.rotary_embedding
|
||||
import vllm_mlu.model_executor.layers.quantization.utils.w8a8_utils
|
||||
import vllm_mlu.model_executor.layers.quantization.fp8
|
||||
import vllm_mlu.model_executor.layers.activation
|
||||
import vllm_mlu.model_executor.layers.layernorm
|
||||
import vllm_mlu.model_executor.layers.fused_moe.layer
|
||||
import vllm_mlu.model_executor.model_loader.tensorizer_loader
|
||||
import vllm_mlu.model_executor.models.registry
|
||||
import vllm_mlu.model_executor.models.config
|
||||
import vllm_mlu.multimodal.utils
|
||||
if is_module_available('lmcache'):
|
||||
import vllm_mlu.distributed.kv_transfer.kv_connector.v1.lmcache_connector
|
||||
|
||||
if VLLM_CI_ACCURACY_TEST:
|
||||
import vllm_mlu.model_executor.model_loader.dummy_loader
|
||||
|
||||
if VLLM_SCHEDULER_PROFILE:
|
||||
import vllm_mlu.entrypoints.openai.api_server
|
||||
104
vllm_mlu/mlu_hijack_utils.py
Normal file
104
vllm_mlu/mlu_hijack_utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IS_GATED=False
|
||||
|
||||
class MluHijackObject:
|
||||
hijack_objs = []
|
||||
|
||||
@classmethod
|
||||
def apply_hijack(cls, obj, org_func, hijack_func,
|
||||
verify_orig_func_exists: bool = False):
|
||||
"""
|
||||
Optional Args:
|
||||
verify_orig_func_exists (bool): If True, verifies that hijack succeeds
|
||||
"""
|
||||
cls.hijack_objs.append((obj, org_func, hijack_func))
|
||||
|
||||
if type(org_func) == str:
|
||||
org_func_name = org_func
|
||||
else:
|
||||
if isinstance(org_func, property):
|
||||
split_name = org_func.fget.__name__.split('__')
|
||||
else:
|
||||
split_name = org_func.__name__.split('__')
|
||||
org_func_name = split_name[-1]
|
||||
if org_func_name == "":
|
||||
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
|
||||
org_func_name = split_name[-2] + "__"
|
||||
if len(split_name) >= 3 and split_name[-3] == "":
|
||||
org_func_name = "__" + org_func_name
|
||||
|
||||
if verify_orig_func_exists and not hasattr(obj, org_func_name):
|
||||
raise AttributeError(f"function {org_func_name} is not part of {obj}")
|
||||
|
||||
setattr(obj, org_func_name, hijack_func)
|
||||
|
||||
if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func):
|
||||
raise AttributeError(
|
||||
f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
|
||||
if obj_ and hijack_func_:
|
||||
for obj, org_func, hijack_func in cls.hijack_objs:
|
||||
if obj_ == obj and hijack_func == hijack_func_:
|
||||
if type(org_func) == str:
|
||||
if hasattr(obj, org_func):
|
||||
delattr(obj, org_func)
|
||||
else:
|
||||
org_func_name = org_func.__name__
|
||||
setattr(obj, org_func_name, org_func)
|
||||
return
|
||||
for obj, org_func, hijack_func in cls.hijack_objs:
|
||||
if type(org_func) == str:
|
||||
if hasattr(obj, org_func):
|
||||
delattr(obj, org_func)
|
||||
else:
|
||||
org_func_name = org_func.__name__
|
||||
setattr(obj, org_func_name, org_func)
|
||||
|
||||
|
||||
TypedDict = {
|
||||
"hidden_size": 0,
|
||||
"vocab_size": 0,
|
||||
"ffn_inner_size": 0,
|
||||
"moe_inner_size": 0,
|
||||
"layer_num": 0,
|
||||
"moe_layer_num": 0,
|
||||
"head_num": 0,
|
||||
"head_size": 0,
|
||||
"head_num_kv": 0,
|
||||
"tp_num": 0,
|
||||
"shared_expert_intermediate_size": 0,
|
||||
"shared_experts": 0,
|
||||
"qk_nope_head_dim": 0,
|
||||
"qk_rope_head_dim": 0,
|
||||
"q_lora_rank": 0.0,
|
||||
"num_attention_heads": 0,
|
||||
"kv_lora_rank": 0,
|
||||
"v_head_dim": 0,
|
||||
"use_gated_ffn": False,
|
||||
"experts_num": 0,
|
||||
"topk_num": 0,
|
||||
"use_causal_mask": False,
|
||||
"cla_coeffient": 0,
|
||||
"kv_cache_dtype": "",
|
||||
"smooth_quant_type": "",
|
||||
"data_type": "",
|
||||
"model_type": "",
|
||||
"filter_data_type": "",
|
||||
}
|
||||
|
||||
|
||||
def set_is_gated(flag):
|
||||
global IS_GATED
|
||||
IS_GATED=flag
|
||||
|
||||
def get_is_gated():
|
||||
return IS_GATED
|
||||
412
vllm_mlu/mlu_metric.py
Normal file
412
vllm_mlu/mlu_metric.py
Normal file
@@ -0,0 +1,412 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import time
|
||||
import statistics
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG_WITH_DEVICE_EN, VLLM_DUMP_MLU_INFO_EN
|
||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
millisecond2second_unit = 1000
|
||||
|
||||
class LLMMetric:
|
||||
def __init__(self)->None:
|
||||
self.batch_size_list = []
|
||||
self.context_latency_list = []
|
||||
self.e2e_latency_list = []
|
||||
self.per_token_latency_list = [ [] ]
|
||||
self.per_token_latency_device_list = [ [] ]
|
||||
self.mm_encoder_latency_device_list = [ [] ]
|
||||
self.peak_memory = 0
|
||||
self.block_memory = 0
|
||||
self.num_total_gpu_blocks = 0
|
||||
self.num_total_cpu_blocks = 0
|
||||
self.num_free_gpu_blocks_list = [ [] ]
|
||||
self.num_free_cpu_blocks_list = [ [] ]
|
||||
self.num_spec_tokens = 0
|
||||
self.draft_acceptance_rate = 0.0
|
||||
self.context_latency_device = 0.0
|
||||
self.generate_latency_device = 0.0
|
||||
self.mm_encoder_latency_device = 0.0
|
||||
|
||||
def reset_metric(self):
|
||||
self.batch_size_list = []
|
||||
self.context_latency_list = []
|
||||
self.e2e_latency_list = []
|
||||
self.per_token_latency_list = [ [] ]
|
||||
self.per_token_latency_device_list = [ [] ]
|
||||
self.mm_encoder_latency_device_list = [ [] ]
|
||||
self.num_free_gpu_blocks_list = [ [] ]
|
||||
self.num_free_cpu_blocks_list = [ [] ]
|
||||
self.num_spec_tokens = 0
|
||||
self.draft_acceptance_rate = 0.0
|
||||
|
||||
@classmethod
|
||||
def get_mlu_cost_time(cls):
|
||||
torch.mlu.synchronize()
|
||||
return time.perf_counter()
|
||||
|
||||
def is_prefill_stage(self):
|
||||
return len(self.per_token_latency_list[-1]) == 0
|
||||
|
||||
def update_memory_usage(self, peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks):
|
||||
self.peak_memory = peak_memory
|
||||
self.block_memory = block_memory
|
||||
self.num_total_gpu_blocks = num_total_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_total_cpu_blocks
|
||||
|
||||
def update_step_block_usage(self, num_free_gpu_blocks, num_free_cpu_blocks):
|
||||
self.num_free_gpu_blocks_list[-1].append(num_free_gpu_blocks)
|
||||
self.num_free_cpu_blocks_list[-1].append(num_free_cpu_blocks)
|
||||
|
||||
def update_step_latency(self, step_latency):
|
||||
if isinstance(step_latency, list):
|
||||
self.per_token_latency_list[-1].extend(step_latency)
|
||||
else:
|
||||
self.per_token_latency_list[-1].append(step_latency)
|
||||
|
||||
def update_step_latency_device(self, step_latency):
|
||||
if isinstance(step_latency, list):
|
||||
self.per_token_latency_device_list[-1].extend(step_latency)
|
||||
else:
|
||||
self.per_token_latency_device_list[-1].append(step_latency)
|
||||
|
||||
def update_mm_encoder_latency_device(self, step_latency):
|
||||
if isinstance(step_latency, list):
|
||||
if len(step_latency) == 0:
|
||||
return
|
||||
assert len(step_latency) == 1, f"Not supported! Model with multi mm encoder steps. {len(step_latency)} {step_latency}"
|
||||
self.mm_encoder_latency_device_list[-1].extend(step_latency)
|
||||
else:
|
||||
self.mm_encoder_latency_device_list[-1].append(step_latency)
|
||||
|
||||
def update_spec_decode_metrics(self, spec_decode_metrics):
|
||||
self.num_spec_tokens = spec_decode_metrics.num_spec_tokens
|
||||
self.draft_acceptance_rate = spec_decode_metrics.draft_acceptance_rate
|
||||
|
||||
def add_metrics(self, batch_size, e2e_latency)->None:
|
||||
self.batch_size_list.append(batch_size)
|
||||
self.e2e_latency_list.append(e2e_latency)
|
||||
self.per_token_latency_list.append([]) # new iter
|
||||
self.per_token_latency_device_list.append([])
|
||||
self.mm_encoder_latency_device_list.append([])
|
||||
self.num_free_gpu_blocks_list.append([])
|
||||
self.num_free_cpu_blocks_list.append([])
|
||||
|
||||
def get_weight_dtype_str(self, model_path, model_dtype, quantization) -> str:
|
||||
# get weight dtype based on quantization config if exists
|
||||
if quantization == 'fp8':
|
||||
return quantization
|
||||
if quantization is not None:
|
||||
quant_method = get_quantization_config(quantization)
|
||||
# combine the model path with the quantization config file name
|
||||
quant_config_paths = quant_method.get_config_filenames()
|
||||
# if there are multiple quantization config files, return the first one existed
|
||||
for quant_config_path in quant_config_paths:
|
||||
quant_config_path = os.path.join(model_path, quant_config_path)
|
||||
# check if the quantization config file exists
|
||||
if not os.path.exists(quant_config_path):
|
||||
continue
|
||||
with open(quant_config_path, 'r') as f:
|
||||
quant_config = json.load(f)
|
||||
quant_config = quant_method.from_config(quant_config)
|
||||
# for smoothquant and weightonly, return the quantization name with the weight bits
|
||||
if quantization == "smoothquant" or quantization == ["weightonly"]:
|
||||
return "{}-int{}".format(quant_config.get_name(), quant_config.weight_bits)
|
||||
else:
|
||||
# for other quantization methods, return the quantization name
|
||||
return quant_config.get_name()
|
||||
# if the quantization config file does not exist, just return the quanization name
|
||||
return quant_config_path.get_name()
|
||||
else:
|
||||
# remove the prefix of model dtype from torch config
|
||||
return str(model_dtype).split(".")[-1]
|
||||
|
||||
def to_csv(self, filename: str, show_per_iter=False) -> None:
|
||||
if show_per_iter:
|
||||
df = pd.DataFrame(self.metrics_data)
|
||||
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
|
||||
memory_df = pd.DataFrame(self.memory_metrics_data)
|
||||
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
|
||||
else:
|
||||
df = pd.DataFrame(self.metrics_data)
|
||||
memory_df = pd.DataFrame(self.memory_metrics_data)
|
||||
df_mean = df.mean().round(3)
|
||||
memory_df_mean = memory_df.mean().round(3)
|
||||
header = ["datetime", "model",
|
||||
"weight dtype", self.batch_size_name,
|
||||
]
|
||||
header = header + list(self.mm_kwargs.keys())
|
||||
header = header + ["input len", "output len", "tp",
|
||||
self.context_latency_name, self.per_token_latency_name]
|
||||
data = [datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.model,
|
||||
self.weight_dtype_str, int(self.metrics_data[self.batch_size_name][0])]
|
||||
data = data + [self.mm_kwargs[k] for k in self.mm_kwargs.keys()]
|
||||
data = data + [self.input_len, self.output_len, self.tp,
|
||||
df_mean[self.context_latency_name], df_mean[self.per_token_latency_name]]
|
||||
if self.num_spec_tokens > 0:
|
||||
header += [self.per_step_latency_name]
|
||||
data += [df_mean[self.per_step_latency_name]]
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
if self.is_v1_multimodal:
|
||||
header += [self.mm_encoder_latency_device_name,]
|
||||
data += [df_mean[self.mm_encoder_latency_device_name],]
|
||||
header += [self.context_latency_device_name, self.per_token_latency_device_name]
|
||||
data += [df_mean[self.context_latency_device_name], df_mean[self.per_token_latency_device_name]]
|
||||
header += [self.e2e_latency_name, self.e2e_throughput_name, self.decoder_throughput_name,]
|
||||
if self.num_spec_tokens > 0:
|
||||
header += [self.k_name, self.acceptance_rate_name]
|
||||
header += [self.decode_times_name,
|
||||
self.peak_memory_name, self.block_memory_name]
|
||||
data += [
|
||||
df_mean[self.e2e_latency_name], df_mean[self.e2e_throughput_name], df_mean[self.decoder_throughput_name],
|
||||
]
|
||||
if self.num_spec_tokens > 0:
|
||||
data += [self.num_spec_tokens, df_mean[self.acceptance_rate_name],]
|
||||
data += [df_mean[self.decode_times_name], memory_df_mean[self.peak_memory_name], memory_df_mean[self.block_memory_name],]
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and self.save_hfu_info:
|
||||
header += [self.context_hfu_name, self.decoder_hfu_name, self.decoder_io_efficiency_name]
|
||||
data += [
|
||||
df_mean[self.context_hfu_name], df_mean[self.decoder_hfu_name],
|
||||
df_mean[self.decoder_io_efficiency_name]
|
||||
]
|
||||
data_dict = dict(zip(header, data))
|
||||
df_csv = pd.DataFrame(data_dict, index=[0])
|
||||
append = False
|
||||
if os.path.isfile(filename):
|
||||
try:
|
||||
df_old = pd.read_csv(filename)
|
||||
append = (df_old.columns.tolist() == header)
|
||||
except Exception as e:
|
||||
logger.info(f"Existing {filename} failed to be read and will be overwritten")
|
||||
if append:
|
||||
df_csv.to_csv(filename, mode='a', header=False, index=False)
|
||||
logger.info(f"Metric appended to existing {filename}")
|
||||
else:
|
||||
df_csv.to_csv(filename, index=False)
|
||||
logger.info(f"Metric written to {filename}")
|
||||
|
||||
def calc_metric(self, model, model_dtype, metrics_idx_start, only_average,
|
||||
input_len, output_len, tp_nums, quantization,
|
||||
show_per_iter=False, is_embedding_task=False, mm_kwargs=None,
|
||||
total_prefill_steps=1, num_spec_tokens=0, dp_size=1, hfu_info=None, io_efficiency=0.0) -> None:
|
||||
keep_digits = 2
|
||||
|
||||
def round_fn(data):
|
||||
return round(data, keep_digits)
|
||||
|
||||
metrics_idx_end = len(self.per_token_latency_list) - 1 # without last []
|
||||
idx_range = range(metrics_idx_start, metrics_idx_end)
|
||||
# specify entries to write to csv
|
||||
self.is_v1_multimodal = mm_kwargs
|
||||
self.mm_kwargs = mm_kwargs if mm_kwargs else {} # multimodal args
|
||||
self.batch_size_name = "batch size"
|
||||
self.input_len = input_len
|
||||
self.output_len = output_len
|
||||
self.tp = tp_nums
|
||||
self.dp = dp_size
|
||||
self.model = model
|
||||
self.context_latency_name = "context latency(ms)"
|
||||
self.mm_encoder_latency_device_name = "multimodal encoder latency device(ms)"
|
||||
self.context_latency_device_name = "context latency device(ms)"
|
||||
if num_spec_tokens > 0:
|
||||
self.per_step_latency_name = "per step latency(ms)"
|
||||
self.per_token_latency_device_name = "per step latency device(ms)"
|
||||
else:
|
||||
self.per_token_latency_device_name = "per token latency device(ms)"
|
||||
self.per_token_latency_name = "per token latency(ms)"
|
||||
self.e2e_latency_name = "e2e latency(ms)"
|
||||
self.e2e_throughput_name = "e2e throughput(tokens/s)"
|
||||
self.decoder_throughput_name = "decoder throughput(tokens/s)"
|
||||
self.k_name = "K"
|
||||
self.acceptance_rate_name = "acceptance rate"
|
||||
self.decode_times_name = "decode times"
|
||||
self.weight_dtype_str = self.get_weight_dtype_str(model, model_dtype, quantization)
|
||||
self.num_spec_tokens = num_spec_tokens
|
||||
rate_list=[]
|
||||
rate=0
|
||||
if num_spec_tokens > 0:
|
||||
for i in range(metrics_idx_end):
|
||||
if len(self.per_token_latency_list[i]) - total_prefill_steps == 0:
|
||||
logger.warning("For now output_len is 0, no need mtp info, if you need mtp info, please increase output_len.")
|
||||
rate_list.append(0.0)
|
||||
else:
|
||||
rate_list.append(((self.output_len - 1) / (float)(len(self.per_token_latency_list[i]) - total_prefill_steps) - 1) / num_spec_tokens)
|
||||
rate = statistics.fmean(rate_list[metrics_idx_start: metrics_idx_end])
|
||||
metrics_data = [
|
||||
(
|
||||
self.batch_size_name, [self.dp * int(self.batch_size_list[i]) for i in idx_range]
|
||||
),
|
||||
(
|
||||
self.context_latency_name, [round_fn(millisecond2second_unit * sum(self.per_token_latency_list[i][:total_prefill_steps])) for i in idx_range]
|
||||
),
|
||||
(
|
||||
self.per_token_latency_name, [
|
||||
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
|
||||
round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * (len(self.per_token_latency_list[i]) - total_prefill_steps) / (self.output_len - 1) * millisecond2second_unit) for i in idx_range
|
||||
]
|
||||
),
|
||||
]
|
||||
if num_spec_tokens > 0:
|
||||
metrics_data += [(self.per_step_latency_name, [
|
||||
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
|
||||
round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * millisecond2second_unit) for i in idx_range
|
||||
])]
|
||||
metrics_data += [
|
||||
(
|
||||
self.e2e_latency_name, [round_fn(millisecond2second_unit * self.e2e_latency_list[i]) for i in idx_range]
|
||||
),
|
||||
(
|
||||
self.e2e_throughput_name, [
|
||||
round_fn(self.dp * (output_len / self.e2e_latency_list[i]) * self.batch_size_list[i]) \
|
||||
for i in idx_range
|
||||
]
|
||||
),
|
||||
(
|
||||
self.decoder_throughput_name, [
|
||||
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
|
||||
round_fn(self.dp * ((output_len-1) / sum(self.per_token_latency_list[i][total_prefill_steps:])) * self.batch_size_list[i]) \
|
||||
for i in idx_range
|
||||
]
|
||||
),
|
||||
(
|
||||
self.decode_times_name, [
|
||||
0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
|
||||
len(self.per_token_latency_list[i][total_prefill_steps:]) for i in idx_range
|
||||
]
|
||||
),
|
||||
]
|
||||
if num_spec_tokens > 0:
|
||||
metrics_data.append((self.k_name, num_spec_tokens))
|
||||
metrics_data.append((self.acceptance_rate_name, [rate_list[i] for i in idx_range]))
|
||||
|
||||
insert_latency_device = VLLM_LATENCY_DEBUG_WITH_DEVICE_EN
|
||||
if insert_latency_device:
|
||||
device_item_idx = 3
|
||||
if self.is_v1_multimodal:
|
||||
mm_encoder_latency_device = [round_fn(sum(self.mm_encoder_latency_device_list[i])) for i in idx_range]
|
||||
metrics_data.insert(device_item_idx, (self.mm_encoder_latency_device_name, mm_encoder_latency_device))
|
||||
device_item_idx = device_item_idx + 1
|
||||
context_latency_device = [round_fn(sum(self.per_token_latency_device_list[i][:total_prefill_steps])) for i in idx_range]
|
||||
per_token_latency_device = [0.0 if len(self.per_token_latency_device_list[i]) <= total_prefill_steps else \
|
||||
round_fn(statistics.fmean(self.per_token_latency_device_list[i][total_prefill_steps:])) for i in idx_range]
|
||||
metrics_data.insert(device_item_idx, (self.context_latency_device_name, context_latency_device))
|
||||
metrics_data.insert(device_item_idx + 1, (self.per_token_latency_device_name, per_token_latency_device))
|
||||
|
||||
self.metrics_data = dict(metrics_data)
|
||||
|
||||
# Print
|
||||
df = pd.DataFrame(self.metrics_data)
|
||||
if show_per_iter:
|
||||
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
|
||||
else:
|
||||
df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = df.mean().round(keep_digits)
|
||||
if only_average:
|
||||
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
|
||||
|
||||
df.index.name = 'iter index'
|
||||
df[self.batch_size_name] = df[self.batch_size_name].astype(int)
|
||||
if num_spec_tokens > 0:
|
||||
df[self.k_name] = df[self.k_name].astype(int)
|
||||
|
||||
self.peak_memory_name = "profile memory(GB)"
|
||||
self.block_memory_name = "total cache memory(GB)"
|
||||
memory_metrics_data = [
|
||||
(
|
||||
self.peak_memory_name, [round_fn(self.peak_memory / 1024 / 1024 / 1024) for i in idx_range]
|
||||
),
|
||||
(
|
||||
self.block_memory_name, [round_fn(self.block_memory / 1024 / 1024 / 1024) for i in idx_range]
|
||||
),
|
||||
]
|
||||
|
||||
self.memory_metrics_data = dict(memory_metrics_data)
|
||||
|
||||
# Print
|
||||
memory_df = pd.DataFrame(self.memory_metrics_data)
|
||||
if show_per_iter:
|
||||
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
|
||||
else:
|
||||
memory_df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = memory_df.mean().round(keep_digits)
|
||||
if only_average:
|
||||
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
|
||||
|
||||
memory_df.index.name = 'iter index'
|
||||
|
||||
pd.set_option('display.colheader_justify', 'center')
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.max_rows', None)
|
||||
print("********************************* Test Info****************************")
|
||||
mm_params_text = " ".join(f"{key}:{value}" for key, value in self.mm_kwargs.items())
|
||||
print("Generation Config {} input len:{} output len:{} tp_nums:{} quantization:{}".format(
|
||||
mm_params_text, input_len,output_len,tp_nums,quantization))
|
||||
self.context_latency_device = np.mean(self.metrics_data['context latency device(ms)'])
|
||||
self.generate_latency_device = np.mean(self.metrics_data[self.per_token_latency_device_name])
|
||||
if self.is_v1_multimodal:
|
||||
self.mm_encoder_latency_device = np.mean(self.metrics_data[self.mm_encoder_latency_device_name])
|
||||
|
||||
print("*************************Performance Info******************************")
|
||||
print(f"Total prefill steps: {total_prefill_steps}")
|
||||
print(df.to_string())
|
||||
if not is_embedding_task:
|
||||
# embedding task does not do profile run, so does not have memory infos
|
||||
print(memory_df.to_string())
|
||||
if insert_latency_device :
|
||||
context_latency = np.mean(self.metrics_data['context latency device(ms)'])
|
||||
generate_latency = np.mean(self.metrics_data[self.per_token_latency_device_name])
|
||||
if num_spec_tokens > 0:
|
||||
print("MTP token accept rate: {:.2f}%".format(rate*100))
|
||||
self.dump_performance_info(hfu_info, io_efficiency)
|
||||
|
||||
avg_latency_e2e = sum(sum(self.per_token_latency_list[i]) for i in idx_range) / len(idx_range)
|
||||
print("Avg latency without host time is :", avg_latency_e2e)
|
||||
print("***********************************************************************")
|
||||
# collect context_hfu and
|
||||
self.save_hfu_info = False
|
||||
if insert_latency_device:
|
||||
if VLLM_DUMP_MLU_INFO_EN:
|
||||
try:
|
||||
import device_info
|
||||
self.save_hfu_info = True
|
||||
except:
|
||||
logger.info(f"try import device_info failed. try pip install device_info.")
|
||||
self.context_hfu_name = "Context HFU"
|
||||
self.decoder_hfu_name = "Decoder HFU"
|
||||
self.decoder_io_efficiency_name = "Decoder IO Efficiency"
|
||||
if self.save_hfu_info:
|
||||
self.metrics_data[self.context_hfu_name] = hfu_info["context_hfu"] * 100
|
||||
self.metrics_data[self.decoder_hfu_name] = hfu_info["decoder_hfu"] * 100
|
||||
self.metrics_data[self.decoder_io_efficiency_name] = io_efficiency * 100
|
||||
if csv_path := os.getenv("OUTPUT_CSV_PATH"):
|
||||
try:
|
||||
if dir_path := os.path.dirname(csv_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
self.to_csv(csv_path, show_per_iter=show_per_iter)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid OUTPUT_CSV_PATH: {csv_path} to dump metrics, Error: {e}")
|
||||
|
||||
def dump_performance_info(self, hfu_info, io_efficiency):
|
||||
try:
|
||||
if VLLM_DUMP_MLU_INFO_EN and hfu_info != None:
|
||||
hfu_info["context_hfu"] = hfu_info["context_hfu"] / (self.context_latency_device / millisecond2second_unit)
|
||||
hfu_info["decoder_hfu"] = hfu_info["decoder_hfu"] / (self.generate_latency_device / millisecond2second_unit)
|
||||
io_efficiency = io_efficiency / self.generate_latency_device
|
||||
print(f"Context HFU-visible: {hfu_info['context_hfu']:.3%}")
|
||||
print(f"Decoder HFU-visible: {hfu_info['decoder_hfu']:.3%}")
|
||||
print(f"Decoder IO Efficiency: {io_efficiency:.3%}")
|
||||
elif hfu_info != None:
|
||||
print(f"Context FLOPS-visible: {hfu_info['context_flops']}")
|
||||
print(f"Decoder FLOPS-visible: {hfu_info['decoder_flops']}")
|
||||
else:
|
||||
logger.info("Unsupport dump performance information")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump performance information: {str(e)}")
|
||||
3
vllm_mlu/model_executor/__init__.py
Executable file
3
vllm_mlu/model_executor/__init__.py
Executable file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
3
vllm_mlu/model_executor/layers/__init__.py
Normal file
3
vllm_mlu/model_executor/layers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
25
vllm_mlu/model_executor/layers/activation.py
Normal file
25
vllm_mlu/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
def vllm__model_executor__activation__QuickGELU__forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: implement forward_oot
|
||||
'''
|
||||
return mlu_ops.active(x, 'quick_gelu', False)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(QuickGELU,
|
||||
QuickGELU.forward_oot,
|
||||
vllm__model_executor__activation__QuickGELU__forward_oot)
|
||||
277
vllm_mlu/model_executor/layers/compressor.py
Normal file
277
vllm_mlu/model_executor/layers/compressor.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
from scipy.linalg import hadamard
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
|
||||
def hadamard_transform_ref(x, scale=1.0):
|
||||
"""
|
||||
x: (..., dim)
|
||||
out: (..., dim)
|
||||
"""
|
||||
x_shape = x.shape
|
||||
dim = x.shape[-1]
|
||||
x = x.reshape(-1, dim)
|
||||
log_dim = math.ceil(math.log2(dim))
|
||||
dim_padded = 2 ** log_dim
|
||||
if dim != dim_padded:
|
||||
x = F.pad(x, (0, dim_padded - dim))
|
||||
out = F.linear(
|
||||
x,
|
||||
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
|
||||
)
|
||||
out = out * scale
|
||||
return out[..., :dim].reshape(*x_shape)
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.bfloat16
|
||||
hidden_size = x.size(-1)
|
||||
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
|
||||
|
||||
|
||||
class Compressor(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
rope,
|
||||
compress_ratio: int = 4,
|
||||
head_dim: int = 512,
|
||||
rotate: bool = False,
|
||||
prefix: str = "",
|
||||
**kwargs,):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dim = config.dim
|
||||
self.head_dim = head_dim
|
||||
self.rope_head_dim =config.rope_head_dim
|
||||
self.nope_head_dim = head_dim - config.rope_head_dim
|
||||
self.compress_ratio = compress_ratio
|
||||
self.overlap = compress_ratio == 4
|
||||
self.rotate = rotate
|
||||
coff = 1 + self.overlap
|
||||
self.norm_eps = config.norm_eps
|
||||
self.window_size = config.window_size
|
||||
|
||||
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
||||
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||||
# The first half of dimensions for overlapping compression and second half for normal compression.
|
||||
|
||||
self.wkv = ReplicatedLinear(
|
||||
self.dim,
|
||||
coff * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.float32,
|
||||
prefix=f"{prefix}.wkv",
|
||||
)
|
||||
|
||||
self.wgate = ReplicatedLinear(
|
||||
self.dim,
|
||||
coff * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.float32,
|
||||
prefix=f"{prefix}.wgate",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(self.head_dim, self.norm_eps)
|
||||
|
||||
self.rotary_emb = rope
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
assert hasattr(hf_config, "cached_state_num"), \
|
||||
f"cached_state_num is not set in hf_config"
|
||||
cached_state_num = hf_config.cached_state_num
|
||||
self.register_buffer(
|
||||
"kv_state",
|
||||
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"score_state",
|
||||
torch.full(
|
||||
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
|
||||
float("-inf"),
|
||||
dtype=torch.float32,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.hadamard_matrix = torch.tensor(
|
||||
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
|
||||
|
||||
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
||||
# tensor: [b,s,r,2d]
|
||||
b, s, _, _ = tensor.size()
|
||||
ratio, d = self.compress_ratio, self.head_dim
|
||||
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
||||
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
||||
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
||||
return new_tensor
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
x = x.float()
|
||||
kv_pack, _ = self.wkv(x)
|
||||
score_pack, _ = self.wgate(x)
|
||||
|
||||
mlu_ops.fused_compress_single_kv(
|
||||
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||||
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||||
position=positions,
|
||||
ape=self.ape,
|
||||
kv_state=self.kv_state,
|
||||
score_state=self.score_state,
|
||||
gamma=self.norm.weight,
|
||||
sin=self.rotary_emb.sin_,
|
||||
cos=self.rotary_emb.cos_,
|
||||
hadamard_matrix=self.hadamard_matrix,
|
||||
slot_mapping=compressor_slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_scale=None,
|
||||
eps=self.norm_eps,
|
||||
overlap=self.overlap,
|
||||
rotate=self.rotate,
|
||||
state_idx=batch_to_kv_state,
|
||||
)
|
||||
|
||||
# Here, return fake compressed_kv.
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
forward_func: Callable = (
|
||||
self.forward_prefill if common_metadata.is_prefill_only
|
||||
else self.forward_decode
|
||||
)
|
||||
return forward_func(
|
||||
x,
|
||||
positions,
|
||||
attn_metadata,
|
||||
batch_to_kv_state,
|
||||
kv_cache,
|
||||
window_offset,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
seq_lens = common_metadata.seq_lens
|
||||
query_start_loc = common_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
ratio, overlap = self.compress_ratio, self.overlap
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
kv_pack, _ = self.wkv(x)
|
||||
score_pack, _ = self.wgate(x)
|
||||
|
||||
compress_lens = query_lens // self.compress_ratio
|
||||
cu_compress_lens = torch.cat([
|
||||
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
|
||||
torch.cumsum(compress_lens, dim=0)],
|
||||
)
|
||||
|
||||
compress_positions = []
|
||||
for i in range(len(seq_lens)):
|
||||
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
|
||||
remainder = seqlen % ratio
|
||||
cutoff = seqlen - remainder
|
||||
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
|
||||
positions_ = pos[:cutoff:ratio].contiguous()
|
||||
compress_positions.append(positions_)
|
||||
kv_positions = torch.cat(compress_positions, dim=0)
|
||||
|
||||
|
||||
total_compress_len = cu_compress_lens[-1].item()
|
||||
kv = torch.empty(
|
||||
[total_compress_len, self.head_dim],
|
||||
dtype=kv_pack.dtype,
|
||||
device=kv_pack.device,
|
||||
)
|
||||
|
||||
mlu_ops.fused_compress_multi_kv(
|
||||
kv = kv_pack,
|
||||
score = score_pack,
|
||||
kv_state = self.kv_state,
|
||||
score_state = self.score_state,
|
||||
state_batch_idx = batch_to_kv_state,
|
||||
cu_seqlens = query_start_loc,
|
||||
ape = self.ape,
|
||||
max_seqlen = common_metadata.max_query_len,
|
||||
overlap = overlap,
|
||||
compressed_kv = kv,
|
||||
)
|
||||
|
||||
if kv.size(0) == 0:
|
||||
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
|
||||
|
||||
|
||||
kv = self.norm(kv.to(dtype))
|
||||
|
||||
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
|
||||
# use compressed cu_seqlens here, so can not call rotary_emb directly
|
||||
kv_rope = mlu_ops.rotary_embedding(
|
||||
kv_rope,
|
||||
self.rotary_emb.sin_,
|
||||
self.rotary_emb.cos_,
|
||||
kv_positions,
|
||||
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
|
||||
True, # interleaved
|
||||
True, # discrete
|
||||
False,
|
||||
common_metadata.max_query_len,
|
||||
)
|
||||
|
||||
if self.rotate:
|
||||
kv = rotate_activation(kv)
|
||||
|
||||
mlu_ops.reshape_paged_cache(
|
||||
kv.unsqueeze(1),
|
||||
None,
|
||||
kv_cache,
|
||||
None,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
|
||||
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)
|
||||
85
vllm_mlu/model_executor/layers/dp_logits_processor.py
Normal file
85
vllm_mlu/model_executor/layers/dp_logits_processor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_gather)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm_mlu.model_executor.models.dp_utils import (
|
||||
tensor_model_parallel_all_gather_dp, DataParallelRuntimeParams)
|
||||
|
||||
|
||||
class DPLogitsProcessor(LogitsProcessor):
|
||||
"""DP LogitsProcessor."""
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
dp_params: Optional[DataParallelRuntimeParams] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
batch_sizes = None
|
||||
if (lm_head.tp_group is not None
|
||||
and dp_params is not None
|
||||
and dp_params.logits_batch_split_list is not None):
|
||||
batch_sizes = dp_params.logits_batch_split_list
|
||||
hidden_states = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=batch_sizes,
|
||||
rank=lm_head.tp_rank,
|
||||
hidden_states=hidden_states,
|
||||
group=lm_head.tp_group,
|
||||
)
|
||||
|
||||
logits = lm_head.quant_method.apply(
|
||||
lm_head, hidden_states, bias=embedding_bias)
|
||||
|
||||
if self.use_all_gather:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits, tp_group=lm_head.tp_group)
|
||||
else:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits, tp_group=lm_head.tp_group)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., : self.org_vocab_size]
|
||||
|
||||
if batch_sizes is not None:
|
||||
offset = sum(batch_sizes[:lm_head.tp_rank])
|
||||
logits = logits[offset : offset + batch_sizes[lm_head.tp_rank]]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
dp_params: Optional[DataParallelRuntimeParams] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(
|
||||
hidden_states, lm_head, embedding_bias, dp_params)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.soft_cap
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
|
||||
return logits
|
||||
219
vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py
Normal file
219
vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
method_has_implemented_embedding,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
get_masked_input_and_mask,
|
||||
pad_vocab_size,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_logits_tp_group,
|
||||
get_logits_tp_world_size,
|
||||
get_logits_tp_rank,
|
||||
)
|
||||
from vllm_mlu.model_executor.models.dp_utils import (
|
||||
DataParallelRuntimeParams,
|
||||
tensor_model_parallel_all_gather_dp,
|
||||
)
|
||||
|
||||
|
||||
class DPVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
"""DP Embedding parallelized in the vocabulary dimension."""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank to support other parallel
|
||||
"""
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = None
|
||||
logits_tp_world_size = get_logits_tp_world_size()
|
||||
if logits_tp_world_size != self.tp_world_size:
|
||||
self.tp_group = get_logits_tp_group()
|
||||
self.tp_world_size = logits_tp_world_size
|
||||
self.tp_rank = get_logits_tp_rank()
|
||||
|
||||
# Keep the input dimensions.
|
||||
tp_rank = self.tp_rank
|
||||
self.tp_size = self.tp_world_size
|
||||
"""
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
"""
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
def forward(self, input_,
|
||||
dp_params: Optional[DataParallelRuntimeParams] = None):
|
||||
token_split_list = None
|
||||
if (dp_params is not None
|
||||
and self.tp_group is not None
|
||||
and dp_params.emb_token_split_list is not None):
|
||||
token_split_list = dp_params.emb_token_split_list
|
||||
input_ = tensor_model_parallel_all_gather_dp(
|
||||
group_num_tokens=token_split_list,
|
||||
rank=self.tp_rank,
|
||||
hidden_states=input_.reshape(-1, 1),
|
||||
group=self.tp_group,
|
||||
).reshape(-1)
|
||||
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
|
||||
|
||||
if token_split_list is not None:
|
||||
offset = sum(token_split_list[:self.tp_rank])
|
||||
output = output[offset : offset + token_split_list[self.tp_rank]]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class DPParallelLMHead(DPVocabParallelEmbedding):
|
||||
"""DP Parallelized LM head.
|
||||
|
||||
NOTE: A copy of ParallelLMHead class, and only change its parent
|
||||
from VocabParallelEmbedding to DPVocabParallelEmbedding.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
||||
"""Tie the weights with word embeddings."""
|
||||
# GGUF quantized embed_tokens.
|
||||
if self.quant_config and self.quant_config.get_name() == "gguf":
|
||||
return embed_tokens
|
||||
else:
|
||||
self.weight = embed_tokens.weight
|
||||
return self
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||
224
vllm_mlu/model_executor/layers/feed_forward.py
Executable file
224
vllm_mlu/model_executor/layers/feed_forward.py
Executable file
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Any
|
||||
|
||||
from vllm.distributed import (
|
||||
get_parallel_world_size_with_group,
|
||||
get_parallel_rank_with_group,
|
||||
)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear
|
||||
)
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import set_is_gated
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
bias: bool,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
skip_bias_add: bool = False,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.is_gated = is_gated
|
||||
self.bias = bias
|
||||
self.up_proj_name = up_proj_name
|
||||
self.down_proj_name = down_proj_name
|
||||
self.quant_config = quant_config
|
||||
self.is_initialized = False
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.reduce_results = reduce_results
|
||||
self.use_bt_ffn = True
|
||||
set_is_gated(self.is_gated)
|
||||
|
||||
# modify tp_size, tp_rank and tp_group when enable data parallel
|
||||
self.tp_size = get_parallel_world_size_with_group(tp_group)
|
||||
self.tp_rank = get_parallel_rank_with_group(tp_group)
|
||||
self.tp_group = tp_group
|
||||
self.keep_full_weights = keep_full_weights
|
||||
if self.keep_full_weights:
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
self.tp_group = None
|
||||
|
||||
# up_proj with gate or not
|
||||
if self.is_gated:
|
||||
up_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}",
|
||||
tp_group=self.tp_group,
|
||||
keep_full_weights=keep_full_weights)
|
||||
else:
|
||||
up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}",
|
||||
tp_group=self.tp_group,
|
||||
keep_full_weights=keep_full_weights)
|
||||
self.register_module(up_proj_name, up_proj)
|
||||
|
||||
# down_proj
|
||||
down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
reduce_results=reduce_results,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{down_proj_name}",
|
||||
tp_group=self.tp_group,
|
||||
keep_full_weights=keep_full_weights)
|
||||
|
||||
self.register_module(down_proj_name, down_proj)
|
||||
|
||||
def prepare_weight(self):
|
||||
if not self.is_initialized:
|
||||
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
|
||||
self.alpha = 1.0
|
||||
self.beta = 0.0
|
||||
# place it here to avoid the overhead of calling it in the forward pass
|
||||
self.is_initialized = True
|
||||
|
||||
def _forward(self, hidden_states):
|
||||
self.prepare_weight()
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
act_dict = {
|
||||
"relu": F.relu,
|
||||
"gelu": F.gelu,
|
||||
"silu": F.silu,
|
||||
}
|
||||
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
|
||||
if self.is_gated:
|
||||
d = fc1.shape[-1] // 2
|
||||
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
|
||||
else:
|
||||
fc1 = act_dict[self.hidden_act](fc1)
|
||||
fc2 = F.linear(fc1, down_proj.weight, bias=None)
|
||||
fc2 = tensor_model_parallel_all_reduce(fc2)
|
||||
if not self.skip_bias_add:
|
||||
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
|
||||
return fc2
|
||||
|
||||
def forward_naive(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None
|
||||
):
|
||||
'''
|
||||
used by quant_tools
|
||||
'''
|
||||
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
|
||||
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
|
||||
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
fc1, bias = up_proj(hidden_states)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(fc1, residual=residual_)
|
||||
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
):
|
||||
self.prepare_weight()
|
||||
|
||||
if self.use_bt_ffn is False:
|
||||
return self.forward_naive(hidden_states, residual, None)
|
||||
|
||||
up_proj = getattr(self, self.up_proj_name)
|
||||
down_proj = getattr(self, self.down_proj_name)
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
|
||||
and not isinstance(down_proj, BaseLayerWithLoRA)):
|
||||
# The matmul formula is the following:
|
||||
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
|
||||
# output = active(mul_out)
|
||||
# Notes: We cannot use the activation function in matmul because it does not support gated operation
|
||||
# we might support its in tmo matmul in the future
|
||||
up_proj_weight = up_proj.weight
|
||||
down_proj_weight = down_proj.weight
|
||||
if self.keep_full_weights and use_tp_weight:
|
||||
up_proj_weight = up_proj.tp_weight
|
||||
down_proj_weight = down_proj.tp_weight
|
||||
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
|
||||
None, 'none', self.alpha, self.beta)
|
||||
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
|
||||
beta = 0.0
|
||||
if residual_ is not None:
|
||||
beta = 1.0
|
||||
residual_ = residual_.view(-1, residual_.shape[-1])
|
||||
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
|
||||
# bias if existed need to add after second matmul according to the original design of vllm
|
||||
if self.reduce_results:
|
||||
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
|
||||
else:
|
||||
out = out_
|
||||
# do the bias add if needed
|
||||
if not self.skip_bias_add:
|
||||
out = out + down_proj.bias if down_proj.bias is not None else out
|
||||
else:
|
||||
return out, down_proj.bias
|
||||
else:
|
||||
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
|
||||
if bias is not None:
|
||||
fc1 += bias
|
||||
input_scale= None
|
||||
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
|
||||
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
|
||||
down_proj.quant_method.skip_quant_input = True
|
||||
down_proj_smooth = down_proj.smooth
|
||||
if self.keep_full_weights and use_tp_weight:
|
||||
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
|
||||
down_proj_smooth = down_proj.tp_smooth
|
||||
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
|
||||
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
|
||||
else:
|
||||
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
|
||||
out, bias = down_proj(
|
||||
fc1, residual=residual_, smooth_quant_scale=input_scale,
|
||||
use_tp_weight=use_tp_weight, output=output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
return out, bias
|
||||
return out
|
||||
3
vllm_mlu/model_executor/layers/fused_moe/__init__.py
Normal file
3
vllm_mlu/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
935
vllm_mlu/model_executor/layers/fused_moe/fused_moe.py
Normal file
935
vllm_mlu/model_executor/layers/fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,935 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_moe_kernel_gptq_awq,
|
||||
write_zeros_to_output,
|
||||
get_default_config,
|
||||
try_get_optimal_moe_config,
|
||||
_get_config_quant_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
activation_without_mul,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
import vllm_mlu._mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_bias_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bbe, # bias expert stride
|
||||
stride_bbn, # bias N stride
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Split the program ID into two dimensions (pid_0 and pid_1)
|
||||
'''
|
||||
pid_0 = tl.program_id(axis=0)
|
||||
pid_1 = tl.program_id(axis=1)
|
||||
pid = pid_1 * tl.num_programs(axis=0) + pid_0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(
|
||||
c_ptr,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
pid_n,
|
||||
N,
|
||||
offs_token,
|
||||
token_mask,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
compute_type,
|
||||
)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
# block-wise
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
)
|
||||
# channel-wise
|
||||
elif per_channel_quant:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
# Load per-token scale for activations
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
||||
# tensor-wise
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
if HAS_BIAS:
|
||||
# bias shape: [num_experts, N]
|
||||
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
if use_fp8_w8a8:
|
||||
# acc used to enable fp8_fast_accum
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-2), block_shape[0]
|
||||
) == B_scale.size(-2)
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-1), block_shape[1]
|
||||
) == B_scale.size(-1)
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Split the program ID into two dimensions (pid_0, pid_1)
|
||||
'''
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
|
||||
assert not (use_int8_w8a16 or use_int4_w4a16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=num_tokens,
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.size(0),
|
||||
bit=4 if use_int4_w4a16 else 8,
|
||||
)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(
|
||||
config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"],
|
||||
)
|
||||
)
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(
|
||||
A,
|
||||
C,
|
||||
B,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
bit,
|
||||
)
|
||||
return
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
else:
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
True,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
ocp_mx_scheme,
|
||||
per_channel_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts_mlu",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
tags=(
|
||||
()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
),
|
||||
)
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
return torch.ops.vllm.outplace_fused_experts_mlu(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
elif ocp_mx_scheme is not None:
|
||||
if ocp_mx_scheme in {
|
||||
"w_mxfp4_a_mxfp4",
|
||||
"w_mxfp4_a_mxfp6_e3m2",
|
||||
"w_mxfp4_a_mxfp6_e2m3",
|
||||
}:
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
|
||||
elif ocp_mx_scheme in {
|
||||
"w_mxfp6_e3m2_a_mxfp6_e3m2",
|
||||
"w_mxfp6_e2m3_a_mxfp6_e2m3",
|
||||
}:
|
||||
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
|
||||
"hidden size mismatch"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
else:
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
|
||||
)
|
||||
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.size(1)
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
||||
# quantized prior to calling fused_experts.
|
||||
quant_dtype = _get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Only use the default config
|
||||
'''
|
||||
config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1],
|
||||
hidden_states.dtype, block_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
cache13 = torch.empty(
|
||||
M * top_k_num * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace and not disable_inplace():
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
# and for which we have a native OCP mx fused MOE kernel,
|
||||
# this dequantization step should not be done.
|
||||
if ocp_mx_scheme in {
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
|
||||
}:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (
|
||||
chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||
)
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
# Adjust the intermediate cache size and config for the last
|
||||
# chunk. Note that in most cases we only have one chunk
|
||||
# so the cache size and config are already set correctly and
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(
|
||||
curr_hidden_states, a1_scale, block_shape)
|
||||
else:
|
||||
qcurr_hidden_states = curr_hidden_states
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Activate by mlu_ops
|
||||
'''
|
||||
intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N),
|
||||
act_mode=activation,
|
||||
is_gated=True)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace moe_sum with torch.sum
|
||||
Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513
|
||||
'''
|
||||
if topk_ids.shape[1] == 2:
|
||||
torch.add(
|
||||
intermediate_cache3[:, 0],
|
||||
intermediate_cache3[:, 1],
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
).squeeze(dim=1)
|
||||
elif topk_ids.shape[1] > 2:
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return out_hidden_states
|
||||
|
||||
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
|
||||
|
||||
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
#TODO: support `routed_scaling_factor`
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
|
||||
)
|
||||
use_fused_kernel = topk_group is None
|
||||
if use_fused_kernel:
|
||||
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
|
||||
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
|
||||
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
|
||||
return mlu_ops.fused_moe(
|
||||
x,
|
||||
router_logits,
|
||||
layer.w13_weight, layer.w2_weight,
|
||||
None, None, # bias1, bias2
|
||||
None, # residual
|
||||
None, # input_smooth
|
||||
None, # act_smooth
|
||||
None, None, # w1_scale, w2_scale
|
||||
top_k,
|
||||
renormalize,
|
||||
True, # gated
|
||||
activation
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert expert_map is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
UnquantizedFusedMoEMethod,
|
||||
UnquantizedFusedMoEMethod.forward_oot,
|
||||
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
|
||||
)
|
||||
248
vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
248
vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Implementation of moe_align_block_size_triton.
|
||||
Note: the implemtentation has been removed from vllm since the
|
||||
cuda implementation is more efficient.
|
||||
'''
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = cdiv(numel, num_experts)
|
||||
sorted_token_ids.fill_(numel)
|
||||
expert_ids.zero_()
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
to their allocated expert.
|
||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||
- num_tokens_post_padded: The total number of tokens after padding,
|
||||
ensuring divisibility by block_size.
|
||||
|
||||
This function pads the number of tokens that each expert needs to process
|
||||
so that it is divisible by block_size.
|
||||
Padding ensures that during block matrix multiplication, the dimensions
|
||||
align correctly.
|
||||
|
||||
Example:
|
||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||
block_size = 4, and num_experts = 4:
|
||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||
with each expert needing to process 3 tokens.
|
||||
- As block_size is 4, we pad 1 token for each expert.
|
||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||
- After sorting by expert index, we obtain token_ids
|
||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||
Tokens 12 are non-existent (padding) and are ignored in
|
||||
the subsequent matrix multiplication.
|
||||
- The padding ensures that the total number of tokens is now divisible
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Only use triton to implement moe_align_block_size
|
||||
'''
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
31
vllm_mlu/model_executor/layers/fused_moe/utils.py
Normal file
31
vllm_mlu/model_executor/layers/fused_moe/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from math import prod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
assert block_shape is not None
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
278
vllm_mlu/model_executor/layers/indexer.py
Normal file
278
vllm_mlu/model_executor/layers/indexer.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.compressor import (
|
||||
Compressor,
|
||||
rotate_activation,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Indexer(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rope,
|
||||
compress_ratio: int = 4,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dim = config.dim
|
||||
self.n_heads = config.index_n_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_local_heads = config.index_n_heads // self.tp_size
|
||||
self.head_dim = config.index_head_dim
|
||||
self.rope_head_dim = config.rope_head_dim
|
||||
self.index_topk = config.index_topk
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.window_size = config.window_size
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.bfloat16,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_dim ** -0.5
|
||||
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
|
||||
self.compress_ratio = compress_ratio
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.rotary_emb = rope
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
|
||||
|
||||
self.freqs_cis = None
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
k_full: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
):
|
||||
assert attn_metadata.prefill.chunked_context is None, \
|
||||
f"Prefill chunked context is not supported."
|
||||
|
||||
query_start_loc = attn_metadata.prefill.query_start_loc
|
||||
cu_seq_q_lens = query_start_loc
|
||||
cu_seq_k_lens = torch.zeros(
|
||||
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
|
||||
)
|
||||
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
|
||||
attn_metadata.prefill.query_start_loc
|
||||
seq_lens = torch.diff(cu_seq_k_lens)
|
||||
|
||||
batch_size = seq_lens.shape[0]
|
||||
|
||||
new_block_tables = torch.empty(
|
||||
[attn_metadata.num_prefill_tokens, self.index_topk],
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
new_context_lens = torch.empty(
|
||||
[attn_metadata.num_prefill_tokens],
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
|
||||
max_seq_len = q_seq_lens.max().item()
|
||||
batch_size = q_seq_lens.size(0)
|
||||
max_compressed_kv_len = max_seq_len // self.compress_ratio
|
||||
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
# The layout of linear kv is as follows:
|
||||
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
|
||||
for i in range(batch_size):
|
||||
start = cu_seq_k_lens[i].item()
|
||||
kv_cache_block_table[i] = torch.arange(
|
||||
start, start + max_compressed_kv_len,
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
# offset total origin_kv len
|
||||
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
|
||||
|
||||
# query: (tokens, index_head, index_head_dim)
|
||||
# k_full: (tokens, index_head_dim)
|
||||
# weights: (tokens, index_head, 1)
|
||||
mlu_ops.masked_indexer_select_paged_kv_prefill(
|
||||
query=q,
|
||||
key_value=k_full,
|
||||
weights=weights.unsqueeze(-1),
|
||||
kv_cache_block_table=kv_cache_block_table,
|
||||
cu_seq_q_lens=cu_seq_q_lens,
|
||||
cu_seq_k_lens=cu_seq_k_lens,
|
||||
index_topk=self.index_topk,
|
||||
kv_cache_block_size=self.block_size,
|
||||
softmax_scale=self.merged_softmax_scale,
|
||||
q_scale=None,
|
||||
k_scale_cache=None,
|
||||
sparse_block_table=new_block_tables,
|
||||
sparse_context_lens=new_context_lens,
|
||||
compress_ratio=self.compress_ratio,
|
||||
kv_cache_block_table_offset=None,
|
||||
)
|
||||
|
||||
return new_block_tables, new_context_lens
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
block_table = attn_metadata.decode.block_table
|
||||
batch_size = block_table.shape[0]
|
||||
seq_len = x.shape[0] // batch_size
|
||||
q = q.view(batch_size, seq_len, *q.shape[1:])
|
||||
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
|
||||
|
||||
seq_lens = attn_metadata.decode.seq_lens
|
||||
k_block_table = block_table
|
||||
|
||||
seq_len = x.shape[0] // batch_size
|
||||
new_block_tables = torch.empty(
|
||||
[batch_size, seq_len, self.index_topk],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
new_context_lens = torch.empty(
|
||||
[attn_metadata.num_decode_tokens],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
|
||||
kv_cache_block_table_offset=torch.empty(
|
||||
[attn_metadata.num_decode_tokens],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
kv_cache_block_table_offset.fill_(self.window_size)
|
||||
mlu_ops.masked_indexer_select_paged_kv_decode(
|
||||
query=q,
|
||||
k_cache=k_cache,
|
||||
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
|
||||
kv_cache_block_table=block_table,
|
||||
k_context_lens=seq_lens // self.compress_ratio,
|
||||
k_cache_block_table=k_block_table,
|
||||
index_topk=self.index_topk,
|
||||
kv_cache_block_size=self.block_size,
|
||||
softmax_scale=self.merged_softmax_scale,
|
||||
q_scale=None,
|
||||
k_scale_cache=None,
|
||||
sparse_block_table=new_block_tables,
|
||||
sparse_context_lens=new_context_lens,
|
||||
compress_ratio=self.compress_ratio,
|
||||
kv_cache_block_table_offset=kv_cache_block_table_offset,
|
||||
)
|
||||
|
||||
# [batch, seq_q, index_topk] -> [batch, index_topk]
|
||||
new_block_tables = new_block_tables.squeeze(1)
|
||||
return new_block_tables, new_context_lens
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
indexer_kv_cache: torch.Tensor,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
query_start_loc = common_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
rd = self.rope_head_dim
|
||||
q = self.wq_b(qr)[0]
|
||||
q = q.unflatten(-1, (self.n_heads, self.head_dim))
|
||||
|
||||
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
|
||||
q_pack = rotate_activation(q)
|
||||
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
compressed_kv = self.compressor(
|
||||
x,
|
||||
positions,
|
||||
attn_metadata,
|
||||
batch_to_kv_state,
|
||||
indexer_kv_cache,
|
||||
0,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
if attn_metadata.prefill:
|
||||
assert compressed_kv is not None and compressed_kv.dim() == 3
|
||||
compressed_kv = compressed_kv.squeeze(-2)
|
||||
compressed_context_lens = query_lens // self.compress_ratio
|
||||
|
||||
prefill_q = q_pack[num_decode_tokens:, ...]
|
||||
prefill_weights = weights_pack[num_decode_tokens:, ...]
|
||||
prefill_block_tables, prefill_context_lens = self.forward_prefill(
|
||||
prefill_q,
|
||||
indexer_kv_cache,
|
||||
prefill_weights,
|
||||
attn_metadata,
|
||||
compressed_kv,
|
||||
compressed_context_lens,
|
||||
)
|
||||
|
||||
if attn_metadata.decode:
|
||||
decode_x = x[:num_decode_tokens, ...]
|
||||
decode_q = q_pack[:num_decode_tokens, ...]
|
||||
decode_weights = weights_pack[attn_metadata.num_prefills:]
|
||||
decode_block_tables, decode_context_lens = self.forward_decode(
|
||||
decode_q,
|
||||
decode_x,
|
||||
indexer_kv_cache,
|
||||
decode_weights,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if attn_metadata.prefill and attn_metadata.decode:
|
||||
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
|
||||
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
|
||||
elif attn_metadata.prefill:
|
||||
new_block_tables = prefill_block_tables
|
||||
new_context_lens = prefill_context_lens
|
||||
else:
|
||||
new_block_tables = decode_block_tables
|
||||
new_context_lens = decode_context_lens
|
||||
return new_block_tables, new_context_lens
|
||||
130
vllm_mlu/model_executor/layers/layernorm.py
Normal file
130
vllm_mlu/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
|
||||
|
||||
|
||||
@CustomOp.register("quant_fusion_rms_norm")
|
||||
class QuantFusionRMSNorm(RMSNorm):
|
||||
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
||||
super().__init__(hidden_size, variance_epsilon)
|
||||
assert not isinstance(
|
||||
proj.quant_method, UnquantizedLinearMethod
|
||||
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
||||
proj.quant_method.skip_quant_input = True
|
||||
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
||||
quant_scale = proj.smooth.data
|
||||
else:
|
||||
quant_scale = proj.scale_to_int.data
|
||||
self.dynamic_quant = dynamic_quant
|
||||
self.quant_scale = torch.nn.Parameter(quant_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
return mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
False,
|
||||
self.quant_scale.data,
|
||||
self.dynamic_quant,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("quant_fusion_layer_norm")
|
||||
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
|
||||
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
||||
super().__init__(hidden_size, variance_epsilon)
|
||||
assert not isinstance(
|
||||
proj.quant_method, UnquantizedLinearMethod
|
||||
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
||||
proj.quant_method.skip_quant_input = True
|
||||
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
||||
quant_scale = proj.smooth.data
|
||||
else:
|
||||
quant_scale = proj.scale_to_int.data
|
||||
self.dynamic_quant = dynamic_quant
|
||||
self.quant_scale = torch.nn.Parameter(quant_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
bias = None if self.bias is None else self.bias.data
|
||||
return mlu_ops.fused_layer_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
bias,
|
||||
None,
|
||||
self.eps,
|
||||
False,
|
||||
self.quant_scale.data,
|
||||
self.dynamic_quant,
|
||||
)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
|
||||
org_shape = x.shape
|
||||
x = x.reshape(-1, self.weight.data.shape[0])
|
||||
if out is not None:
|
||||
out = out.view(-1, self.weight.data.shape[0])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, self.weight.data.shape[0])
|
||||
x = mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
True,
|
||||
out=out,
|
||||
)
|
||||
else:
|
||||
x = mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
False,
|
||||
out=out,
|
||||
)
|
||||
|
||||
if out is not None:
|
||||
return x
|
||||
|
||||
if residual is None:
|
||||
assert isinstance(x, torch.Tensor)
|
||||
return x.view(org_shape)
|
||||
|
||||
assert isinstance(x, tuple)
|
||||
assert len(x) == 2
|
||||
return x[0].view(org_shape), x[1].view(org_shape)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RMSNorm,
|
||||
RMSNorm.forward_oot,
|
||||
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
|
||||
)
|
||||
693
vllm_mlu/model_executor/layers/linear.py
Normal file
693
vllm_mlu/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,693 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (divide, split_tensor_along_last_dim,
|
||||
get_parallel_rank_with_group, get_parallel_world_size_with_group,
|
||||
get_tp_world_group, get_tp_world_world_size, get_tp_world_rank)
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase,
|
||||
ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED.extend([
|
||||
"GPTQMluLinearMethod",
|
||||
"AWQMluLinearMethod"
|
||||
])
|
||||
|
||||
vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual parameter.
|
||||
@brief: dispatch unquantized_gemm to mlu ops.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
beta = 0.0
|
||||
if residual is not None:
|
||||
beta = 1.0
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
||||
return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]),
|
||||
layer.weight,
|
||||
bias, residual, 'none', 1.0, beta).view(res_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group and keep_full_weights parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__LinearBase____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
vllm__module_executor__layers__linear__LinearBase____init__org(
|
||||
self=self,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel
|
||||
'''
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
|
||||
self.tp_size = self.tp_world_size
|
||||
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
|
||||
|
||||
self.keep_full_weights = keep_full_weights
|
||||
if self.keep_full_weights or disable_tp:
|
||||
self.tp_group = None
|
||||
self.tp_world_size = 1
|
||||
self.tp_size = self.tp_world_size
|
||||
self.tp_rank = 0
|
||||
self.tp_world_size_org = get_tp_world_world_size()
|
||||
self.tp_rank_org = get_tp_world_rank()
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group and keep_full_weights parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__ColumnParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
output_sizes: list[int] | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group,
|
||||
keep_full_weights=keep_full_weights,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__
|
||||
'''
|
||||
# Divide the weight matrix along the last dimension.
|
||||
# self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
|
||||
# self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size) for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group in create_weights
|
||||
'''
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2
|
||||
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
|
||||
else self.weight_loader
|
||||
),
|
||||
tp_group=self.tp_group,
|
||||
)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition, dtype=params_dtype)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add smooth_quant_scale and use_tp_weight parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add input_scale and use_tp_weight parameter.
|
||||
'''
|
||||
kwargs = {'bias': bias}
|
||||
if use_tp_weight:
|
||||
kwargs['use_tp_weight'] = use_tp_weight
|
||||
if smooth_quant_scale is not None:
|
||||
kwargs['input_scale'] = smooth_quant_scale
|
||||
output_parallel = self.quant_method.apply(self, input_, **kwargs)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.gather_output and self.tp_size > 1:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group param to tensor_model_parallel_all_gather
|
||||
'''
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group and keep_full_weights parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: checkout output_sizes after init to get self.tp_world_size
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
super(MergedColumnParallelLinear, self).__init__(
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
output_sizes=self.output_sizes,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group,
|
||||
keep_full_weights=keep_full_weights,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
|
||||
|
||||
if self.keep_full_weights:
|
||||
tp_size = self.tp_world_size_org
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
out_dim, in_dim = self.weight.shape
|
||||
out_dim_tp = divide(out_dim, tp_size)
|
||||
self.tp_weight = Parameter(
|
||||
self.weight.data.new_empty((out_dim_tp, in_dim)),
|
||||
requires_grad=False,
|
||||
)
|
||||
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
|
||||
and quant_config.input_quant_method == "per_token"):
|
||||
out_dim, in_dim = self.qweight.shape
|
||||
out_dim_tp = divide(out_dim, tp_size)
|
||||
self.tp_qweight = Parameter(
|
||||
self.qweight.data.new_empty((out_dim_tp, in_dim)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.tp_per_channel_scale = Parameter(
|
||||
self.per_channel_scale.data.new_empty((out_dim_tp)),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: int | None = None,
|
||||
):
|
||||
loaded_weight_orig = loaded_weight
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org(
|
||||
self=self,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
loaded_shard_id=loaded_shard_id,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
# load into tp weight
|
||||
if self.keep_full_weights:
|
||||
tp_size = self.tp_world_size_org
|
||||
tp_rank = self.tp_rank_org
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
|
||||
tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size)
|
||||
tp_weight_shard.copy_(tp_weight)
|
||||
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
|
||||
if output_dim is None:
|
||||
return
|
||||
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
|
||||
if loaded_weight_orig.ndim == 1:
|
||||
tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size)
|
||||
elif loaded_weight_orig.ndim == 2:
|
||||
tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size)
|
||||
else:
|
||||
raise ValueError("only support rank 1 and 2 when using tp_weight")
|
||||
|
||||
tp_weight_shard.copy_(tp_weight)
|
||||
else:
|
||||
raise TypeError(f"quant method is expected to be either unquantized or smoothquant")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
tp_group: Any = None,
|
||||
keep_full_weights: bool = False,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group,
|
||||
keep_full_weights=keep_full_weights,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
# Divide the weight matrix along the last dimension
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group in create_weights
|
||||
'''
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2
|
||||
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
|
||||
else self.weight_loader
|
||||
),
|
||||
tp_group=self.tp_group,
|
||||
)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
if self.keep_full_weights:
|
||||
tp_size = self.tp_world_size_org
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
out_dim, in_dim = self.weight.data.shape
|
||||
in_dim_tp = divide(in_dim, tp_size)
|
||||
self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)),
|
||||
requires_grad=False)
|
||||
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
|
||||
and quant_config.input_quant_method == "per_token"):
|
||||
out_dim, in_dim = self.qweight.data.shape
|
||||
in_dim_tp = divide(in_dim, tp_size)
|
||||
self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)),
|
||||
requires_grad=False)
|
||||
if hasattr(self, "smooth"):
|
||||
assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor"
|
||||
dim = self.smooth.shape[0]
|
||||
dim_tp = divide(dim, tp_size)
|
||||
self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)),
|
||||
requires_grad=False)
|
||||
else:
|
||||
raise TypeError("quant method expected to be unquantized or smoothquant per-token")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
self.update_param_tp_status()
|
||||
|
||||
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader(
|
||||
self, param: Parameter, loaded_weight: torch.Tensor
|
||||
):
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
loaded_weight_orig = loaded_weight
|
||||
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org(
|
||||
self=self,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add keep_full_weights for dp parallelize shared expert
|
||||
'''
|
||||
if self.keep_full_weights:
|
||||
if input_dim is None:
|
||||
return
|
||||
tp_size = self.tp_world_size_org
|
||||
tp_rank = self.tp_rank_org
|
||||
shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
if isinstance(self.quant_method, UnquantizedLinearMethod):
|
||||
shard_view = self.weight.narrow(input_dim, start_idx, shard_size)
|
||||
self.tp_weight.copy_(shard_view)
|
||||
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
|
||||
if loaded_weight_orig.ndim == 1:
|
||||
shard_view = self.smooth.narrow(input_dim, start_idx, shard_size)
|
||||
self.tp_smooth.copy_(shard_view)
|
||||
elif loaded_weight_orig.ndim == 2:
|
||||
shard_view = self.qweight.narrow(input_dim, start_idx, shard_size)
|
||||
self.tp_qweight.copy_(shard_view)
|
||||
else:
|
||||
raise ValueError("only rank 1 and 2 is supported for tp_weight")
|
||||
else:
|
||||
raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant")
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters.
|
||||
'''
|
||||
def vllm__module_executor__layers__linear__RowParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add additional matmul parameters.
|
||||
'''
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
kwargs = {'bias': bias_, 'residual': residual_}
|
||||
if use_tp_weight:
|
||||
kwargs['use_tp_weight'] = use_tp_weight
|
||||
if smooth_quant_scale is not None:
|
||||
kwargs['input_scale'] = smooth_quant_scale
|
||||
if output is not None:
|
||||
kwargs['output'] = output
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, **kwargs)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
|
||||
'''
|
||||
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
|
||||
UnquantizedLinearMethod.apply,
|
||||
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
|
||||
MluHijackObject.apply_hijack(LinearBase,
|
||||
LinearBase.__init__,
|
||||
vllm__module_executor__layers__linear__LinearBase____init__)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.__init__,
|
||||
vllm__module_executor__layers__linear__ColumnParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.forward,
|
||||
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.__init__,
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.weight_loader,
|
||||
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.__init__,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.weight_loader,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward,
|
||||
vllm__module_executor__layers__linear__RowParallelLinear__forward)
|
||||
744
vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py
Normal file
744
vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
"""Inference-only MOE model."""
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.distributed.parallel_state import(
|
||||
cnclep_dispatch, cnclep_combine)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
class LongCatSparseMoeMlp(SparseMoeMlp):
|
||||
"""
|
||||
sparse moe mlp layer specific to longcat model
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
has_bias: bool,
|
||||
skip_bias_add: bool = False,
|
||||
renormalize:bool = False,
|
||||
hidden_act: str = "silu",
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_use_fused_moe: bool = False,
|
||||
expert_group: Optional[int] = 1,
|
||||
topk_group: Optional[int] = 1,
|
||||
scoring_func: str = "softmax",
|
||||
topk_method: str = "",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
tp_group: Any = None,
|
||||
use_all2all: bool = False,
|
||||
num_zero_experts: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
up_proj_name=up_proj_name,
|
||||
is_gated=is_gated,
|
||||
down_proj_name=down_proj_name,
|
||||
has_bias=has_bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
renormalize=renormalize,
|
||||
hidden_act=hidden_act,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=is_use_fused_moe,
|
||||
expert_group=expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
topk_method=topk_method,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
tp_group=tp_group,
|
||||
use_all2all=use_all2all,
|
||||
init_avg_moe=False,
|
||||
)
|
||||
self.num_zero_experts = num_zero_experts
|
||||
self.total_experts_including_zero = self.num_total_experts + self.num_zero_experts
|
||||
self.use_quant_all2all = use_all2all and quant_config is not None
|
||||
self.zero_expert_size = divide(self.num_zero_experts, self.moe_ep_size)
|
||||
self.start_zero_expert_id = (
|
||||
self.num_total_experts + self.moe_ep_rank * ((self.num_zero_experts + self.moe_ep_size - 1) // self.moe_ep_size)
|
||||
)
|
||||
|
||||
if VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg:
|
||||
n_tokens = SparseMoeMlp.max_batched_token * self.dp_size
|
||||
expert_group = self.moe_ep_size
|
||||
val = 1.0 / float(self.total_experts_including_zero)
|
||||
SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32)
|
||||
if VLLM_RANDOM_MOE_EN:
|
||||
import numpy as np
|
||||
# example deepseekv2: experts 160 topk 6
|
||||
# avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,......
|
||||
array = np.stack([np.random.permutation(self.total_experts_including_zero)[:top_k] for _ in range(n_tokens)])
|
||||
table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32)
|
||||
else:
|
||||
# example deepseekv2: experts 160
|
||||
# avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,......
|
||||
import math
|
||||
batch_table = math.ceil(n_tokens * top_k / self.total_experts_including_zero) * self.total_experts_including_zero
|
||||
hi_val = batch_table // self.total_experts_including_zero
|
||||
table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view(
|
||||
hi_val, expert_group, num_experts // expert_group).transpose(1, 2)
|
||||
if self.num_zero_experts > 0:
|
||||
# Longcat model, for avg expert, we choose eight non-zero experts and four zero
|
||||
# experts for each token accorrding to the paper.
|
||||
assert num_experts == 512 and num_zero_experts == 256 and top_k == 12
|
||||
assert num_zero_experts % expert_group == 0
|
||||
non_zero_expert_num_per_token = 8
|
||||
zero_expert_num_per_token = 4
|
||||
zero_expert_table = torch.arange(
|
||||
num_experts, num_experts + num_zero_experts, dtype=table.dtype, device=table.device).view(
|
||||
expert_group, num_zero_experts // expert_group).transpose(0, 1).flatten()
|
||||
non_zero_expert_table = table[0].flatten()
|
||||
token_expert_list = []
|
||||
for idx in range(0, num_experts // non_zero_expert_num_per_token):
|
||||
token_expert_list.append(non_zero_expert_table[
|
||||
idx * non_zero_expert_num_per_token:
|
||||
idx * non_zero_expert_num_per_token + non_zero_expert_num_per_token])
|
||||
token_expert_list.append(zero_expert_table[
|
||||
idx * zero_expert_num_per_token:
|
||||
idx * zero_expert_num_per_token + zero_expert_num_per_token])
|
||||
avg_expert_table = torch.cat(token_expert_list)
|
||||
table = avg_expert_table.repeat(hi_val)
|
||||
SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k)
|
||||
SparseMoeMlp.is_expert_avg = True
|
||||
|
||||
|
||||
def forward_experts_nofused_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None):
|
||||
assert self.moe_ep_size == 1
|
||||
assert not self.use_all2all
|
||||
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_indices.to(torch.int32), total_num_experts)
|
||||
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
|
||||
# expand_token_count and expand_cusum_token_count has item but the value is all zero
|
||||
# so this rank should only return final_hidden_states with zero value
|
||||
if cusum_token_count[-1] == 0:
|
||||
final_hidden_states = torch.zeros_like(hidden_states,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
return final_hidden_states
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_gather_idx, cusum_token_count,
|
||||
start_expert_id=self.start_expert_id,
|
||||
expert_size=self.end_expert_id - self.start_expert_id)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_gather_idx, cusum_token_count,
|
||||
start_expert_id=self.start_zero_expert_id,
|
||||
expert_size=self.zero_expert_size)
|
||||
|
||||
expand_output_list = []
|
||||
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
|
||||
1] - cusum_token_count[self.start_expert_id]
|
||||
|
||||
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count[:self.num_total_experts]):
|
||||
if num_tokens_per_expert > 0:
|
||||
expert_hidden_states = expand_hidden_states[
|
||||
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
|
||||
if expert_idx < self.num_total_experts:
|
||||
expert_output = self.experts[expert_idx](expert_hidden_states)
|
||||
else:
|
||||
expert_output = expert_hidden_states
|
||||
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
|
||||
expand_output_list.append(expert_output)
|
||||
expand_output = torch.cat(expand_output_list, dim=0)
|
||||
num_normal_tokens = cusum_token_count[self.num_total_experts]
|
||||
expand_hidden_states[:num_normal_tokens] = expand_output
|
||||
# reduce normal experts
|
||||
final_hidden_states = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states, topk_weights, scatter_idx,
|
||||
residual_, cusum_token_count, start_expert_id=self.start_expert_id,
|
||||
expert_size=self.end_expert_id - self.start_expert_id, bias=None)
|
||||
# reduce zero experts
|
||||
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
|
||||
final_hidden_states = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, topk_weights, scatter_idx,
|
||||
final_hidden_states, cusum_token_count, start_expert_id=self.start_zero_expert_id,
|
||||
expert_size=self.zero_expert_size, bias=None,
|
||||
output=final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
# no compute-communication parallel, for prototyping only, not in actual use.
|
||||
# subject to becoming stale
|
||||
def forward_all2all_int8_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None):
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
self.pack_params_after_loading()
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias2=self.b2
|
||||
input_smooth=self.a13_scale_all_experts
|
||||
act_smooth=self.a2_scale
|
||||
w1_scale=self.w13_scale
|
||||
w2_scale=self.w2_scale
|
||||
act_mode=self.hidden_act
|
||||
quant_input=None
|
||||
|
||||
max_m = hidden_states.shape[0]
|
||||
|
||||
reduce_weight = topk_weights
|
||||
expert_id = topk_indices
|
||||
|
||||
expand_idx, combine_idx, token_count, cusum_token_count \
|
||||
= mlu_ops.moe_gen_idx(expert_id, total_num_experts)
|
||||
|
||||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||||
|
||||
dispatch_send_token_tensor = (
|
||||
self.dispatch_send_buffer[:dispatch_bytes]
|
||||
.view(num_token_expand, self.dispatch_token_size)
|
||||
)
|
||||
|
||||
quant_size = self.hidden_size
|
||||
quant_input = dispatch_send_token_tensor[:, : quant_size]
|
||||
input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32)
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||||
hidden_states, input_smooth, None, token_count[:self.num_total_experts],
|
||||
expand_idx, None,
|
||||
output=quant_input,
|
||||
output_scale=input_scale)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
|
||||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
token_count[:self.num_total_experts], self.moe_ep_size)
|
||||
|
||||
cnclep_dispatch(self.dispatch_token_size,
|
||||
num_token_expand,
|
||||
dispatch_send_layout,
|
||||
token_count[:self.num_total_experts],
|
||||
self.dispatch_recv_layout,
|
||||
self.dispatch_recv_token_num)
|
||||
|
||||
recv_token_num = self.dispatch_recv_token_num.view(
|
||||
self.moe_ep_size, self.num_experts_per_rank)
|
||||
pad_num = self.max_num_tokens_per_rank
|
||||
|
||||
(
|
||||
gather_by_expert_index,
|
||||
gather_by_rank_index,
|
||||
tokens_per_local_expert,
|
||||
token_sum
|
||||
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
|
||||
|
||||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||||
dispatch_recv_token_tensor = (
|
||||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||||
.view(self.max_num_tokens_recv, self.dispatch_token_size))
|
||||
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.quant_input_recv,
|
||||
self.input_scale_recv)
|
||||
|
||||
max_m = self.max_num_tokens_per_expert
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1,
|
||||
tokens_per_local_expert,
|
||||
None, None, None, None,
|
||||
self.input_scale_recv.view(torch.float32).flatten(),
|
||||
w1_scale, dtype, max_m)
|
||||
|
||||
# continue reusing self.quant_input_recv and self.input_scale_recv
|
||||
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
|
||||
input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]]
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
|
||||
tokens_per_local_expert,
|
||||
output=quant_input,
|
||||
output_scale=input_scale_fp32,
|
||||
act_mode=act_mode,
|
||||
is_gated=self.is_gated)
|
||||
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
|
||||
tokens_per_local_expert,
|
||||
None, None, None, None, input_scale, w2_scale, dtype, max_m)
|
||||
|
||||
combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype)
|
||||
mlu_ops.gather_split(gemm_out,
|
||||
gather_by_rank_index,
|
||||
token_sum,
|
||||
combine_send_token_tensor,
|
||||
None)
|
||||
|
||||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size)
|
||||
combine_recv_layout = self.dispatch_recv_layout
|
||||
|
||||
# combine
|
||||
combine_args = dict(
|
||||
token_byte=self.hidden_size * 2,
|
||||
token_num=num_token_expand,
|
||||
send_src_layout=combine_send_layout,
|
||||
send_dst_layout=combine_recv_layout,
|
||||
send_token=None,
|
||||
recv_token=None)
|
||||
|
||||
cnclep_combine(**combine_args)
|
||||
|
||||
numel_recv = num_token_expand * self.hidden_size
|
||||
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
|
||||
.view(num_token_expand, self.hidden_size))
|
||||
|
||||
residual_ = None
|
||||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
|
||||
assert self.moe_ep_size > 1
|
||||
# zero expert reduce
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.num_total_experts,
|
||||
self.num_zero_experts, output=hidden_states)
|
||||
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
# no compute-communication parallel, for prototyping only, not in actual use.
|
||||
# subject to becoming stale
|
||||
def forward_all2all_bf16_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None):
|
||||
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
self.pack_params_after_loading()
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias1=self.b13
|
||||
bias2=self.b2
|
||||
gated=self.is_gated
|
||||
act_mode=self.hidden_act
|
||||
|
||||
max_m = hidden_states.shape[0]
|
||||
reduce_weight = topk_weights
|
||||
expert_id = topk_indices
|
||||
|
||||
# gen_idx
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = \
|
||||
mlu_ops.moe_gen_idx(expert_id, total_num_experts)
|
||||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||||
|
||||
dispatch_send_token_tensor = (
|
||||
self.dispatch_send_buffer[:dispatch_bytes]
|
||||
.view(num_token_expand, self.dispatch_token_size)
|
||||
.view(hidden_states.dtype)
|
||||
)
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
|
||||
dispatch_send_token_tensor.copy_(expand_hidden_states)
|
||||
|
||||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
token_count[:self.num_total_experts], self.moe_ep_size)
|
||||
|
||||
cnclep_dispatch(self.dispatch_token_size,
|
||||
num_token_expand,
|
||||
dispatch_send_layout,
|
||||
token_count[:self.num_total_experts],
|
||||
self.dispatch_recv_layout,
|
||||
self.dispatch_recv_token_num,
|
||||
use_quant_dispatch=False,
|
||||
)
|
||||
|
||||
recv_token_num = self.dispatch_recv_token_num.view(
|
||||
self.moe_ep_size, self.num_experts_per_rank)
|
||||
pad_num = self.max_num_tokens_per_rank
|
||||
|
||||
(
|
||||
gather_by_expert_index,
|
||||
gather_by_rank_index,
|
||||
tokens_per_local_expert,
|
||||
token_sum
|
||||
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
|
||||
|
||||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||||
dispatch_recv_token_tensor = (
|
||||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||||
.view(self.max_num_tokens_recv, self.dispatch_token_size)
|
||||
.view(hidden_states.dtype)
|
||||
)
|
||||
|
||||
|
||||
self.quant_input_recv = self.quant_input_recv.view(hidden_states.dtype)
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.quant_input_recv)
|
||||
|
||||
max_m = self.max_num_tokens_per_expert
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
self.quant_input_recv, w1, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
act_out = mlu_ops.moe_active(
|
||||
gemm_out, act_mode, gated)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
act_out, w2, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
|
||||
combine_send_token_tensor = self.combine_send_buffer.view(
|
||||
self.max_num_tokens_recv, -1).view(hidden_states.dtype)
|
||||
mlu_ops.gather_split(gemm_out,
|
||||
gather_by_rank_index,
|
||||
token_sum,
|
||||
combine_send_token_tensor,
|
||||
None)
|
||||
|
||||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
self.dispatch_recv_token_num, self.moe_ep_size)
|
||||
combine_recv_layout = self.dispatch_recv_layout
|
||||
|
||||
combine_args = dict(
|
||||
token_byte=self.hidden_size * 2,
|
||||
token_num=num_token_expand,
|
||||
send_src_layout=combine_send_layout,
|
||||
send_dst_layout=combine_recv_layout,
|
||||
send_token=None,
|
||||
recv_token=None,
|
||||
use_quant_dispatch=False,
|
||||
)
|
||||
|
||||
cnclep_combine(**combine_args)
|
||||
|
||||
numel_recv = num_token_expand * self.hidden_size
|
||||
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
|
||||
.view(num_token_expand, self.hidden_size))
|
||||
|
||||
residual_ = None
|
||||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
|
||||
# zero expert reduce
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.num_total_experts,
|
||||
self.num_zero_experts, output=hidden_states)
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
def forward_before_dispatch(self, hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor):
|
||||
# gate and softmax topk is called in router for longcat
|
||||
# other models can do these operations here
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_indices, self.total_experts_including_zero)
|
||||
|
||||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||||
dispatch_send_token_tensor = (
|
||||
self.dispatch_send_buffer[:dispatch_bytes]
|
||||
.view(num_token_expand, self.dispatch_token_size)
|
||||
)
|
||||
if self.use_quant_all2all:
|
||||
hidden_states_stride = self.hidden_size
|
||||
quant_input = dispatch_send_token_tensor[:, : hidden_states_stride]
|
||||
input_scale = dispatch_send_token_tensor[:, hidden_states_stride :].view(torch.float32)
|
||||
# expand input + quantize
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||||
hidden_states, self.a13_scale_all_experts, None,
|
||||
token_count[:self.num_total_experts],
|
||||
expand_idx, None,
|
||||
output=quant_input,
|
||||
output_scale=input_scale)
|
||||
# expand input of zero-expert
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
else:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts)
|
||||
dispatch_send_token_tensor = dispatch_send_token_tensor.view(
|
||||
hidden_states.dtype)
|
||||
dispatch_send_token_tensor.copy_(expand_hidden_states)
|
||||
del expand_hidden_states
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.num_total_experts,
|
||||
expert_size=self.num_zero_experts)
|
||||
|
||||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
token_count[:self.num_total_experts], self.moe_ep_size)
|
||||
|
||||
return combine_idx, token_count, cusum_token_count, dispatch_send_layout, expand_hidden_states_zero
|
||||
|
||||
def forward_dispatch(self, token_num: int, dispatch_send_layout: torch.Tensor,
|
||||
token_count: torch.Tensor):
|
||||
num_token_expand = token_num * self.top_k
|
||||
cnclep_dispatch(self.dispatch_token_size,
|
||||
num_token_expand,
|
||||
dispatch_send_layout,
|
||||
token_count[:self.num_total_experts],
|
||||
self.dispatch_recv_layout,
|
||||
self.dispatch_recv_token_num,
|
||||
use_quant_dispatch=self.use_quant_all2all)
|
||||
|
||||
def forward_before_combine(self, hidden_states_dtype: torch.dtype):
|
||||
recv_token_num = self.dispatch_recv_token_num.view(
|
||||
self.moe_ep_size, self.num_experts_per_rank)
|
||||
|
||||
(
|
||||
gather_by_expert_index,
|
||||
gather_by_rank_index,
|
||||
tokens_per_local_expert,
|
||||
token_sum,
|
||||
cusum_token_count
|
||||
) = mlu_ops.moe_all2all_gen_gather_index(
|
||||
recv_token_num, self.max_num_tokens_per_rank,
|
||||
return_cusum_token_count=True)
|
||||
|
||||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||||
dispatch_recv_token_tensor = (
|
||||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||||
.view(self.max_num_tokens_recv, self.dispatch_token_size))
|
||||
|
||||
max_m = self.max_num_tokens_per_expert
|
||||
if self.use_quant_all2all:
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.quant_input_recv,
|
||||
self.input_scale_recv)
|
||||
# OPT: input_scale_recv_flatten can reuse self.input_scale_recv
|
||||
input_scale_recv_flatten = self.input_scale_recv.view(torch.float32).flatten()
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, self.w13,
|
||||
tokens_per_local_expert,
|
||||
None, None, None, None,
|
||||
input_scale_recv_flatten,
|
||||
self.w13_scale, hidden_states_dtype, max_m)
|
||||
|
||||
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
|
||||
input_scale_fp32 = input_scale_recv_flatten[:gemm_out.shape[0]]
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, self.a2_scale, None,
|
||||
tokens_per_local_expert,
|
||||
output=quant_input,
|
||||
output_scale=input_scale_fp32,
|
||||
act_mode=self.hidden_act,
|
||||
is_gated=self.is_gated)
|
||||
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, self.w2, tokens_per_local_expert,
|
||||
None, None, None, None, input_scale, self.w2_scale,
|
||||
hidden_states_dtype, max_m)
|
||||
else:
|
||||
dispatch_recv_token_tensor = dispatch_recv_token_tensor.view(hidden_states_dtype)
|
||||
self.input_recv = self.input_recv.view(hidden_states_dtype)
|
||||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||||
gather_by_expert_index,
|
||||
token_sum,
|
||||
self.input_recv)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
self.input_recv, self.w13, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
act_out = self.input_recv[:, :gemm_out.shape[-1] // 2]
|
||||
act_out = mlu_ops.moe_active(
|
||||
gemm_out, self.hidden_act, self.is_gated, output=act_out,
|
||||
bias=None, cusum_token_count=cusum_token_count,
|
||||
start_expert_id=0, expert_size=self.num_experts_per_rank)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
act_out, self.w2, tokens_per_local_expert,
|
||||
None, None, None, None, max_m)
|
||||
|
||||
combine_send_token_tensor = self.combine_send_buffer.view(
|
||||
self.max_num_tokens_recv, -1).view(hidden_states_dtype)
|
||||
mlu_ops.gather_split(gemm_out,
|
||||
gather_by_rank_index,
|
||||
token_sum,
|
||||
combine_send_token_tensor,
|
||||
None)
|
||||
|
||||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
|
||||
self.dispatch_recv_token_num, self.moe_ep_size)
|
||||
|
||||
return combine_send_layout
|
||||
|
||||
def forward_combine(self, token_num: int, combine_send_layout: torch.Tensor):
|
||||
num_token_expand = token_num * self.top_k
|
||||
# combine_recv_layout(self.dispatch_recv_layout) is calculated when cnclep_dispatch
|
||||
# because dispatch and combine are inverse operation
|
||||
cnclep_combine(token_byte=self.hidden_size * 2,
|
||||
token_num=num_token_expand,
|
||||
send_src_layout=combine_send_layout,
|
||||
send_dst_layout=self.dispatch_recv_layout,
|
||||
send_token=None,
|
||||
recv_token=None,
|
||||
use_quant_dispatch=self.use_quant_all2all)
|
||||
|
||||
def forward_after_combine(self, token_num: int,
|
||||
reduce_weight: torch.Tensor,
|
||||
combine_idx: torch.Tensor,
|
||||
cusum_token_count: torch.Tensor,
|
||||
expand_hidden_states_zero: torch.Tensor,
|
||||
output_tensor_dtype: torch.dtype,
|
||||
output_tensor: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None):
|
||||
num_token_expand = token_num * self.top_k
|
||||
numel_recv = num_token_expand * self.hidden_size
|
||||
recv_token = (self.combine_recv_buffer.view(output_tensor_dtype)[:numel_recv]
|
||||
.view(num_token_expand, self.hidden_size))
|
||||
|
||||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||||
residual, cusum_token_count, start_expert_id=0,
|
||||
expert_size=self.num_total_experts, bias=self.b2, output=output_tensor)
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.num_total_experts,
|
||||
self.num_zero_experts, output=output_tensor)
|
||||
|
||||
return output
|
||||
|
||||
# no compute-communication parallel, for prototyping only, not in actual use.
|
||||
# subject to becoming stale
|
||||
def forward_group_experts_longcat(
|
||||
self, hidden_states, total_num_experts, total_num_experts_per_rank,
|
||||
topk_indices=None, topk_weights=None, residual_=None,
|
||||
expand_idx=None, combine_idx=None, token_count=None, cusum_token_count=None):
|
||||
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
|
||||
ori_input_shape = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
self.pack_params()
|
||||
self.pack_params_after_loading()
|
||||
w1=self.w13
|
||||
w2=self.w2
|
||||
bias1=self.b13
|
||||
bias2=self.b2
|
||||
input_smooth=self.a13_scale
|
||||
act_smooth=self.a2_scale
|
||||
w1_scale=self.w13_scale
|
||||
w2_scale=self.w2_scale
|
||||
gated=self.is_gated
|
||||
act_mode=self.hidden_act
|
||||
quant_input=None
|
||||
|
||||
start_expert_id=self.start_expert_id
|
||||
expert_size = w1.size(0)
|
||||
max_m = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
|
||||
# Check smooth quant parameters.
|
||||
per_token_sq = False
|
||||
if not is_fp8_quant:
|
||||
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
|
||||
if all(x is not None for x in check_list):
|
||||
per_token_sq = True
|
||||
|
||||
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
|
||||
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
|
||||
"and absent at the same time.")
|
||||
|
||||
expert_id = topk_indices
|
||||
reduce_weight = topk_weights
|
||||
|
||||
# gen_idx
|
||||
if expert_id is not None:
|
||||
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, total_num_experts)
|
||||
|
||||
# check quant
|
||||
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
|
||||
raise NotImplementedError
|
||||
elif per_token_sq:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=start_expert_id,
|
||||
expert_size=expert_size)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(
|
||||
hidden_states, expand_idx, cusum_token_count,
|
||||
start_expert_id=self.start_zero_expert_id,
|
||||
expert_size=self.zero_expert_size)
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||||
expand_hidden_states, input_smooth, None,
|
||||
token_count[start_expert_id:start_expert_id+expert_size])
|
||||
else:
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, start_expert_id, expert_size)
|
||||
expand_hidden_states_zero = mlu_ops.moe_expand_input(hidden_states, expand_idx,
|
||||
cusum_token_count, self.start_zero_expert_id, self.zero_expert_size)
|
||||
|
||||
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input, w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, input_scale, w1_scale, dtype, max_m)
|
||||
else:
|
||||
gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, max_m)
|
||||
|
||||
# add_bias_active
|
||||
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
|
||||
raise NotImplementedError
|
||||
elif per_token_sq:
|
||||
quant_input = quant_input[:, :gemm_out.shape[-1] // 2]
|
||||
input_scale = input_scale[:gemm_out.shape[0]]
|
||||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
output=quant_input,
|
||||
output_scale=input_scale,
|
||||
act_mode=act_mode,
|
||||
is_gated=self.is_gated)
|
||||
|
||||
if ((is_fp8_quant and self.quant_config.activation_quant_method == 'per_token')
|
||||
or per_token_sq):
|
||||
# Remove the reference to gemm_out tensor.
|
||||
# If that was the only reference, the tensor’s memory becomes eligible for deallocation
|
||||
# So that we can reuse this memory for the new allocation of next gemm operation
|
||||
# del gemm_out
|
||||
gemm_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input, w2,
|
||||
token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, input_scale, w2_scale, dtype, max_m,
|
||||
output=expand_hidden_states)
|
||||
else:
|
||||
act_out = mlu_ops.moe_active(
|
||||
gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2],
|
||||
bias1, cusum_token_count, start_expert_id, expert_size)
|
||||
gemm_out = mlu_ops.group_gemm(
|
||||
act_out, w2, token_count[start_expert_id:start_expert_id+expert_size],
|
||||
None, None, None, None, max_m,
|
||||
output=expand_hidden_states)
|
||||
|
||||
output = mlu_ops.moe_combine_result(
|
||||
gemm_out, reduce_weight, combine_idx,
|
||||
residual_, cusum_token_count, start_expert_id,
|
||||
expert_size, bias2)
|
||||
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
|
||||
output = mlu_ops.moe_combine_result(
|
||||
expand_hidden_states_zero, reduce_weight, combine_idx,
|
||||
output, cusum_token_count, self.start_zero_expert_id,
|
||||
self.zero_expert_size, bias2,
|
||||
output=output)
|
||||
return output.view(ori_input_shape)
|
||||
37
vllm_mlu/model_executor/layers/quantization/__init__.py
Normal file
37
vllm_mlu/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QUANTIZATION_METHODS, register_quantization_config
|
||||
)
|
||||
|
||||
MLU_QUANTIZATION_METHODS= [
|
||||
"smoothquant",
|
||||
"weightonly",
|
||||
"awq_mlu",
|
||||
"gptq_mlu",
|
||||
]
|
||||
|
||||
|
||||
def register_fake_mlu_quantization_methods():
|
||||
for quant_method in MLU_QUANTIZATION_METHODS:
|
||||
if quant_method not in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.append(quant_method)
|
||||
|
||||
|
||||
def remove_fake_mlu_quantization_methods():
|
||||
for quant_method in MLU_QUANTIZATION_METHODS:
|
||||
if quant_method in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.remove(quant_method)
|
||||
|
||||
|
||||
def register_real_mlu_quantization_methods():
|
||||
remove_fake_mlu_quantization_methods()
|
||||
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
|
||||
register_quantization_config("weightonly")(WeightOnlyConfig)
|
||||
register_quantization_config("smoothquant")(SmoothQuantConfig)
|
||||
register_quantization_config("awq_mlu")(AWQMluConfig)
|
||||
register_quantization_config("gptq_mlu")(GPTQMluConfig)
|
||||
412
vllm_mlu/model_executor/layers/quantization/awq_mlu.py
Normal file
412
vllm_mlu/model_executor/layers/quantization/awq_mlu.py
Normal file
@@ -0,0 +1,412 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("awq_mlu")
|
||||
class AWQMluConfig(QuantizationConfig):
|
||||
"""Config class for AWQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: {
|
||||
False: scalar_types.uint4b8,
|
||||
True: scalar_types.uint4,
|
||||
},
|
||||
8: {
|
||||
False: scalar_types.uint8b128,
|
||||
True: scalar_types.uint8,
|
||||
}
|
||||
}
|
||||
|
||||
VERSION = ["gemm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
version: str = "gemm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
self.version = version
|
||||
self.support_scale_zeros = False
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is supported for "
|
||||
f"AWQMlu, but got {self.weight_bits} bits.")
|
||||
if self.version not in self.VERSION:
|
||||
raise ValueError(
|
||||
"Currently, only gemm, gemv version is supported for "
|
||||
f"AWQMlu, but got verion:{self.version}.")
|
||||
|
||||
if self.version in ["gemm"]:
|
||||
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
|
||||
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
|
||||
else:
|
||||
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
version = cls.get_from_keys_or(config, ["version"],
|
||||
default="gemm")
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "awq"
|
||||
or user_quant == "awq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_mlu"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_mlu for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
has_zp = quant_config.get("zero_point", None)
|
||||
version = quant_config.get("version", "gemm")
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
if version not in cls.VERSION:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp)
|
||||
|
||||
class AWQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
packed_qweight, scale_zeros = self.extract_autoawq(layer)
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
def extract_autoawq(self, layer: torch.nn.Module):
|
||||
qweight = layer.qweight.data
|
||||
qzeros = layer.qzeros.data
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
|
||||
|
||||
# overflow checks
|
||||
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||
if izeros is not None:
|
||||
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
# unpacking columnwise
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
|
||||
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
|
||||
|
||||
# unpacking columnwise
|
||||
if qzeros is not None:
|
||||
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
izeros = izeros.view(izeros.shape[0], -1)
|
||||
if not self.quant_config.zero_point:
|
||||
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
|
||||
else:
|
||||
izeros = None
|
||||
|
||||
return iweights, izeros
|
||||
|
||||
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
rweights = iweights[:, reverse_order_tensor]
|
||||
if izeros is not None:
|
||||
rzeros = izeros[:, reverse_order_tensor]
|
||||
|
||||
return rweights, rzeros
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
753
vllm_mlu/model_executor/layers/quantization/fp8.py
Normal file
753
vllm_mlu/model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,753 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from vllm import envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
get_flashinfer_moe_backend,
|
||||
ACTIVATION_SCHEMES,
|
||||
Fp8Config,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoeBackend,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
validate_fp8_block_shape
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
|
||||
maybe_create_device_identity, Fp8LinearOp)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter, ChannelQuantScaleParameter,
|
||||
ModelWeightParameter, PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
import vllm_mlu._mlu_ops as mlu_ops
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"instead."
|
||||
)
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disable marlin for MLU backend.
|
||||
'''
|
||||
if current_platform.is_rocm() or current_platform.is_out_of_tree():
|
||||
use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if use_marlin:
|
||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# deepGEMM on supported platforms with block-quantized weights
|
||||
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("DeepGEMM backend requested but not available.")
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and block_quant
|
||||
):
|
||||
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
|
||||
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
Fp8Config____init____org = Fp8Config.__init__
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: list[str] | None = None,
|
||||
weight_block_size: list[int] | None = None,
|
||||
activation_quant_method: Optional[str] = None,
|
||||
weight_quant_method: Optional[str] = None,
|
||||
) -> None:
|
||||
super(Fp8Config, self).__init__()
|
||||
|
||||
Fp8Config____init____org(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized,
|
||||
activation_scheme,
|
||||
ignored_layers,
|
||||
weight_block_size
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add class members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
self.activation_quant_method = activation_quant_method
|
||||
self.weight_quant_method = weight_quant_method
|
||||
|
||||
assert (self.weight_block_size or \
|
||||
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
|
||||
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
|
||||
"input dynamic per-token weight per-channel quantization yet."
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
||||
if not ignored_layers:
|
||||
ignored_layers = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add config members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
activation_quant_method = cls.get_from_keys_or(config,
|
||||
["activation_quant_method"],
|
||||
'per_token')
|
||||
weight_quant_method = cls.get_from_keys_or(config,
|
||||
["weight_quant_method"],
|
||||
None)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=weight_block_size,
|
||||
activation_quant_method=activation_quant_method,
|
||||
weight_quant_method=weight_quant_method)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group.
|
||||
'''
|
||||
tp_group = extra_weight_attrs.get("tp_group", None)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group.
|
||||
'''
|
||||
# WEIGHT
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
else:
|
||||
# For non-serialized checkpoints, use original dtype
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Support weight per channel quantization.
|
||||
@brief: Add tp_group to enable custom split.
|
||||
'''
|
||||
if self.weight_per_channel:
|
||||
scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
else:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
scale = create_fp8_scale_parameter(
|
||||
BlockQuantScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.act_q_static:
|
||||
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
|
||||
self,
|
||||
quant_config: Fp8Config
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
# AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
# and at the moment are MI300 series
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
if self.weight_block_size:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add config members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
|
||||
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
|
||||
if self.weight_per_channel and self.activation_per_token:
|
||||
self.use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
)
|
||||
|
||||
|
||||
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
|
||||
self,
|
||||
layer: Module,
|
||||
) -> None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: For dynamic activation and channel-wise weight quantization,
|
||||
additional processing is not needed.
|
||||
'''
|
||||
if (self.quant_config.is_checkpoint_fp8_serialized
|
||||
and self.weight_per_channel
|
||||
and self.quant_config.activation_scheme == "dynamic"):
|
||||
return
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert residual is None, "Fp8Linear residual is not supported yet."
|
||||
|
||||
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
|
||||
# we will use BF16 dequant when DeepGEMM is not supported.
|
||||
if vllm_is_batch_invariant():
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
# per-tensor/channel: dequant to BF16 and run GEMM
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
else:
|
||||
# Multiple scales (fused modules like QKV)
|
||||
# Try to infer correct broadcasting
|
||||
# weight is [K, N], scale could be [num_logical_weights]
|
||||
# Need to figure out how to broadcast - for now just try
|
||||
# direct multiplication
|
||||
if (
|
||||
weight_scale.dim() == 1
|
||||
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||
):
|
||||
# Per-row scaling
|
||||
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use activation per token quantization based on quantization config.
|
||||
'''
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
weight_per_channel=self.weight_per_channel,
|
||||
activation_per_token=self.activation_per_token)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
|
||||
self,
|
||||
quant_config: Fp8Config,
|
||||
layer: torch.nn.Module
|
||||
):
|
||||
super(Fp8MoEMethod, self).__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size == [128, 128], (
|
||||
f"Only support weight_block_size == [128, 128], "
|
||||
f"got {self.weight_block_size}"
|
||||
)
|
||||
self.flashinfer_moe_fn = partial(
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
moe=self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: In mlu, always set self.use_marlin as False.
|
||||
'''
|
||||
self.use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
|
||||
'''
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
route_scale=routed_scaling_factor,
|
||||
)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
# gen_idx
|
||||
ori_input_shape = x.shape
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
router_logits = router_logits.reshape(-1, router_logits.size(-1))
|
||||
expert_num = router_logits.size(-1)
|
||||
tokens_num = x.size(0)
|
||||
expert_size = layer.w13_weight.size(0)
|
||||
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_ids, expert_num
|
||||
)
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
x, expand_idx, cumsum_token_count, 0, expert_size
|
||||
)
|
||||
quant_input, input_scale = _fp8_quantize(
|
||||
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
|
||||
)
|
||||
gemm1_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input,
|
||||
layer.w13_weight,
|
||||
token_count,
|
||||
expand_idx=None,
|
||||
c=None,
|
||||
alpha=None,
|
||||
beta=None,
|
||||
a_scale=input_scale.T.contiguous(),
|
||||
b_scale=layer.w13_weight_scale_inv,
|
||||
dtype=x.dtype,
|
||||
max_m=tokens_num,
|
||||
)
|
||||
|
||||
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
|
||||
act_out_quantize, act_out_scale = _fp8_quantize(
|
||||
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
|
||||
)
|
||||
|
||||
gemm2_out = mlu_ops.smooth_quant_group_gemm(
|
||||
act_out_quantize,
|
||||
layer.w2_weight,
|
||||
token_count,
|
||||
expand_idx=None,
|
||||
c=None,
|
||||
alpha=None,
|
||||
beta=None,
|
||||
a_scale=act_out_scale.T.contiguous(),
|
||||
b_scale=layer.w2_weight_scale_inv,
|
||||
dtype=x.dtype,
|
||||
max_m=tokens_num,
|
||||
)
|
||||
|
||||
output = mlu_ops.moe_combine_result(
|
||||
gemm2_out,
|
||||
topk_weights,
|
||||
combine_idx,
|
||||
residual=None,
|
||||
cusum_token_count=cumsum_token_count,
|
||||
start_expert_id=0,
|
||||
expert_size=expert_size,
|
||||
bias=None,
|
||||
)
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
"""
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
"""
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.apply,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8Config,
|
||||
Fp8Config.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8Config,
|
||||
Fp8Config.from_config,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.create_weights,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.process_weights_after_loading,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8MoEMethod,
|
||||
Fp8MoEMethod.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8MoEMethod,
|
||||
Fp8MoEMethod.apply,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
|
||||
)
|
||||
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("gptq_mlu")
|
||||
class GPTQMluConfig(QuantizationConfig):
|
||||
"""Config class for GPTQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
(4, False): scalar_types.uint4b8,
|
||||
(8, False): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
self.support_scale_zeros = False
|
||||
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is "
|
||||
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size, has_zp=False)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
class GPTQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
|
||||
-1) and (not self.quant_config.desc_act):
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.device = layer.qweight.data.device
|
||||
packed_qweight, scale_zeros = self.extract_autogptq(layer)
|
||||
if self.quant_config.use_native:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.use_native:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def extract_autogptq(self, layer: torch.nn.Module):
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
|
||||
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
|
||||
|
||||
if self.quant_config.use_native:
|
||||
if self.quant_config.desc_act:
|
||||
scales = torch.index_select(scales, 0, layer.g_idx)
|
||||
if izeros is not None:
|
||||
izeros = torch.index_select(izeros, 0, layer.g_idx)
|
||||
else:
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
|
||||
if izeros is not None:
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
# for is_sym is true now, so make iweight to sign value and ignore qzeros
|
||||
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
|
||||
shifts.unsqueeze(-1),
|
||||
).to(dtype)
|
||||
weight = torch.bitwise_and(weight, (2**bits) - 1)
|
||||
weight = weight.reshape(-1, weight.shape[-1])
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
|
||||
shifts.unsqueeze(0),
|
||||
).to(dtype)
|
||||
|
||||
zeros = zeros + 1
|
||||
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
zeros = zeros.reshape(qzeros.shape[0], -1)
|
||||
|
||||
return zeros
|
||||
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
337
vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
337
vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
|
||||
str_dtype_to_bits,
|
||||
is_fp8_str_dtype)
|
||||
|
||||
|
||||
# @register_quantization_config("smoothquant")
|
||||
class SmoothQuantConfig(QuantizationConfig):
|
||||
"""Config class for SmoothQuant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_mode: str, # smoothquant
|
||||
input_quant_method: str, # per token/per tensor
|
||||
group_size: int,
|
||||
weight_precision: str,
|
||||
activation_precision: str,
|
||||
only_expert_per_group: bool,
|
||||
expert_weight_precision: str,
|
||||
expert_activation_precision: str,
|
||||
force_use_weightonly_except_expert: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_mode = quant_mode
|
||||
self.input_quant_method = input_quant_method
|
||||
self.group_size = group_size
|
||||
self.weight_precision = weight_precision
|
||||
self.activation_precision = activation_precision
|
||||
self.only_expert_per_group = only_expert_per_group
|
||||
self.expert_weight_precision = expert_weight_precision
|
||||
self.expert_activation_precision = expert_activation_precision
|
||||
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
|
||||
|
||||
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
|
||||
self.weight_bits = str_dtype_to_bits(self.weight_precision)
|
||||
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
|
||||
if self.weight_precision == 'int4':
|
||||
self.weight_dtype = torch.int8
|
||||
else:
|
||||
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
|
||||
if self.expert_weight_precision == 'int4':
|
||||
self.expert_weight_dtype = torch.int8
|
||||
else:
|
||||
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
|
||||
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
|
||||
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
|
||||
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
self.expert_pack_factor = 8 // self.expert_weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
|
||||
f"quant_mode={self.quant_mode}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"weight_precision={self.weight_precision}, "
|
||||
f"activation_precision={self.activation_precision}, "
|
||||
f"only_expert_per_group={self.only_expert_per_group}, "
|
||||
f"expert_weight_precision={self.expert_weight_precision}, "
|
||||
f"expert_activation_precision={self.expert_activation_precision}, "
|
||||
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "SmoothQuant"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
|
||||
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
|
||||
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
|
||||
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
|
||||
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
|
||||
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
|
||||
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
|
||||
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
|
||||
|
||||
if expert_weight_precision is None:
|
||||
expert_weight_precision = weight_precision
|
||||
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
|
||||
weight_precision = 'int8'
|
||||
|
||||
if expert_activation_precision is None:
|
||||
expert_activation_precision = activation_precision
|
||||
|
||||
return cls(quant_mode=quant_mode,
|
||||
input_quant_method=input_quant_method,
|
||||
group_size=group_size,
|
||||
weight_precision=weight_precision,
|
||||
activation_precision=activation_precision,
|
||||
only_expert_per_group=only_expert_per_group,
|
||||
expert_weight_precision=expert_weight_precision,
|
||||
expert_activation_precision=expert_activation_precision,
|
||||
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return SmoothQuantLinearMethod(self, prefix)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class SmoothQuantLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SmoothQuant.
|
||||
|
||||
Args:
|
||||
quant_config: The SmoothQuant quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
|
||||
self.quant_config = quant_config
|
||||
# for per-tensor case, we can skip quant input for the first attn|ffn linear
|
||||
# and fusion this step in layernorm to get better performance
|
||||
self.skip_quant_input = False
|
||||
self.compute_dtype = torch.get_default_dtype()
|
||||
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
|
||||
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
|
||||
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
|
||||
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
|
||||
|
||||
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
|
||||
self.is_group_quant = True
|
||||
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
|
||||
self.is_group_quant = True
|
||||
else:
|
||||
self.is_group_quant = False
|
||||
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
|
||||
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor != 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
group_num = 1
|
||||
if self.is_group_quant:
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
f"The input size {input_size_per_partition} is not aligned with the quantized "
|
||||
f"weight shape. This can be caused by too large "
|
||||
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
|
||||
|
||||
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||||
if input_size_per_partition != input_size:
|
||||
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||||
|
||||
qweight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
device="mlu",
|
||||
dtype=self.weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.is_group_quant:
|
||||
per_channel_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
group_num,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
per_channel_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("per_channel_scale", per_channel_scale)
|
||||
|
||||
if self.has_smooth:
|
||||
smooth = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(smooth, {
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("smooth", smooth)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
scale_to_int = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale_to_int, {
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("scale_to_int", scale_to_int)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.has_smooth and layer.smooth.dtype != torch.float:
|
||||
layer.smooth = layer.smooth.to(torch.float)
|
||||
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
|
||||
layer.scale_to_int = layer.scale_to_int.to(torch.float)
|
||||
if layer.per_channel_scale.dtype != torch.float:
|
||||
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
|
||||
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
|
||||
if self.has_smooth:
|
||||
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
use_tp_weight : bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
layer_smooth = layer.smooth if self.has_smooth else None
|
||||
layer_qweight = layer.qweight
|
||||
layer_per_channel_scale = layer.per_channel_scale
|
||||
if use_tp_weight:
|
||||
if hasattr(layer, 'tp_smooth'):
|
||||
layer_smooth = layer.tp_smooth
|
||||
if hasattr(layer, 'tp_qweight'):
|
||||
layer_qweight = layer.tp_qweight
|
||||
if hasattr(layer, 'tp_per_channel_scale'):
|
||||
layer_per_channel_scale = layer.tp_per_channel_scale
|
||||
|
||||
quant_input = None
|
||||
if self.skip_quant_input:
|
||||
quant_input = x
|
||||
elif self.quant_config.input_quant_method == "per_token":
|
||||
if self.is_fp8:
|
||||
quant_input, input_scale = mlu_ops.scaled_quantize(x,
|
||||
layer_smooth,
|
||||
quant_type=self.weight_dtype,
|
||||
quant_mode='dynamic_per_token')
|
||||
else:
|
||||
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
|
||||
elif self.quant_config.input_quant_method == "per_tensor":
|
||||
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
quant_input_shape = quant_input.shape
|
||||
if len(quant_input_shape) > 2:
|
||||
quant_input = quant_input.view(-1, quant_input_shape[-1])
|
||||
input_scale = input_scale.view(-1)
|
||||
if residual is not None and len(residual.shape) > 2:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
if self.is_fp8:
|
||||
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
|
||||
layer_per_channel_scale,
|
||||
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
|
||||
bias,
|
||||
c=residual, act_mode="none",quant_bit_size=8,
|
||||
alpha=1.0, beta=1.0, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||||
if output is not None:
|
||||
out = out.view(output.shape)
|
||||
output.copy_(out)
|
||||
out = output
|
||||
else:
|
||||
if output is not None:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||||
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
|
||||
else:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||||
layer_per_channel_scale, self.compute_dtype, bias, residual)
|
||||
if len(quant_input_shape) > 2:
|
||||
out = out.view(*quant_input_shape[:-1], out.shape[-1])
|
||||
return out
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user