[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

206
.gitignore vendored Normal file
View File

@@ -0,0 +1,206 @@
# version file generated by setuptools-scm
/vllm/_version.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
/.deps/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/source/getting_started/examples/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# generated files
**/generated/**
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VSCode
.vscode/
# DS Store
.DS_Store
# Results
*.csv
# but may add new accuracy reference
!benchmarks/TruthfulQA/*.csv
# Python pickle files
*.pkl
# Sphinx documentation
_build/
# vim swap files
*.swo
*.swp
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h
# Benchmark dataset
benchmarks/*.json
# Linting
actionlint
shellcheck*/

76
CMakeLists.txt Normal file
View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
cmake_minimum_required(VERSION 3.16)
project(vllm_mlu_C)
function(detect_debian10)
if(EXISTS "/etc/os-release")
file(READ "/etc/os-release" os_release)
if(os_release MATCHES "PRETTY_NAME=\"Debian GNU/Linux 10" OR
os_release MATCHES "VERSION_ID=\"10")
set(DEBIAN_10 TRUE PARENT_SCOPE)
message(STATUS "Detected Debian 10 (buster)")
endif()
endif()
endfunction()
detect_debian10()
if(DEBIAN_10)
find_program(GCC_PATH "gcc" PATHS "/usr/local/bin")
if(GCC_PATH)
message(STATUS "Using GCC on Debian 10: ${GCC_PATH}")
set(CMAKE_C_COMPILER "${GCC_PATH}")
set(CMAKE_CXX_COMPILER "/usr/local/bin/g++")
else()
message(WARNING "Debian 10 detected but gcc not found!")
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
find_package(pybind11 REQUIRED)
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
set(VLLM_MLU_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}")
find_package(Torch REQUIRED)
if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE)
endif()
file(GLOB VLLM_MLU_SRC ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
include_directories(
${pybind11_INCLUDE_DIRS}
${PYTHON_INCLUDE_PATH}
${TORCH_INCLUDE_DIRS}
$ENV{NEUWARE_HOME}/include
)
pybind11_add_module(vllm_mlu_C ${VLLM_MLU_SRC})
target_link_directories(
vllm_mlu_C
PRIVATE
$ENV{NEUWARE_HOME}/lib64
)
target_link_libraries(
vllm_mlu_C
PUBLIC
${TORCH_LIBRARIES}
libcndrv.so
)
target_link_options(vllm_mlu_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
install(TARGETS vllm_mlu_C DESTINATION ${VLLM_MLU_INSTALL_PATH})

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Modifications made by Cambricon Technologies Corporation Limited. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

116
README.md Normal file
View File

@@ -0,0 +1,116 @@
<!-- SPDX-License-Identifier: Apache-2.0 -->
<!-- SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project -->
### Cambricon vLLM (vllm_mlu)
#### 1. 项目描述
Cambricon vLLMvllm_mlu基于社区vLLM提供的[插件系统](https://docs.vllm.ai/en/latest/design/plugin_system.html)开发旨在为用户提供在寒武纪MLU硬件平台上高效运行大语言模型LLM推理和服务的能力。
vllm_mlu支持包括但不限于Chunk Prefill、Prefix Caching、Spec Decode、Graph Mode、Sleep Mode等vLLM原生特性。
#### 2. 更新历史
[2026.04.24] vllm_mlu day0支持DeepSeek-V4
#### 3. 使用说明
软件环境依赖Cambricon SDKSDK获取请联系寒武纪官方支持渠道[ecosystem@cambricon.com](mailto:ecosystem@cambricon.com)
*NOTEvllm-mlu仓库仅支持MLU370以上的设备*
##### 3.1 镜像使用
使⽤寒武纪SDK提供的镜像 Cambricon vLLM Container。
```
# 加载镜像
docker load -i cambricon_vllm_container.tar.gz
# 进入镜像
docker run -it --net=host \
--shm-size '64gb' --privileged -it \
--ulimit memlock=-1 ${IMAGE_NAME} \
/bin/bash
# 使⽤推理环境
source /torch/venv3/pytorch_infer/bin/activate
```
##### 3.2 ⾃定义安装步骤
安装Cambricon vLLM前需要保证依赖已正确安装。
安装步骤:
```bash
# 已经获取Cambricon vLLM源码,包含vllm源码
# 基于vllm源码安装
cd vllm-v{社区vLLM版本}/
VLLM_TARGET_DEVICE=empty pip install -e . # 使⽤开发者模式安装
# 基于vllm-mlu源码安装
git clone https://github.com/Cambricon/vllm-mlu
cd vllm-mlu
pip install -e . # 使⽤开发者模式安装
# 安装ray
# 1. 进⼊vllm-mlu源码中。
cd tools/ray_mlu/
# 2. 适配基于Ray安装
pip install --no-cache-dir --force-reinstall ray==2.51.1
# 3. 为了在寒武纪设备运⾏Ray也需要适配寒武纪软件。
# PIP_INSTALL_LOC 指向pip的安装路径
cp __init__.py ${RAY_DIR}/_private/accelerators/__init__.py
cp mlu.py ${RAY_DIR}/_private/accelerators/
cp nsight.py ${RAY_DIR}/_private/runtime_env/nsight.py
cp node.py ${RAY_DIR}/_private/node.py
cp worker.py ${RAY_DIR}/_private/worker.py
cp device_manager/__init__.py ${RAY_DIR}/air/_internal/device_manager/__init__.py
cp device_manager/mlu.py ${RAY_DIR}/air/_internal/device_manager/
```
##### 3.3 运行步骤
Cambricon vLLM代码运⾏和vLLM社区⼀致。
###### 3.3.1 离线推理命令
```
# 运行推理命令
python examples/offline_inference/offline_inference.py ${MODEL_PATH}
```
###### 3.3.2 在线推理命令
分别启动server和client完成推理服务示例如下
```
# server
vllm serve ${MODEL_PATH} \
--port 8100 \
--block-size 1 \
--max-model-len 4096 \
--tensor-parallel-size 8 \
--gpu-memory-utilization 0.96 \
--trust-remote-code \
--enable-expert-parallel \
--no-enable-prefix-caching \
--disable-log-requests \
--enforce-eager
# client, we post a single request here.
curl -X POST http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{"model": ${MODEL_PATH}, \
"prompt": "The future of AI is", \
"max_tokens": 128, "temperature": 0.7 \
}'
```

50
cmake/utils.cmake Normal file
View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
#
# Attempt to find the python package that uses the same python executable as
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
#
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
message(FATAL_ERROR
"Python version (${_VER}) is not one of the supported versions: "
"${_SUPPORTED_VERSIONS_LIST}.")
endif()
message(STATUS "Found python matching: ${EXECUTABLE}.")
endmacro()
#
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
# has trailing whitespace stripped. If an error is encountered when running
# python, a fatal message `ERR_MSG` is issued.
#
function (run_python OUT EXPR ERR_MSG)
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}" "-c" "${EXPR}"
OUTPUT_VARIABLE PYTHON_OUT
RESULT_VARIABLE PYTHON_ERROR_CODE
ERROR_VARIABLE PYTHON_STDERR
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT PYTHON_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
endif()
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
endfunction()
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
macro (append_cmake_prefix_path PKG EXPR)
run_python(_PREFIX_PATH
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
endmacro()

310
csrc/cnmem_allocator.cpp Normal file
View File

@@ -0,0 +1,310 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
// A MLU PluggableAllocator based on cn_api APIs.
#include <iostream>
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <cn_api.h>
#define DRV_CHECK_GET_RETURN(...) \
DRV_CHECK_GET_RETURN_IMPL(__VA_ARGS__, return, )
#define DRV_CHECK_GET_RETURN_IMPL(_1, _2, ...) _2
#define CN_CHECK(return_code, ...) \
do { \
CNresult rc = (return_code); \
if (rc) { \
const char *error_str; \
cnGetErrorString(rc, &error_str); \
std::cout << "Error: " << error_str \
<< " at " << __FILE__ \
<< ":" << __LINE__ \
<< std::endl; \
DRV_CHECK_GET_RETURN(__VA_ARGS__) \
__VA_ARGS__; \
} \
} while (0)
// Global references to Python callables
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;
// ---------------------------------------------------------------------------
// Helper functions:
void ensure_context(CNdev device) {
CNcontext pctx;
CN_CHECK(cnCtxGetCurrent(&pctx));
if (!pctx) {
// Ensure device context;
CN_CHECK(cnCtxCreate(&pctx, 0, device));
CN_CHECK(cnCtxSetCurrent(pctx));
}
}
void create_and_map(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) {
ensure_context(device);
// Define memory allocation properties
CNmemAllocationProp prop = {};
// The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide.
prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE;
prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE;
// Allocate memory using cnMemCreate
CN_CHECK(cnMemCreate(p_memHandle, size, &prop, 0));
CN_CHECK(cnMemMap(d_mem, size, 0, *p_memHandle, 0));
CNmemAccessDesc accessDesc = {};
accessDesc.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
accessDesc.accessFlags = CN_MEM_ACCESS_FLAGS_PROT_READWRITE;
CN_CHECK(cnMemSetAccess(d_mem, size, &accessDesc, 1));
}
void unmap_and_release(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) {
ensure_context(device);
CN_CHECK(cnMemUnmap(d_mem, size));
CN_CHECK(cnMemRelease(*p_memHandle));
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
unsigned long long b,
unsigned long long c,
unsigned long long d) {
// Create a new tuple of size 4
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL;
}
// Convert integers to Python objects and set them in the tuple
// Steals reference to the PyLong
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.
return tuple;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, CNqueue stream) {
ensure_context(device);
// first allocation, align the size, and reserve an address, and also allocate
// a CNmemGenericAllocationHandle
// Define memory allocation properties
CNmemAllocationProp prop = {};
// The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide.
prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE;
prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE;
//Check if the allocation is supported
size_t granularity;
CN_CHECK(cnMemGetAllocationGranularity(&granularity, &prop, CN_MEM_ALLOC_GRANULARITY_MINIMUM), nullptr);
size_t alignedSize = ((size+granularity-1)/granularity)*granularity;
CNaddr d_mem;
CN_CHECK(cnMemAddressReserve(&d_mem, alignedSize, 0, 0, 0), nullptr);
// allocate the CNmemGenericAllocationHandle
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)malloc(sizeof(CNmemGenericAllocationHandle));
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
// Call g_python_malloc_callback
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);
if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}
PyGILState_Release(gstate);
// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);
return (void*)d_mem;
}
__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, CNqueue stream) {
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "ERROR: g_python_free_callback not set.\n";
return;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}
PyGILState_Release(gstate);
// Free memory
CNaddr d_mem = (CNaddr)recv_d_mem;
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
//free address and the handle
CN_CHECK(cnMemAddressFree(d_mem, size));
free(p_memHandle);
}
// ---------------------------------------------------------------------------
// Python extension boilerplate:
// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args) {
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
// This module keeps a strong reference to prevent premature GC
Py_XINCREF(malloc_callback);
Py_XINCREF(free_callback);
Py_XDECREF(g_python_malloc_callback);
Py_XDECREF(g_python_free_callback);
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;
Py_RETURN_NONE;
}
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CNaddr d_mem_ptr = (CNaddr)recv_d_mem;
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CNaddr d_mem_ptr = (CNaddr)recv_d_mem;
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_cn_memcpy(PyObject* self, PyObject* args){
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 3) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 3");
return nullptr;
}
CNaddr dst, src;
cn_uint64_t bytes;
if (!PyArg_ParseTuple(args, "KKK", &dst, &src, &bytes)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CN_CHECK(cnMemcpy(dst, src, bytes), nullptr);
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
"Initialize module with python_malloc and python_free callables."},
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
"Create and map memory on the device."},
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
METH_VARARGS, "Unmap and release memory on the device."},
{"python_cn_memcpy", (PyCFunction)python_cn_memcpy, METH_VARARGS, "Copies data from source address to destination address."},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef cnmem_allocator_module = {
PyModuleDef_HEAD_INIT, "cnmem_allocator",
"cnapi-mem-based allocator for MLUPluggableAllocator", -1, module_methods};
PyMODINIT_FUNC PyInit_vllm_mlu_C(void) {
// Initialize the module
PyObject* module = PyModule_Create(&cnmem_allocator_module);
if (!module) {
return NULL;
}
return module;
}
} // extern "C"

34
csrc/ops.h Normal file
View File

@@ -0,0 +1,34 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
#pragma once
#include <optional>
#include <torch/library.h>
#include <vector>
namespace vllm_mlu {
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
// Ensure tensor is on MLU
if (!tensor.is_privateuseone()) {
throw std::runtime_error("Tensor must be on MLU device");
}
// Get the raw data pointer
void* data_ptr = tensor.data_ptr();
// Get tensor sizes and strides
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
// Get tensor options (dtype, device)
auto options = tensor.options();
// Create a new tensor from the raw data pointer
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
}

18
csrc/torch_bindings.cpp Normal file
View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <pybind11/pybind11.h>
#include "ops.h"
#include "utils.h"
TORCH_LIBRARY_EXPAND(_C, ops)
{
// vLLM-MLU custom ops
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_mlu::weak_ref_tensor);
}
REGISTER_EXTENSION(_C)

29
csrc/utils.h Normal file
View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
#pragma once
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import sys
from vllm import LLM, SamplingParams
def main(model_path):
# Sample prompts.
prompts = [
"The benefits of exercise include",
"The importance of reading books is",
"Gardening can be relaxing because",
"A good night's sleep is essential for",
]
sampling_params = SamplingParams(
temperature=0.6, top_p=0.95, max_tokens=10)
# Create an LLM.
engine_args_dict = {
"model": model_path,
"tensor_parallel_size": 8,
"enable_expert_parallel": True,
"enable_prefix_caching": False,
"enforce_eager": True,
"trust_remote_code": True,
"max_num_seqs": len(prompts),
"max_model_len": 4096,
"block_size": 1,
"gpu_memory_utilization": 0.96,
}
llm = LLM(**engine_args_dict)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == '__main__':
if len(sys.argv) < 2:
print("Usage: python offline_inference.py <model_path>")
sys.exit(1)
main(sys.argv[1])

14
requirements.txt Normal file
View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# Dependencies for Cambricon MLUs
ray == 2.51.1
click == 8.2.1
triton >= 3.2.0
torch == 2.9.1
torch-mlu >= 1.29.1
torch_mlu_ops >= 1.8.1
matplotlib == 3.10.3
datasets == 3.6.0
blobfile == 3.0.0
scipy == 1.10.1

276
setup.py Normal file
View File

@@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import importlib.util
import io
import logging
import os
import re
import subprocess
import sys
from sysconfig import get_paths
from typing import List, Dict
from setuptools import Extension
from setuptools import find_namespace_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.install import install
from setuptools.command.develop import develop
ROOT_DIR = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
def check_or_set_default_env(cmake_args, env_name, env_variable, default_path=""):
if env_variable is None:
logging.warning(f"Set default {env_name}: {default_path}")
env_variable = default_path
else:
logging.info(f"Found existing {env_name}: {env_variable}")
cmake_args += [f"-D{env_name}={env_variable}"]
return cmake_args
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm_mlu", "envs.py"))
class CMakeExtension(Extension):
def __init__(self,
name: str,
cmake_lists_dir: str = ".",
**kwargs) -> None:
super().__init__(name, sources=[], py_limited_api=False, **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def get_vllm_version() -> str:
"""
get vllm version
"""
with open(get_path("tools/build.property"), 'r') as file:
content = file.read()
results = re.findall(r'VLLM_VERSION=([\d|\.]+)\+mlu([\d|\.]+)\.pt(\d+)', content)
assert results, "fail to get vllm, vllm_mlu and pytorch version."
version = f"{results[-1][0]}+mlu{results[-1][1]}.pt{results[-1][2]}"
return version
def read_readme() -> str:
"""Read the README file if present."""
p = get_path("README.md")
if os.path.isfile(p):
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
else:
return ""
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
def _read_requirements(filename: str) -> List[str]:
with open(get_path(filename)) as f:
requirements = f.read().strip().split("\n")
resolved_requirements = []
for line in requirements:
if line.startswith("-r "):
resolved_requirements += _read_requirements(line.split()[1])
elif line.startswith("--"):
continue
else:
resolved_requirements.append(line)
return resolved_requirements
return _read_requirements("requirements.txt")
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config: Dict[str, bool] = {}
# Determine number of compilation jobs
def compute_num_jobs(self):
# `num_jobs` is either the value of the MAX_JOBS environment variable
# (if defined) or the number of CPUs available.
num_jobs = envs.MAX_JOBS
if num_jobs is not None:
num_jobs = int(num_jobs)
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
else:
try:
# os.sched_getaffinity() isn't universally available, so fall
# back to os.cpu_count() if we get an error here.
num_jobs = len(os.sched_getaffinity(0))
except AttributeError:
num_jobs = os.cpu_count()
num_jobs = max(1, num_jobs)
return num_jobs
#
# Perform cmake configuration for a single extension.
#
def configure(self, ext: CMakeExtension) -> None:
os.makedirs(self.build_temp, exist_ok=True)
source_dir = os.path.abspath(ROOT_DIR)
python_executable = sys.executable
cmake_args = ["cmake"]
# Default use release mode to compile the csrc code
# Turbo now support compiled with Release, Debug and RelWithDebugInfo
if envs.CMAKE_BUILD_TYPE is None or envs.CMAKE_BUILD_TYPE not in [
"Debug",
"Release",
"RelWithDebugInfo",
]:
envs.CMAKE_BUILD_TYPE = "Release"
cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"]
# Default dump the compile commands for lsp
cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"]
if envs.CXX_COMPILER is not None:
cmake_args += [f"-DCMAKE_CXX_COMPILER={envs.CXX_COMPILER}"]
if envs.C_COMPILER is not None:
cmake_args += [f"-DCMAKE_C_COMPILER={envs.C_COMPILER}"]
if envs.VERBOSE:
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
# find PYTHON_EXECUTABLE
check_or_set_default_env(cmake_args, "PYTHON_EXECUTABLE", sys.executable)
# find PYTHON_INCLUDE_PATH
check_or_set_default_env(cmake_args, "PYTHON_INCLUDE_PATH",
get_paths()["include"])
try:
# if pybind11 is installed via pip
subprocess.check_call([sys.executable, "-m", "pip", "install", "pybind11==2.13.6"])
pybind11_cmake_path = (subprocess.check_output([python_executable, "-m",
"pybind11", "--cmake"]).decode().strip())
except subprocess.CalledProcessError as e:
# else specify pybind11 path installed from source code on CI container
raise RuntimeError(f"CMake configuration failed: {e}")
install_path = os.path.join(ROOT_DIR, self.build_lib)
if isinstance(self.distribution.get_command_obj("develop"), develop):
install_path = os.path.join(ROOT_DIR, "vllm_mlu")
# add CMAKE_INSTALL_PATH
cmake_args += [f"-DCMAKE_INSTALL_PREFIX={install_path}"]
cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"]
cmake_args += [source_dir]
logging.info(f"cmake config command: {cmake_args}")
try:
subprocess.check_call(cmake_args, cwd=self.build_temp)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"CMake configuration failed: {e}")
def build_extensions(self) -> None:
if not envs.COMPILE_CUSTOM_KERNELS:
return
# Ensure that CMake is present and working
try:
subprocess.check_output(["cmake", "--version"])
except OSError as e:
raise RuntimeError(f"Cannot find CMake executable: {e}")
# Create build directory if it does not exist.
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
os.makedirs(os.path.join(self.build_lib, "vllm_mlu"), exist_ok=True)
targets = []
def get_target_name(s: str) -> str:
return s.removeprefix("vllm_mlu.")
# Build all the extensions
for ext in self.extensions:
self.configure(ext)
targets.append(get_target_name(ext.name))
num_jobs = self.compute_num_jobs()
build_args = ["--build", ".", f"-j={num_jobs}",
*[f"--target={name}" for name in targets],
]
logger.info(build_args)
try:
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
except OSError as e:
raise RuntimeError(f"Build library failed: {e}")
# Install the libraries
install_args = ["--install", ".", ]
try:
subprocess.check_call(["cmake", *install_args], cwd=self.build_temp)
except OSError as e:
raise RuntimeError(f"Install library failed: {e}")
# copy back to build folder for editable build
if isinstance(self.distribution.get_command_obj("develop"), develop):
for root, _, files in os.walk(self.build_temp):
for file in files:
if file.endswith(".so"):
src_path = os.path.join(root, file)
dst_path = os.path.join(self.build_lib, "vllm_mlu", file)
self.copy_file(src_path, dst_path)
logger.info(f"Copy: {src_path} -> {dst_path}")
def run(self):
# First, run the standard build_ext command to compile the extensions
super().run()
class custom_install(install):
def run(self):
self.run_command("build_ext")
install.run(self)
ext_modules = []
if envs.COMPILE_CUSTOM_KERNELS:
ext_modules = [CMakeExtension(name="vllm_mlu.vllm_mlu_C")]
cmdclass = {"build_ext": cmake_build_ext, "install": custom_install}
setup(
name="vllm_mlu",
version=get_vllm_version(),
author="Cambricon vLLM Team",
license="Apache 2.0",
description=("A high-throughput and memory-efficient inference and "
"serving engine for LLMs on MLU backend"),
long_description=read_readme(),
long_description_content_type="text/markdown",
url="",
project_urls={
"Homepage": "https://github.com/vllm-project/vllm",
"Documentation": "https://vllm.readthedocs.io/en/latest/",
},
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=find_namespace_packages(exclude=("docs", "examples", "tests*", "csrc")),
include_package_data=True,
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules = ext_modules,
cmdclass=cmdclass,
entry_points={
'vllm.platform_plugins': ["mlu = vllm_mlu:register_mlu_platform"],
'vllm.general_plugins': ["mlu_hijack = vllm_mlu:register_mlu_hijack"]
}
)

89
tools/ray_mlu/__init__.py Normal file
View File

@@ -0,0 +1,89 @@
from typing import Optional, Set
from ray._private.accelerators.accelerator import (
RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR,
AcceleratorManager,
)
from ray._private.accelerators.amd_gpu import AMDGPUAcceleratorManager
from ray._private.accelerators.hpu import HPUAcceleratorManager
from ray._private.accelerators.intel_gpu import IntelGPUAcceleratorManager
from ray._private.accelerators.neuron import NeuronAcceleratorManager
from ray._private.accelerators.npu import NPUAcceleratorManager
from ray._private.accelerators.nvidia_gpu import NvidiaGPUAcceleratorManager
from ray._private.accelerators.rbln import RBLNAcceleratorManager
from ray._private.accelerators.tpu import TPUAcceleratorManager
from ray._private.accelerators.mlu import MLUAcceleratorManager
def get_all_accelerator_managers() -> Set[AcceleratorManager]:
"""Get all accelerator managers supported by Ray."""
return {
NvidiaGPUAcceleratorManager,
IntelGPUAcceleratorManager,
AMDGPUAcceleratorManager,
TPUAcceleratorManager,
NeuronAcceleratorManager,
HPUAcceleratorManager,
NPUAcceleratorManager,
RBLNAcceleratorManager,
MLUAcceleratorManager,
}
def get_all_accelerator_resource_names() -> Set[str]:
"""Get all resource names for accelerators."""
return {
accelerator_manager.get_resource_name()
for accelerator_manager in get_all_accelerator_managers()
}
def get_accelerator_manager_for_resource(
resource_name: str,
) -> Optional[AcceleratorManager]:
"""Get the corresponding accelerator manager for the given
accelerator resource name
E.g., TPUAcceleratorManager is returned if resource name is "TPU"
"""
try:
return get_accelerator_manager_for_resource._resource_name_to_accelerator_manager.get( # noqa: E501
resource_name, None
)
except AttributeError:
# Lazy initialization.
resource_name_to_accelerator_manager = {
accelerator_manager.get_resource_name(): accelerator_manager
for accelerator_manager in get_all_accelerator_managers()
}
# Special handling for GPU resource name since multiple accelerator managers
# have the same GPU resource name.
if AMDGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager
elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager
elif MLUAcceleratorManager.get_current_node_num_accelerators() > 0:
resource_name_to_accelerator_manager["GPU"] = MLUAcceleratorManager
else:
resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager
get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = (
resource_name_to_accelerator_manager
)
return resource_name_to_accelerator_manager.get(resource_name, None)
__all__ = [
"NvidiaGPUAcceleratorManager",
"IntelGPUAcceleratorManager",
"AMDGPUAcceleratorManager",
"TPUAcceleratorManager",
"NeuronAcceleratorManager",
"HPUAcceleratorManager",
"NPUAcceleratorManager",
"RBLNAcceleratorManager",
"MLUAcceleratorManager",
"get_all_accelerator_managers",
"get_all_accelerator_resource_names",
"get_accelerator_manager_for_resource",
"RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR",
]

View File

@@ -0,0 +1,114 @@
import logging
import threading
from typing import Optional
import ray
import ray._private.ray_constants as ray_constants
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager
from ray.air._internal.device_manager.mlu import MLUTorchDeviceManager
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
logger = logging.getLogger(__name__)
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager
'''
=============================
Modify by vllm_mlu
=============================
@brief: use MLUTorchDeviceManager when key="GPU"
'''
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = {
ray_constants.GPU: MLUTorchDeviceManager,
ray_constants.HPU: HPUTorchDeviceManager,
ray_constants.NPU: NPUTorchDeviceManager,
}
'''
==================
End of MLU Hijack
==================
'''
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None:
if backend == "hccl":
# The name for the communication backend of Habana and torch-npu is the same.
HPUTorchDeviceManager.register_custom_torch_dist_backend()
NPUTorchDeviceManager.register_custom_torch_dist_backend()
_torch_device_manager = None
_torch_device_manager_lock = threading.Lock()
def get_torch_device_manager_by_context() -> TorchDeviceManager:
global _torch_device_manager
with _torch_device_manager_lock:
if not _torch_device_manager:
existing_device_manager_cls = None
resources = ray.get_runtime_context().get_accelerator_ids()
# select correct accelerator type from resources
for resource_type, resource_value in resources.items():
device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get(
resource_type, None
)
if resource_value and device_manager_cls:
# An error will raise when multiple accelerators are specified.
if existing_device_manager_cls:
raise RuntimeError(
"Unable to determine the appropriate DeviceManager "
f"for the specified resources {resources}."
)
else:
existing_device_manager_cls = device_manager_cls
device_manager_cls = (
existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS
)
_torch_device_manager = device_manager_cls()
return _torch_device_manager
def get_torch_device_manager_by_device_type(device_type: str):
'''
=============================
Modify by vllm_mlu
=============================
@brief: use MLUTorchDeviceManager when key="GPU"
'''
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda":
return MLUTorchDeviceManager()
elif device_type.lower() == ray_constants.NPU.lower():
return NPUTorchDeviceManager()
elif device_type.lower() == ray_constants.HPU.lower():
return HPUTorchDeviceManager()
elif device_type.lower() == "cpu":
return CPUTorchDeviceManager()
'''
==================
End of MLU Hijack
==================
'''
raise RuntimeError(f"Device type {device_type} cannot be recognized.")
__all__ = [
TorchDeviceManager,
CPUTorchDeviceManager,
CUDATorchDeviceManager,
HPUTorchDeviceManager,
NPUTorchDeviceManager,
MLUTorchDeviceManager,
register_custom_torch_dist_backend,
get_torch_device_manager_by_context,
get_torch_device_manager_by_device_type,
]

View File

@@ -0,0 +1,103 @@
import os
from importlib.util import find_spec
from typing import List, Union
import torch
import ray
import ray._private.ray_constants as ray_constants
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
from ray._private.accelerators.mlu import MLU_VISIBLE_DEVICES_ENV_VAR
def is_package_present(package_name: str) -> bool:
try:
return find_spec(package_name) is not None
except ModuleNotFoundError:
return False
MLU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_mlu")
if MLU_TORCH_PACKAGE_AVAILABLE:
import torch_mlu # noqa: F401
class MLUTorchDeviceManager(TorchDeviceManager):
"""Cambricon MLU device manager"""
@staticmethod
def register_custom_torch_dist_backend():
if MLU_TORCH_PACKAGE_AVAILABLE:
import torch_mlu # noqa: F401, F811
def is_available(self) -> bool:
if not MLU_TORCH_PACKAGE_AVAILABLE:
return False
return torch.mlu.is_available()
def get_devices(self) -> List[torch.device]:
"""Gets the correct torch device list configured for this process.
Returns a list of torch MLU devices allocated for the current worker.
If no MLUs are assigned, then it returns a list with a single CPU device.
"""
if MLU_TORCH_PACKAGE_AVAILABLE and torch.mlu.is_available():
mlu_ids = [
str(id)
for id in ray.get_runtime_context().get_accelerator_ids()[
ray_constants.GPU
]
]
device_ids = []
if len(mlu_ids) > 0:
mlu_visible_str = os.environ.get(MLU_VISIBLE_DEVICES_ENV_VAR, "")
if mlu_visible_str and mlu_visible_str != "NoDevFiles":
mlu_visible_list = mlu_visible_str.split(",")
else:
mlu_visible_list = []
for mlu_id in mlu_ids:
try:
device_ids.append(mlu_visible_list.index(mlu_id))
except IndexError:
raise RuntimeError(
"MLU_VISIBLE_DEVICES set incorrectly. "
f"Got {mlu_visible_str}, expected to include {mlu_id}. "
"Did you override the `MLU_VISIBLE_DEVICES` "
"environment variable?"
)
else:
# If called on the driver or outside of Ray Train, return the
# 0th device.
device_ids.append(0)
devices = [torch.device(f"mlu:{device_id}") for device_id in device_ids]
else:
raise RuntimeError(
"Using MLUTorchDeviceManager but torch mlu is not available."
)
return devices
def set_device(self, device: Union[torch.device, int]):
torch.mlu.set_device(device)
def supports_stream(self) -> bool:
"""Validate if the device type support to create a stream"""
return True
def create_stream(self, device):
"""Create a stream on MLU device"""
return torch.mlu.Stream(device)
def get_stream_context(self, stream):
"""Get a torch.stream context on MLU device"""
return torch.mlu.stream(stream)
def get_current_stream(self):
"""Get current stream for MLU device"""
return torch.mlu.current_stream()

243
tools/ray_mlu/diff.patch Normal file
View File

@@ -0,0 +1,243 @@
commit 7376225d16e381ecae5cc07d84db9eed043ed06a
Author: tanhaojue <tanhaojue@cambricon.com>
Date: Thu Mar 7 15:54:09 2024 +0800
support mlu
diff --git a/python/ray/_private/accelerators/__init__.py b/python/ray/_private/accelerators/__init__.py
index 71550bc..07bdcd6 100644
--- a/python/ray/_private/accelerators/__init__.py
+++ b/python/ray/_private/accelerators/__init__.py
@@ -8,6 +8,7 @@ from ray._private.accelerators.tpu import TPUAcceleratorManager
from ray._private.accelerators.neuron import NeuronAcceleratorManager
from ray._private.accelerators.hpu import HPUAcceleratorManager
from ray._private.accelerators.npu import NPUAcceleratorManager
+from ray._private.accelerators.mlu import MLUAcceleratorManager
def get_all_accelerator_managers() -> Set[AcceleratorManager]:
@@ -20,6 +21,7 @@ def get_all_accelerator_managers() -> Set[AcceleratorManager]:
NeuronAcceleratorManager,
HPUAcceleratorManager,
NPUAcceleratorManager,
+ MLUAcceleratorManager,
}
@@ -55,6 +57,8 @@ def get_accelerator_manager_for_resource(
resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager
elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager
+ elif MLUAcceleratorManager.get_current_node_num_accelerators() > 0:
+ resource_name_to_accelerator_manager["GPU"] = MLUAcceleratorManager
else:
resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager
get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = (
@@ -71,6 +75,7 @@ __all__ = [
"NeuronAcceleratorManager",
"HPUAcceleratorManager",
"NPUAcceleratorManager",
+ "MLUAcceleratorManager",
"get_all_accelerator_managers",
"get_all_accelerator_resource_names",
"get_accelerator_manager_for_resource",
diff --git a/python/ray/_private/accelerators/mlu.py b/python/ray/_private/accelerators/mlu.py
new file mode 100755
index 0000000..21a5771
--- /dev/null
+++ b/python/ray/_private/accelerators/mlu.py
@@ -0,0 +1,92 @@
+import os
+import glob
+import logging
+from typing import Optional, List, Tuple
+import torch
+import torch_mlu
+from ray._private.accelerators.accelerator import AcceleratorManager
+
+logger = logging.getLogger(__name__)
+
+MLU_VISIBLE_DEVICES_ENV_VAR = "MLU_VISIBLE_DEVICES"
+NOSET_MLU_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_MLU_VISIBLE_DEVICES"
+
+
+class MLUAcceleratorManager(AcceleratorManager):
+ """Cambricon MLU accelerators."""
+
+ @staticmethod
+ def get_resource_name() -> str:
+ return "GPU"
+
+ @staticmethod
+ def get_visible_accelerator_ids_env_var() -> str:
+ return MLU_VISIBLE_DEVICES_ENV_VAR
+
+ @staticmethod
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
+ mlu_visible_devices = os.environ.get(
+ MLUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
+ )
+
+ if mlu_visible_devices is None:
+ return None
+
+ if mlu_visible_devices == "":
+ return []
+
+ if mlu_visible_devices == "NoDevFiles":
+ return []
+
+ return list(mlu_visible_devices.split(","))
+
+ @staticmethod
+ def get_current_node_num_accelerators() -> int:
+ """Attempt to detect the number of MLUs on this machine.
+
+ MLU chips are represented as devices within `/dev/`, either as `/dev/davinci?`.
+
+ Returns:
+ The number of MLUs if any were detected, otherwise 0.
+ """
+ try:
+ return torch.mlu.device_count()
+ except Exception as e:
+ logger.debug("Could not import CambriconCL: %s", e)
+
+ try:
+ mlu_files = glob.glob("/dev/cambricon_dev?")
+ return len(mlu_files)
+ except Exception as e:
+ logger.debug("Failed to detect number of MLUs: %s", e)
+ return 0
+
+ @staticmethod
+ def get_current_node_accelerator_type() -> Optional[str]:
+ """Get the type of the Cambricon MLU on the current node.
+
+ Returns:
+ A string of the type, such as "MLU370".
+ """
+ try:
+ return torch.mlu.get_device_name(0)
+ except Exception:
+ logger.exception("Failed to detect MLU type.")
+ return None
+
+ @staticmethod
+ def validate_resource_request_quantity(
+ quantity: float,
+ ) -> Tuple[bool, Optional[str]]:
+ return (True, None)
+
+ @staticmethod
+ def set_current_process_visible_accelerator_ids(
+ visible_mlu_devices: List[str],
+ ) -> None:
+ if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR):
+ return
+
+ os.environ[
+ MLUAcceleratorManager.get_visible_accelerator_ids_env_var()
+ ] = ",".join([str(i) for i in visible_mlu_devices])
diff --git a/python/ray/tests/accelerators/test_mlu.py b/python/ray/tests/accelerators/test_mlu.py
new file mode 100755
index 0000000..70e81f7
--- /dev/null
+++ b/python/ray/tests/accelerators/test_mlu.py
@@ -0,0 +1,92 @@
+import os
+import sys
+import pytest
+from unittest.mock import patch
+
+import ray
+from ray._private.accelerators import MLUAcceleratorManager as Accelerator
+
+
+@patch("glob.glob")
+@patch("os.listdir")
+def test_autodetect_num_mlus(mock_list, mock_glob):
+ mock_glob.return_value = [f"/dev/davinci{i}" for i in range(4)]
+ # mock_list.return_value = []
+ assert Accelerator.get_current_node_num_accelerators() == 4
+
+
+@patch("glob.glob")
+@patch("os.listdir")
+def test_autodetect_num_mlus_without_devices(mock_list, mock_glob):
+ mock_glob.side_effect = Exception
+ # mock_list.return_value = []
+ assert Accelerator.get_current_node_num_accelerators() == 0
+
+
+def test_mlu_accelerator_manager_api():
+ assert Accelerator.get_resource_name() == "MLU"
+ assert Accelerator.get_visible_accelerator_ids_env_var() == "MLU_VISIBLE_DEVICES"
+ assert Accelerator.validate_resource_request_quantity(0.5) == (True, None)
+ assert Accelerator.validate_resource_request_quantity(1) == (True, None)
+
+
+def test_visible_mlu_type(monkeypatch, shutdown_only):
+ with patch.object(
+ Accelerator, "get_current_node_num_accelerators", return_value=4
+ ), patch.object(
+ Accelerator, "get_current_node_accelerator_type", return_value="MLU370"
+ ):
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
+ manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
+ assert manager.get_current_node_accelerator_type() == "MLU370"
+
+@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
+def test_visible_mlu_ids(monkeypatch, shutdown_only):
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
+ with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
+
+ ray.init()
+ manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
+ assert manager.get_current_node_num_accelerators() == 4
+ assert manager.__name__ == "MLUAcceleratorManager"
+ assert ray.available_resources()["MLU"] == 3
+
+def test_get_current_process_visible_accelerator_ids(monkeypatch, shutdown_only):
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
+ assert Accelerator.get_current_process_visible_accelerator_ids() == ["0", "1", "2"]
+
+ monkeypatch.delenv("MLU_VISIBLE_DEVICES")
+ assert Accelerator.get_current_process_visible_accelerator_ids() is None
+
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "")
+ assert Accelerator.get_current_process_visible_accelerator_ids() == []
+
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "NoDevFiles")
+ assert Accelerator.get_current_process_visible_accelerator_ids() == []
+
+
+def test_set_current_process_visible_accelerator_ids(shutdown_only):
+ Accelerator.set_current_process_visible_accelerator_ids(["0"])
+ assert os.environ["MLU_VISIBLE_DEVICES"] == "0"
+
+ Accelerator.set_current_process_visible_accelerator_ids(["0", "1"])
+ assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1"
+
+ Accelerator.set_current_process_visible_accelerator_ids(["0", "1", "2"])
+ assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1,2"
+
+
+@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
+def test_auto_detected_more_than_visible(monkeypatch, shutdown_only):
+ with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
+ # If more MLUs are detected than visible.
+ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
+
+ ray.init()
+ assert ray.available_resources()["MLU"] == 3
+
+if __name__ == "__main__":
+ if os.environ.get("PARALLEL_CI"):
+ sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
+ else:
+ sys.exit(pytest.main(["-sv", __file__]))
diff --git a/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl b/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
new file mode 100644
index 0000000..8628a88
Binary files /dev/null and b/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl differ

View File

@@ -0,0 +1,11 @@
diff --git a/ray_mlu/mlu.py b/ray_mlu/mlu.py
index 21a57719..2c63fd5b 100755
--- a/ray_mlu/mlu.py
+++ b/ray_mlu/mlu.py
@@ -87,6 +87,3 @@ class MLUAcceleratorManager(AcceleratorManager):
if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR):
return
- os.environ[
- MLUAcceleratorManager.get_visible_accelerator_ids_env_var()
- ] = ",".join([str(i) for i in visible_mlu_devices])

94
tools/ray_mlu/mlu.py Executable file
View File

@@ -0,0 +1,94 @@
import os
import glob
import logging
from typing import Optional, List, Tuple
import torch
import torch_mlu
from ray._private.accelerators.accelerator import AcceleratorManager
logger = logging.getLogger(__name__)
MLU_VISIBLE_DEVICES_ENV_VAR = "MLU_VISIBLE_DEVICES"
NOSET_MLU_VISIBLE_DEVICES_ENV_VAR = (
"RAY_EXPERIMENTAL_NOSET_MLU_VISIBLE_DEVICES"
)
class MLUAcceleratorManager(AcceleratorManager):
"""Cambricon MLU accelerators."""
@staticmethod
def get_resource_name() -> str:
return "GPU"
@staticmethod
def get_visible_accelerator_ids_env_var() -> str:
return MLU_VISIBLE_DEVICES_ENV_VAR
@staticmethod
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
mlu_visible_devices = os.environ.get(
MLUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
)
if mlu_visible_devices is None:
return None
if mlu_visible_devices == "":
return []
if mlu_visible_devices == "NoDevFiles":
return []
return list(mlu_visible_devices.split(","))
@staticmethod
def get_current_node_num_accelerators() -> int:
"""Attempt to detect the number of MLUs on this machine.
MLU chips are represented as devices within `/dev/`, either as `/dev/davinci?`.
Returns:
The number of MLUs if any were detected, otherwise 0.
"""
try:
return torch.mlu.device_count()
except Exception as e:
logger.debug("Could not import CambriconCL: %s", e)
try:
mlu_files = glob.glob("/dev/cambricon_dev?")
return len(mlu_files)
except Exception as e:
logger.debug("Failed to detect number of MLUs: %s", e)
return 0
@staticmethod
def get_current_node_accelerator_type() -> Optional[str]:
"""Get the type of the Cambricon MLU on the current node.
Returns:
A string of the type, such as "MLU370".
"""
try:
return torch.mlu.get_device_name(0)
except Exception:
logger.exception("Failed to detect MLU type.")
return None
@staticmethod
def validate_resource_request_quantity(
quantity: float,
) -> Tuple[bool, Optional[str]]:
return (True, None)
@staticmethod
def set_current_process_visible_accelerator_ids(
visible_mlu_devices: List[str],
) -> None:
if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR):
return
os.environ[
MLUAcceleratorManager.get_visible_accelerator_ids_env_var()
] = ",".join([str(i) for i in visible_mlu_devices])

1890
tools/ray_mlu/node.py Normal file

File diff suppressed because it is too large Load Diff

142
tools/ray_mlu/nsight.py Normal file
View File

@@ -0,0 +1,142 @@
import asyncio
import copy
import logging
import os
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from ray._common.utils import (
try_to_create_directory,
)
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray.exceptions import RuntimeEnvSetupError
default_logger = logging.getLogger(__name__)
# Nsight options used when runtime_env={"_nsight": "default"}
# use default cnperf config, no need to specify any options
NSIGHT_DEFAULT_CONFIG = {}
def parse_nsight_config(nsight_config: Dict[str, str]) -> List[str]:
"""
Function to convert dictionary of nsight options into
nsight command line
The function returns:
- List[str]: nsys profile cmd line split into list of str
"""
nsight_cmd = ["cnperf-cli", "record"]
for option, option_val in nsight_config.items():
# option standard based on
# https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html
if len(option) > 1:
nsight_cmd.append(f"--{option}={option_val}")
else:
nsight_cmd += [f"-{option}", option_val]
return nsight_cmd
class NsightPlugin(RuntimeEnvPlugin):
name = "_nsight"
def __init__(self, resources_dir: str):
self.nsight_cmd = []
# replace this with better way to get logs dir
session_dir, runtime_dir = os.path.split(resources_dir)
self._nsight_dir = Path(session_dir) / "logs" / "nsight"
try_to_create_directory(self._nsight_dir)
async def _check_nsight_script(
self, nsight_config: Dict[str, str]
) -> Tuple[bool, str]:
"""
Function to validate if nsight_config is a valid nsight profile options
Args:
nsight_config: dictionary mapping nsight option to it's value
Returns:
a tuple consists of a boolean indicating if the nsight_config
is valid option and an error message if the nsight_config is invalid
"""
# use empty as nsight report test filename
nsight_config_copy = copy.deepcopy(nsight_config)
try_to_create_directory(Path(self._nsight_dir) / "empty")
nsight_config_copy["o"] = str(Path(self._nsight_dir) / "empty/test")
nsight_cmd = parse_nsight_config(nsight_config_copy)
try:
nsight_cmd = nsight_cmd + ["python", "-c", '""']
process = await asyncio.create_subprocess_exec(
*nsight_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await process.communicate()
error_msg = stderr.strip() if stderr.strip() != "" else stdout.strip()
# cleanup test.cnperf-rep file
clean_up_cmd = ["rm", f"{nsight_config_copy['o']}.cnperf-rep"]
cleanup_process = await asyncio.create_subprocess_exec(
*clean_up_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
_, _ = await cleanup_process.communicate()
if process.returncode == 0:
return True, None
else:
return False, error_msg
except FileNotFoundError:
return False, ("cnperf-cli is not installed")
async def create(
self,
uri: Optional[str],
runtime_env: "RuntimeEnv", # noqa: F821
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
) -> int:
nsight_config = runtime_env.nsight()
if not nsight_config:
return 0
if nsight_config and sys.platform != "linux":
raise RuntimeEnvSetupError(
"CNPerf CLI is only available in Linux.\n"
"More information can be found in "
"https://docs.nvidia.com/nsight-compute/NsightComputeCli/index.html"
)
if isinstance(nsight_config, str):
if nsight_config == "default":
nsight_config = NSIGHT_DEFAULT_CONFIG
else:
raise RuntimeEnvSetupError(
f"Unsupported nsight config: {nsight_config}. "
"The supported config is 'default' or "
"Dictionary of cnperf options"
)
is_valid_nsight_cmd, error_msg = await self._check_nsight_script(nsight_config)
if not is_valid_nsight_cmd:
logger.warning(error_msg)
raise RuntimeEnvSetupError(
"cnperf-cli failed to run with the following "
f"error message:\n {error_msg}"
)
self.nsight_cmd = parse_nsight_config(nsight_config)
return 0
def modify_context(
self,
uris: List[str],
runtime_env: "RuntimeEnv", # noqa: F821
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
):
context.py_executable = " ".join(self.nsight_cmd) + " python"
logger.info("Running CNPerf cmd: %s", context.py_executable)

92
tools/ray_mlu/test_mlu.py Executable file
View File

@@ -0,0 +1,92 @@
import os
import sys
import pytest
from unittest.mock import patch
import ray
from ray._private.accelerators import MLUAcceleratorManager as Accelerator
@patch("glob.glob")
@patch("os.listdir")
def test_autodetect_num_mlus(mock_list, mock_glob):
mock_glob.return_value = [f"/dev/davinci{i}" for i in range(4)]
# mock_list.return_value = []
assert Accelerator.get_current_node_num_accelerators() == 4
@patch("glob.glob")
@patch("os.listdir")
def test_autodetect_num_mlus_without_devices(mock_list, mock_glob):
mock_glob.side_effect = Exception
# mock_list.return_value = []
assert Accelerator.get_current_node_num_accelerators() == 0
def test_mlu_accelerator_manager_api():
assert Accelerator.get_resource_name() == "MLU"
assert Accelerator.get_visible_accelerator_ids_env_var() == "MLU_VISIBLE_DEVICES"
assert Accelerator.validate_resource_request_quantity(0.5) == (True, None)
assert Accelerator.validate_resource_request_quantity(1) == (True, None)
def test_visible_mlu_type(monkeypatch, shutdown_only):
with patch.object(
Accelerator, "get_current_node_num_accelerators", return_value=4
), patch.object(
Accelerator, "get_current_node_accelerator_type", return_value="MLU370"
):
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
assert manager.get_current_node_accelerator_type() == "MLU370"
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
def test_visible_mlu_ids(monkeypatch, shutdown_only):
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
ray.init()
manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU")
assert manager.get_current_node_num_accelerators() == 4
assert manager.__name__ == "MLUAcceleratorManager"
assert ray.available_resources()["MLU"] == 3
def test_get_current_process_visible_accelerator_ids(monkeypatch, shutdown_only):
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
assert Accelerator.get_current_process_visible_accelerator_ids() == ["0", "1", "2"]
monkeypatch.delenv("MLU_VISIBLE_DEVICES")
assert Accelerator.get_current_process_visible_accelerator_ids() is None
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "")
assert Accelerator.get_current_process_visible_accelerator_ids() == []
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "NoDevFiles")
assert Accelerator.get_current_process_visible_accelerator_ids() == []
def test_set_current_process_visible_accelerator_ids(shutdown_only):
Accelerator.set_current_process_visible_accelerator_ids(["0"])
assert os.environ["MLU_VISIBLE_DEVICES"] == "0"
Accelerator.set_current_process_visible_accelerator_ids(["0", "1"])
assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1"
Accelerator.set_current_process_visible_accelerator_ids(["0", "1", "2"])
assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1,2"
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows")
def test_auto_detected_more_than_visible(monkeypatch, shutdown_only):
with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4):
# If more MLUs are detected than visible.
monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2")
ray.init()
assert ray.available_resources()["MLU"] == 3
if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))

3785
tools/ray_mlu/worker.py Normal file

File diff suppressed because it is too large Load Diff

15
vllm_mlu/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
def register_mlu_platform():
"""Register the MLU platform."""
return "vllm_mlu.platforms.mlu.MLUPlatform"
def register_mlu_hijack():
"""Register the MLU models and hijack."""
from vllm_mlu import mlu_hijack
from vllm_mlu.model_executor.models import register_model
register_model()
return

1853
vllm_mlu/_mlu_ops.py Normal file

File diff suppressed because it is too large Load Diff

107
vllm_mlu/_mlu_utils.py Normal file
View File

@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
import torch
import vllm.envs as envs
def _check_env(env, default=False):
if env in os.environ:
return os.environ[env].lower() in ["true", "1"]
return default
def _check_env_value(env, default=0):
if env in os.environ:
if not os.environ[env].isdigit():
raise ValueError(f"'{env}' should be set with integer")
value = int(os.environ[env])
return value
return default
def _check_env_float(env, default=0):
if env in os.environ:
try:
value = float(os.environ[env])
except ValueError:
raise ValueError(f"'{env}' should be set with float")
return value
return default
# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency.
VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False)
# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency.
VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False)
# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference.
VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False)
# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference.
VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False)
# VLLM_DUMP_MLU_INFO_DEBUG: Dump device debug info when running vLLM inference.
VLLM_DUMP_MLU_INFO_DEBUG = _check_env("VLLM_DUMP_MLU_INFO_DEBUG", default=False)
# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler.
VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False)
# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True.
# Set to False to disable warning messages.
VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True)
# VLLM_AVG_MOE_EN: make moe experts workload balance, default value is False.
VLLM_AVG_MOE_EN = _check_env("VLLM_AVG_MOE_EN", default=False) or _check_env("VLLM_RANDOM_MOE_EN", default=False)
VLLM_RANDOM_MOE_EN = _check_env("VLLM_RANDOM_MOE_EN", default=False)
# VLLM_LOGITS_USE_ALL_GATHER: use allgather for logits collection, default value is False.
VLLM_LOGITS_USE_ALL_GATHER = _check_env("VLLM_LOGITS_USE_ALL_GATHER", default=False)
VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO)
VLLM_DUMP_MLU_INFO_DEBUG = (VLLM_DUMP_MLU_INFO_DEBUG and VLLM_DUMP_MLU_INFO_EN)
# VLLM_V1_USE_UNCHUNK_SCHED: v1 use unchunk scheduler, default value is True.
VLLM_V1_USE_UNCHUNK_SCHED = _check_env("VLLM_V1_USE_UNCHUNK_SCHED", default=True)
# VLLM_V1_MIN_PREFILL_BATCH: the min scheduling batch in v1, default is 1.
VLLM_V1_MIN_PREFILL_BATCH = _check_env_value("VLLM_V1_MIN_PREFILL_BATCH", default=1)
# VLLM_V1_USE_FULL_GRAPH: v1 use full graph capture, default value is True.
VLLM_V1_USE_FULL_GRAPH = _check_env("VLLM_V1_USE_FULL_GRAPH", default=True)
# VLLM_V1_BENCHMARK: v1 benchmark, default value is False.
VLLM_V1_BENCHMARK = _check_env("VLLM_V1_BENCHMARK", default=False)
# VLLM_MTP_DEBUG: use to show mtp accepted rate, default value is False.
VLLM_MTP_DEBUG = _check_env("VLLM_MTP_DEBUG", default=False)
# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None
VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False)
# VLLM_MTP_FIXED_ACCEPTANCE_RATE: use fixed acceptance rate, default value is None.
VLLM_MTP_FIXED_ACCEPTANCE_RATE = _check_env_float("VLLM_MTP_FIXED_ACCEPTANCE_RATE", default=None)
# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None
VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False)
# VLLM_V1_UNCHUNK_SCHED_LOG: print v1 unchunk scheduler state
VLLM_V1_UNCHUNK_SCHED_LOG = _check_env("VLLM_V1_UNCHUNK_SCHED_LOG", default=False)
# VLLM_MOE_PREFILL_CHUNK_SIZE: in number of tokens. enabled when > 0.
VLLM_MOE_PREFILL_CHUNK_SIZE = _check_env_value("VLLM_MOE_PREFILL_CHUNK_SIZE", default=0)
# VLLM_CI_ACCURACY_TEST: CI accuracy test, default value is False.
VLLM_CI_ACCURACY_TEST = _check_env("VLLM_CI_ACCURACY_TEST", default=False)
# VLLM_DISAGG_TRANS_ALL_BLOCKS: optimize the performance of disagg
VLLM_DISAGG_TRANS_ALL_BLOCKS = _check_env("VLLM_DISAGG_TRANS_ALL_BLOCKS", default=True)
# vllm disagg debug
VLLM_DISAGG_CNPX_EXECUTE = _check_env("VLLM_DISAGG_CNPX_EXECUTE", default=False)
VLLM_DISAGG_CNPX_REQUEST = _check_env("VLLM_DISAGG_CNPX_REQUEST", default=False)
VLLM_DISAGG_FAKE_DECODER = _check_env("VLLM_DISAGG_FAKE_DECODER", default=False)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

351
vllm_mlu/attention/layer.py Normal file
View File

@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, cast
import torch
from torch import nn
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import MLAAttentionImpl
from vllm.attention.layer import Attention, MLAAttention, _init_kv_cache_quant
from vllm.attention.selector import get_attn_backend
from vllm.config.cache import CacheConfig
from vllm.config.vllm import QuantizationConfig, VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm_mlu.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.v1.kv_cache_interface import (
MLUFullAttentionSpec,
MLUMLAAttentionSpec,
MLUSlidingWindowSpec,
)
@maybe_transfer_kv_layer
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
kwargs: dict[str, Any] = {},
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
'''
=============================
Modify by vllm_mlu
=============================
@brief: add return for self.impl.forward and it's param kwargs
'''
output = self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
kwargs=kwargs,
)
'''
==================
End of MLU Hijack
==================
'''
return output
class Attention_MluHijack(Attention):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace SlidingWindowSpec with MLUSlidingWindowSpec.
'''
return MLUSlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
'''
==================
End of MLU Hijack
==================
'''
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace FullAttentionSpec with MLUFullAttentionSpec.
'''
return MLUFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
'''
==================
End of MLU Hijack
==================
'''
class MLAAttention_MluHijack(MLAAttention):
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
) -> None:
nn.Module.__init__(self)
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
# self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
'''
=============================
Modify by vllm_mlu
=============================
@brief: insert num_kv_heads for mlu platform
'''
self.head_size = qk_nope_head_dim + qk_rope_head_dim
self.num_kv_heads = extra_impl_args.pop("num_kv_heads", None)
if self.num_kv_heads is None:
self.num_kv_heads = num_heads
self.decoder_attn_dtype = None
decoder_attn_dtype = get_current_vllm_config().mlu_config.decoder_attn_dtype
if decoder_attn_dtype in ["int8", "fp8_e4m3", "fp8"]:
self.decoder_attn_dtype = (
torch.int8 if decoder_attn_dtype == "int8"
else torch.float8_e4m3fn
)
extra_impl_args['decoder_attn_dtype'] = self.decoder_attn_dtype
'''
==================
End of MLU Hijack
==================
'''
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
self.num_heads,
self.head_size,
self.scale,
self.num_kv_heads,
None, # alibi_slops
None, # sliding_window
kv_cache_dtype,
None, # logits_soft_cap
AttentionType.DECODER, # attn_dtype
None, # kv_sharing_target_layer_name
**extra_impl_args,
)
self.dtype = dtype
self.use_direct_call = not current_platform.opaque_attention_op()
if current_platform.is_out_of_tree():
self.use_direct_call = False
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
'''
=============================
Modify by vllm_mlu
=============================
@brief: support kv8 and deepseek v3.2
'''
self.kv_cache = [
[torch.tensor([]), torch.tensor([]), torch.tensor([])]
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.impl.use_mla = True
'''
==================
End of MLU Hijack
==================
'''
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace MLAAttentionSpec with MLUMLAAttentionSpec.
'''
index_head_dim, index_n_heads = 0, 0
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v32":
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
index_n_heads = 1
if vllm_config.model_config.hf_text_config.model_type == "deepseek_v4":
index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim
index_n_heads = 1
return MLUMLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
index_head_dim=index_head_dim,
index_n_heads=index_n_heads,
)
'''
==================
End of MLU Hijack
==================
'''
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
kwargs: dict[str, Any] = {},
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
assert not self.use_direct_call, "MLU-V1 does not support direct call."
if self.attn_backend.accept_output_buffer:
output_lse = None
output_shape = (output_shape if output_shape is not None else query.shape)
output_shape = [output_shape[0], self.num_heads * self.v_head_dim]
output = torch.empty(
output_shape,
dtype=self.dtype if query.dtype == torch.int8 else query.dtype,
device=query.device,
)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.v_head_dim)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.v_head_dim)
if not kwargs:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
)
attn_output_list = output
else:
attn_output_list = unified_attention_with_output(
query, key, value, output, self.layer_name, kwargs=kwargs)
if isinstance(attn_output_list, (list, tuple)) and len(attn_output_list) > 1:
output_lse = attn_output_list[1]
if output_lse is not None:
return output.view(-1, hidden_size), output_lse
else:
return output.view(-1, hidden_size)
'''
==================
End of MLU Hijack
==================
'''
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
)
MluHijackObject.apply_hijack(
Attention,
Attention.get_kv_cache_spec,
Attention_MluHijack.get_kv_cache_spec,
)
MluHijackObject.apply_hijack(
MLAAttention,
MLAAttention.__init__,
MLAAttention_MluHijack.__init__,
)
MluHijackObject.apply_hijack(
MLAAttention,
MLAAttention.get_kv_cache_spec,
MLAAttention_MluHijack.get_kv_cache_spec,
)
MluHijackObject.apply_hijack(
MLAAttention,
MLAAttention.forward,
MLAAttention_MluHijack.forward,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import inspect
from collections.abc import Callable
from functools import wraps
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
def maybe_transfer_kv_layer(func: Callable) -> Callable:
"""Decorator that handles KV layer transfer prior and after execution of
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
On entry: waits for the KV layer from the connector.
On exit: saves the KV layer to the connector.
"""
# Import at runtime to avoid circular dependency
from vllm.attention.layer import get_attention_context
# Inspect the signature ONCE when the decorator is applied.
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
# Find the index of 'layer_name' parameter.
try:
layer_name_index = param_names.index("layer_name")
except ValueError as e:
raise TypeError(
f"Function {func.__name__} must have a 'layer_name' parameter"
) from e
@wraps(func)
def wrapper(*args, **kwargs):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs)
layer_name: str = args[layer_name_index]
# Extract attention context (layer-specific metadata, layer, and kv_cache)
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)
# Wait for KV layer on entry
connector.wait_for_layer_load(layer_name)
# Execute the function
result = func(*args, **kwargs)
# Save KV cache layer on exit
if kwargs is None or kwargs.get("save_kv_layer", True):
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
return result
return wrapper

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
- ShareGPT
- Random (synthetic)
- Sonnet
- BurstGPT
- HuggingFace
- VisionArena
"""
from tempfile import NamedTemporaryFile
import numpy as np
from vllm.benchmarks.datasets import RandomMultiModalDataset
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video(
self, width: int, height: int, num_frames: int
) -> dict:
"""Generate synthetic video with random values.
Creates a video with random pixel values, encodes it to MP4 format,
and returns the content as bytes.
"""
import cv2
random_pixels = self._rng.integers(
0,
256,
(num_frames, height, width, 3),
dtype=np.uint8,
)
# Create a temporary video file in memory
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = 30 # frames per second
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
temp_path = temp_file.name
# Create video writer
video_writer = cv2.VideoWriter(
temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height)
)
if not video_writer.isOpened():
raise RuntimeError("Failed to create video writer")
for frame in random_pixels:
video_writer.write(frame)
video_writer.release()
temp_file.close()
# Read the video file content
with open(temp_path, "rb") as f:
video_content = f.read()
return {"bytes": video_content}
MluHijackObject.apply_hijack(
RandomMultiModalDataset,
RandomMultiModalDataset.generate_synthetic_video,
vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import operator
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.platforms import current_platform
from vllm.logger import init_logger
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import is_func
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
class FixFunctionalizationPass_MluHijack(FixFunctionalizationPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug(
"XPU platform does not support fix functionalizationpass currently."
)
return
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
'''
=============================
Modify by vllm_mlu
=============================
@brief: skip custom op on mlu
'''
if current_platform.is_out_of_tree():
continue # skip the count on mlu
'''
==================
End of MLU Hijack
==================
'''
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(
user_of_getitem, torch.ops.aten.slice_scatter.default
):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: "query", 2: "key"}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: "input", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "scale", 3: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input")
)
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input", "scale")
)
elif (
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
):
mutated_args = {1: "result", 2: "result_block_scale"}
self.defunctionalize(
graph,
node,
mutated_args,
args=(
"result",
"result_block_scale",
"input",
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.nodes_to_remove.clear()
MluHijackObject.apply_hijack(
FixFunctionalizationPass,
FixFunctionalizationPass.__call__,
FixFunctionalizationPass_MluHijack.__call__
)

View File

@@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import dataclasses
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from unittest.mock import patch
import torch
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import weak_ref_tensors
from vllm.compilation.cuda_graph import (
CUDAGraphEntry,
CUDAGraphWrapper,
CUDAGraphOptions,
)
from vllm_mlu.v1.attention.backends.utils import MLUInferMode
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: specialized graph entry for prefill graphs
'''
@dataclasses.dataclass
class PrefillGraphEntry:
batch_size: int = 0
seq_len: int = 0
cudagraph: torch.mlu.MLUGraph | None = None
output: Any | None = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: list[int] | None = None
'''
==================
End of MLU Hijack
==================
'''
class MLUGraphWrapper(CUDAGraphWrapper):
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
super().__init__(runnable, vllm_config, runtime_mode, cudagraph_options)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add separate dict for prefill graph entries
'''
self.prefill_mlugraph_entry: PrefillGraphEntry | None = None
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: check if running in prefill mode
'''
def is_running_in_prefill(self, entry: PrefillGraphEntry | None = None) -> bool:
forward_context = get_forward_context()
if forward_context.attn_metadata is None:
return False
infer_mode = forward_context.attn_metadata['common_metadata'].infer_mode
seq_lens_cpu = forward_context.attn_metadata['common_metadata'].seq_lens_cpu
if entry is not None \
and infer_mode == MLUInferMode.PREFILL_ONLY \
and seq_lens_cpu.size(0) == entry.batch_size \
and (seq_lens_cpu == entry.seq_len).all().item():
return True
return False
'''
==================
End of MLU Hijack
==================
'''
def __call__(
self,
is_capturing_prefill: bool = False,
prefill_enable_mlugraph: bool = False,
prefill_batch_size: int = 0,
prefill_seq_len: int = 0,
is_running_drafter: bool = False,
*args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode != self.runtime_mode
):
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
'''
=============================
Modify by vllm_mlu
=============================
@brief: handle prefill graph separately
@brief: skip check in running drafter model
'''
if is_capturing_prefill: # PREFILL capture
self.prefill_mlugraph_entry = PrefillGraphEntry(
batch_size=prefill_batch_size,
seq_len=prefill_seq_len)
else: # FULL/DECODE capture
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
batch_descriptor=batch_descriptor
)
if ((self.is_running_in_prefill(self.prefill_mlugraph_entry) and prefill_enable_mlugraph)
or is_capturing_prefill):
entry = self.prefill_mlugraph_entry
logger.debug(
f"Hitting a prefill cudagraph on {self.runtime_mode.name}, "
f"batch_size: {entry.batch_size}, seq_len: {entry.seq_len}")
else: # FULL/DECODE capture
entry = self.concrete_cudagraph_entries[batch_descriptor]
logger.debug(
"Hitting a decode cudagraph on (%s, %s)",
self.runtime_mode.name,
entry.batch_descriptor,
)
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
if is_capturing_prefill:
logger.debug(
"Capturing a prefill cudagraph on (%s, batch_size=%d, seq_len=%d)",
self.runtime_mode.name,
entry.batch_size,
entry.seq_len,
)
else:
logger.debug(
"Capturing a decode cudagraph on (%s, %s)",
self.runtime_mode.name,
entry.batch_descriptor,
)
if ((not is_capturing_prefill) and (not is_running_drafter)):
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
'''
==================
End of MLU Hijack
==================
'''
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.mlu.MLUGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.mlu.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.mlu.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

71
vllm_mlu/config/model.py Normal file
View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.config.model import ModelConfig
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__model__ModelConfig__is_embedding_task(self) -> bool:
return self.runner_type == "pooling"
def vllm__config__model__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0)
if self.use_mla:
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else:
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if hasattr(self.hf_text_config, "model_type") and (
self.hf_text_config.model_type == "zamba2"
):
return self.hf_text_config.attention_head_dim
if self.is_attention_free:
return 0
# NOTE: Some configs may set head_dim=None in the config
if getattr(self.hf_text_config, "head_dim", None) is not None:
return self.hf_text_config.head_dim
# NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head`
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head
# FIXME(woosuk): This may not be true for all models.
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust num_heads and num_attention_heads.
'''
if hasattr(self.hf_text_config, "num_heads"):
num_attention_heads = self.hf_text_config.num_heads
else:
num_attention_heads = self.hf_text_config.num_attention_heads
return (self.hf_text_config.hidden_size // num_attention_heads)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
ModelConfig,
"is_embedding_task",
vllm__config__model__ModelConfig__is_embedding_task,
)
MluHijackObject.apply_hijack(
ModelConfig,
ModelConfig.get_head_size,
vllm__config__model__ModelConfig__get_head_size,
)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing_extensions import Self
from vllm.config.scheduler import SchedulerConfig
from vllm.logger import init_logger
from vllm_mlu._mlu_utils import VLLM_V1_BENCHMARK
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__scheduler__SchedulerConfig__verify_max_model_len(
self, max_model_len: int,
) -> Self:
'''
=============================
Modify by vllm_mlu
=============================
@brief: This restriction is removed when VLLM_V1_BENCHMARK is set to True
'''
if not VLLM_V1_BENCHMARK:
if (
self.max_num_batched_tokens < max_model_len
and not self.enable_chunked_prefill
):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
'''
==================
End of MLU Hijack
==================
'''
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs})."
)
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
self.max_num_seqs * max_model_len,
)
if self.max_num_partial_prefills > 1:
if not self.enable_chunked_prefill:
raise ValueError(
"Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1."
)
if self.long_prefill_token_threshold > max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({max_model_len})."
)
if self.max_long_partial_prefills > self.max_num_partial_prefills:
raise ValueError(
f"{self.max_long_partial_prefills=} must be less than or equal to "
f"{self.max_num_partial_prefills=}."
)
return self
MluHijackObject.apply_hijack(
SchedulerConfig,
SchedulerConfig.verify_max_model_len,
vllm__config__scheduler__SchedulerConfig__verify_max_model_len,
)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.config.parallel import ParallelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
@staticmethod
def vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
'''
=============================
Modify by vllm_mlu
@brief: add draft data parallel parameters
=============================
'''
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
# add draft data parallel parameters
data_parallel_size=target_parallel_config.data_parallel_size,
data_parallel_size_local=target_parallel_config.data_parallel_size_local,
data_parallel_master_ip=target_parallel_config.data_parallel_master_ip,
data_parallel_rpc_port=target_parallel_config.data_parallel_rpc_port,
)
'''
==================
End of MLU Hijack
==================
'''
return draft_parallel_config
vllm__config__speculative__SpeculativeConfig____post_init___org = SpeculativeConfig.__post_init__
def vllm__config__speculative__SpeculativeConfig____post_init__(self):
if self.model is None and self.num_speculative_tokens is not None and self.method is None:
self.method = "mtp"
vllm__config__speculative__SpeculativeConfig____post_init___org(self)
MluHijackObject.apply_hijack(
SpeculativeConfig,
SpeculativeConfig.create_draft_parallel_config,
vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config,
)
MluHijackObject.apply_hijack(
SpeculativeConfig,
SpeculativeConfig.__post_init__,
vllm__config__speculative__SpeculativeConfig____post_init__,
)

213
vllm_mlu/config/vllm.py Normal file
View File

@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
from vllm.config.vllm import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__config__vllm__VllmConfig___set_cudagraph_sizes(self):
"""
vLLM defines the default candidate list of batch sizes for CUDA graph
capture as:
```python
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
will be the final sizes to capture cudagraph (in ascending order).
These sizes are used to capture and reuse CUDA graphs for
performance-critical paths (e.g., decoding). Capturing enables
significantly faster kernel dispatch by avoiding Python overhead. The
list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
most GPUs), which controls the total allowed number of tokens in a
batch. Since each sequence may have a variable number of tokens, the
maximum usable batch size will depend on actual sequence lengths.
Example:
With `max_num_batched_tokens = 8192`, and typical sequences
averaging ~32 tokens, most practical batch sizes fall below 256.
However, the system will still allow capture sizes up to 512 if
shape and memory permit.
Note:
If users explicitly specify cudagraph capture sizes in the
compilation config, those will override this default logic.
At runtime:
- If batch size <= one of the `cudagraph_capture_sizes`, the closest
padded CUDA graph will be used.
- If batch size > largest `cudagraph_capture_sizes`, cudagraph will
not be used.
"""
if hasattr(self.compilation_config, "_has_set_capture_list"):
# avoid set capture list twice while init
return
if (
self.model_config is not None
and not self.model_config.enforce_eager
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
# determine the initial max_cudagraph_capture_size
max_cudagraph_capture_size = (
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
assert max_cudagraph_capture_size >= 1, (
"Maximum cudagraph size should be greater than or equal to 1 "
"when using cuda graph."
)
# determine the cudagraph_capture_sizes
if self.compilation_config.cudagraph_capture_sizes is not None:
assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
"cudagraph_capture_sizes should contain at least one element "
"when using cuda graph."
)
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = [
i for i in dedup_sizes if i <= max_num_tokens
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:
cudagraph_capture_sizes = [
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
]
if max_cudagraph_capture_size >= 8:
# Step size 8 for small batch sizes, up to 256(not included)
cudagraph_capture_sizes += list(
range(8, min(max_cudagraph_capture_size + 1, 256), 8)
)
if max_cudagraph_capture_size >= 256:
# Step size 16 for larger batch sizes
cudagraph_capture_sizes += list(
range(256, max_cudagraph_capture_size + 1, 16)
)
'''
=============================
Modify by vllm_mlu
=============================
@brief:
1) check batch_size_capture_list when enable mtp because bs * (K + 1)
may greater than max_num_batched_tokens
2) capture MLUGraph by given batch list
'''
mlu_graph_capture_list = os.getenv("MLU_GRAPH_CAPTURE_LIST", None)
if mlu_graph_capture_list:
if "-" in mlu_graph_capture_list:
batch_info = mlu_graph_capture_list.split("-")
assert len(batch_info) == 3, \
f"Got invalid graph_capture_list={mlu_graph_capture_list}, " + \
f"but expected format 'min_bs-max_bs(may not include)-step'."
start, end, step = mlu_graph_capture_list.split("-")
cudagraph_capture_sizes = [1, 2, 4] + [
i for i in range(int(start), int(end), int(step))
]
cudagraph_capture_sizes = sorted(list(set(cudagraph_capture_sizes)))
else:
cudagraph_capture_sizes = [int(x) for x in mlu_graph_capture_list.split(",")]
if (self.speculative_config is not None
and self.speculative_config.num_speculative_tokens > 0
):
K = self.speculative_config.num_speculative_tokens
cudagraph_capture_sizes = [x * (1 + K) for x in cudagraph_capture_sizes]
cudagraph_capture_sizes = [
size for size in cudagraph_capture_sizes
if size <= self.scheduler_config.max_num_batched_tokens
]
'''
==================
End of MLU Hijack
==================
'''
if (
self.parallel_config.tensor_parallel_size > 1
and self.compilation_config.pass_config.enable_sequence_parallelism
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
)
# user-specific compilation_config.max_cudagraph_capture_size get
# truncated to valid_max_size when they are inconsistent.
valid_max_size = (
cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
)
if (
self.compilation_config.max_cudagraph_capture_size is not None
and self.compilation_config.max_cudagraph_capture_size != valid_max_size
):
# raise error only when both two flags are user-specified
# and they are inconsistent with each other
if self.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"customized max_cudagraph_capture_size"
f"(={self.compilation_config.max_cudagraph_capture_size}) "
"should be consistent with the max value of "
f"cudagraph_capture_sizes(={valid_max_size})"
)
logger.warning(
"Truncating max_cudagraph_capture_size to %d",
valid_max_size,
)
# always set the final max_cudagraph_capture_size
self.compilation_config.max_cudagraph_capture_size = valid_max_size
if self.compilation_config.cudagraph_capture_sizes is not None and len(
cudagraph_capture_sizes
) < len(self.compilation_config.cudagraph_capture_sizes):
# If users have specified capture sizes, we only need to
# compare the lens before and after modification since the modified
# list is only the subset of the original list.
logger.warning(
(
"cudagraph_capture_sizes specified in compilation_config"
" %s is overridden by config %s"
),
self.compilation_config.cudagraph_capture_sizes,
cudagraph_capture_sizes,
)
# always write back the final sizes
self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
# no cudagraph in use
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
# complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes()
setattr(self.compilation_config, "_has_set_capture_list", True)
MluHijackObject.apply_hijack(
VllmConfig,
VllmConfig._set_cudagraph_sizes,
vllm__config__vllm__VllmConfig___set_cudagraph_sizes,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,319 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# cn_api based pytorch pluggable allocator to implement sleep mode.
import dataclasses
import gc
import os
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
def find_loaded_library(lib_name) -> str | None:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found_line = None
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found_line = line
break
if found_line is None:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = found_line.index("/")
path = found_line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), (
f"Unexpected filename: {filename} for library {lib_name}"
)
return path
cnmem_available = False
try:
from vllm_mlu.vllm_mlu_C import (
init_module,
python_create_and_map,
python_unmap_and_release,
python_cn_memcpy,
)
lib_name = find_loaded_library("vllm_mlu_C")
cnmem_available = True
except ModuleNotFoundError as e:
logger.error("Failed to import cnmem_allocator:%s", e)
init_module = None
python_create_and_map = None
python_unmap_and_release = None
lib_name = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = tuple[int, int, int, int]
@dataclasses.dataclass
class AllocationData:
handle: HandleType
tag: str
cpu_backup_tensor: torch.Tensor | None = None
def create_and_map(allocation_handle: HandleType) -> None:
python_create_and_map(*allocation_handle)
def unmap_and_release(allocation_handle: HandleType) -> None:
python_unmap_and_release(*allocation_handle)
def get_pluggable_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]
) -> torch.mlu.memory.MLUPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.mlu.memory.MLUPluggableAllocator(
lib_name, "my_malloc", "my_free"
)
return new_alloc
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]):
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.mlu.memory.MemPool(new_alloc._allocator)
with torch.mlu.memory.use_mem_pool(mem_pool):
yield mem_pool, new_alloc
class CnMemAllocator:
"""
A singleton class that manages a memory pool for MLU tensors.
The memory in this pool can be offloaded or discarded when the
allocator sleeps.
Inside the `use_memory_pool(tag)` context, all tensors created will
be allocated in the memory pool, and has the same tag as the
tag passed to the context.
When we call `sleep`, all tensors with the specified tag will be
offloaded to CPU memory, and the rest of the tensors will be discarded.
When we call `wake_up`, all tensors that are previously offloaded
will be loaded back to GPU memory, and the rest of the tensors will
have empty memory.
Why it needs to be a singleton?
When allocated tensors are garbage collected, PyTorch will call
the free callback, which will call the `python_free_callback` method.
The C-extension uses a global variable to store the function of an
instance of this class. If we create multiple instances of this class,
the global variable will be overwritten and the free callback will
not work as expected.
"""
instance: "CnMemAllocator" = None
default_tag: str = "default"
@staticmethod
def get_instance() -> "CnMemAllocator":
"""
CnMemAllocator is a singleton class.
We cannot call the constructor directly.
Call this method to get the instance.
"""
assert cnmem_available, "cnmem allocator is not available"
if CnMemAllocator.instance is None:
CnMemAllocator.instance = CnMemAllocator()
return CnMemAllocator.instance
def __init__(self):
conf = os.environ.get("PYTORCH_MLU_ALLOC_CONF", "")
assert "expandable_segments:True" not in conf, (
"Expandable segments are not compatible with memory pool. "
"Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates."
)
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CnMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {}
# Creating strong references to the two callbacks here to prevent
# these ephemeral bound-method objects being garbage collected.
# See discussions in https://github.com/vllm-project/vllm/pull/22724
self.python_malloc_callback = self._python_malloc_callback
self.python_free_callback = self._python_free_callback
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag
)
logger.debug(
"Allocated %s bytes for %s with address %s from cnmem allocator",
allocation_handle[1],
self.current_tag,
py_d_mem,
)
return
def _python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
data = self.pointer_to_data.pop(ptr)
if data.cpu_backup_tensor is not None:
data.cpu_backup_tensor = None
logger.debug(
"Freed %s bytes for %s with address %s from cnmem allocator",
data.handle[1],
data.tag,
ptr,
)
return data.handle
def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CnMemAllocator.default_tag, )
elif isinstance(offload_tags, str):
offload_tags = (offload_tags,)
assert isinstance(offload_tags, tuple)
total_bytes = 0
backup_bytes = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
total_bytes += handle[1]
if data.tag in offload_tags:
backup_bytes += handle[1]
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device="cpu",
pin_memory=is_pin_memory_available(),
)
cpu_ptr = cpu_backup_tensor.data_ptr()
python_cn_memcpy(cpu_ptr, ptr, size_in_bytes)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
logger.info(
"CnMemAllocator: sleep freed %.2f GiB memory in total, of which "
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
"directly.",
total_bytes / 1024**3,
backup_bytes / 1024**3,
(total_bytes - backup_bytes) / 1024**3,
)
gc.collect()
torch.mlu.empty_cache()
def wake_up(self, tags: list[str] | None = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items():
if tags is None or data.tag in tags:
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = (
cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
)
cpu_ptr = cpu_backup_tensor.data_ptr()
python_cn_memcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: str | None = None):
"""
A context manager to use the memory pool.
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:
tag = CnMemAllocator.default_tag
assert isinstance(tag, str)
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(
self.python_malloc_callback, self.python_free_callback
) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released, e.g. in online quantization,
# where the model is created in higher precision, and then
# quantized in lower precision.
# Find all unused allocations and manually release them.
# TODO: we should expose `empty_cache` method in the memory pool.
# TODO: ask for help from PyTorch team to expose this method.
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
self.current_tag = old_tag
def get_current_usage(self) -> int:
"""
Get the total number of bytes allocated in the memory pool.
"""
sum_bytes: int = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
sum_bytes += handle[1]
return sum_bytes

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
class MLUCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = ""
):
super().__init__(cpu_group, device, device_group, unique_name)
# init device according to rank
self.device = torch.mlu.current_device()
self.ca_comm: CustomAllreduce | None = None

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
MLUKVConnectors: dict[str, tuple[str, str]] = {
"MLUSharedStorageConnector": (
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector"
),
"MLUNixlConnector": (
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"MLUNixlConnector"
),
}
for name, (module_path, class_name) in MLUKVConnectors.items():
if name not in KVConnectorFactory._registry:
KVConnectorFactory.register_connector(name, module_path, class_name)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class LMCacheConnectorV1_MluHijack(LMCacheConnectorV1):
def response_remote_alloc_once(self) -> None:
self._lmcache_engine.response_remote_alloc_once()
def request_remote_memory_send(self) -> None:
self._lmcache_engine.request_remote_memory_send()
MluHijackObject.apply_hijack(LMCacheConnectorV1,
"response_remote_alloc_once",
LMCacheConnectorV1_MluHijack.response_remote_alloc_once)
MluHijackObject.apply_hijack(LMCacheConnectorV1,
"request_remote_memory_send",
LMCacheConnectorV1_MluHijack.request_remote_memory_send)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
import zmq
from vllm import envs
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
EngineId, NixlConnectorWorker, NixlAgentMetadata, NixlConnectorScheduler, NixlConnector)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
class MLUNixlConnector(NixlConnector):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super(NixlConnector, self).__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : MLUNixlConnectorScheduler | None = (
MLUNixlConnectorScheduler(vllm_config, self.engine_id)
)
self.connector_worker: MLUNixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MLUNixlConnectorWorker(vllm_config, self.engine_id)
class MLUNixlConnectorScheduler(NixlConnectorScheduler):
"""Implementation of Scheduler side methods"""
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: kv transfer info
'''
if request.kv_transfer_params.get("do_remote_prefill", False):
logger.info(f"NIXLConnector update_state_after_alloc: request_id={request.request_id}, "
f"num_prompt_tokens={request.num_prompt_tokens}, "
f"num_external_tokens={num_external_tokens}, "
f"kv_transfer_params={request.kv_transfer_params}")
'''
==================
End of MLU Hijack
==================
'''
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids()
if num_external_tokens > 0
else []
)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
class MLUNixlConnectorWorker(NixlConnectorWorker):
"""Implementation of Worker side methods"""
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
'''
=============================
Add by vllm_mlu
=============================
@brief: not support kv8
'''
if not isinstance(first_kv_cache, torch.Tensor):
kv_caches = {key: value[0] for key, value in kv_caches.items()}
_, first_kv_cache = next(iter(kv_caches.items()))
'''
==================
End of MLU Hijack
==================
'''
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
'''
=============================
Add by vllm_mlu
=============================
@brief: support mla
'''
use_mla = first_kv_cache.shape[0] == 1
'''
==================
End of MLU Hijack
==================
'''
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if use_mla:
# MLA case.
'''
=============================
Add by vllm_mlu
=============================
@brief: support mla
'''
self.num_blocks = first_kv_cache.shape[1]
'''
==================
End of MLU Hijack
==================
'''
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
'''
=============================
Add by vllm_mlu
=============================
@brief: MLU kv_cache layout is [2 (k and v), num_blocks, kv_heads, block_size, head_dim]
'''
n_kv_heads, block_size, head_dim = block_shape[-3:]
'''
==================
End of MLU Hijack
==================
'''
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
self.num_blocks, block_shape, first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
caches_data.append(
(base_addr, region_len, cache.device.index, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
self.num_layers = len(self.kv_caches.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
ready_event.wait()

View File

@@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
from vllm_mlu.v1.attention.backends.flash_mla import MLAFlashAttentionCommonMetadata
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
mm_hashes: list[str]
@staticmethod
def make_meta(
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
mm_hashes=mm_hashes,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
)
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1
)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info(
"Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping),
)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None)
if kv_cache_attr is None:
continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0, False
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
token_ids = request.prompt_token_ids or []
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
mm_hashes = [f.identifier for f in new_req.mm_features]
if new_req.req_id in self._requests_need_load:
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=mm_hashes,
)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_prompt(token_ids, mm_hashes):
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=mm_hashes,
)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue
num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features],
)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request."""
return self._found_match_for_prompt(
list(request.prompt_token_ids or []),
[f.identifier for f in request.mm_features],
)
def _found_match_for_prompt(
self,
prompt_token_ids: list[int],
mm_hashes: list[str],
) -> bool:
num_tokens_to_check = align_to_block_size(
len(prompt_token_ids) - 1, self._block_size
)
foldername = self._generate_foldername_debug(
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
mm_hashes,
create_folder=False,
)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
token_ids: torch.Tensor,
mm_hashes: list[str],
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
token_bytes = token_ids.numpy().tobytes()
# Add mm_hashes to the bytes being hashed to avoid path traversal and
# to create a canonical key.
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
token_ids: torch.Tensor,
mm_hashes: list[str],
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(
token_ids, mm_hashes=mm_hashes, create_folder=True
)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size."""
return (num_tokens - 1) // block_size * block_size

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from contextlib import contextmanager, nullcontext
from typing import Optional
from dataclasses import dataclass
import torch
from vllm.distributed.parallel_state import (
GroupCoordinator,
GraphCaptureContext,
get_pp_group,
get_tp_group,
)
from vllm.distributed.mlu_parallel_state import(
get_moe_expert_parallel_world_size,
get_moe_expert_parallel_rank,
get_moe_expert_parallel_group,
)
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
@dataclass
class MLUGraphCaptureContext:
stream: torch.mlu.Stream
@contextmanager
def mlu_graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = MLUGraphCaptureContext(torch.mlu.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
yield context
@contextmanager
def vllm__distributed__parallel_state__GroupCoordinator__graph_capture(
self,
graph_capture_context: GraphCaptureContext | None = None,
):
if graph_capture_context is None:
stream = torch.mlu.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
from vllm_mlu.distributed.device_communicators.mlu_communicator import (
MLUCommunicator,
)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, MLUCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.mlu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.mlu.stream(stream), maybe_ca_context:
yield graph_capture_context
@dataclass
class CnclEPBuffer:
dispatch_send_token_tensor: torch.Tensor
dispatch_recv_token_tensor: torch.Tensor
combine_send_token_tensor: torch.Tensor
combine_recv_token_tensor: torch.Tensor
class CnclEP:
def __init__(self,
dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True) -> None:
nranks = get_moe_expert_parallel_world_size()
rank = get_moe_expert_parallel_rank()
moe_ep_group = get_moe_expert_parallel_group()
self.max_num_tokens_per_rank = max_num_tokens_per_rank
self.use_quant_dispatch = use_quant_dispatch
(
handle,
exchange_info_size,
exchange_info,
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor
) = mlu_ops.moe_all2all_create(dispatch_token_size,
combine_token_size,
num_global_experts,
max_num_tokens_per_rank,
rank,
nranks)
self.handle = handle
self.buffer = CnclEPBuffer(
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor)
assert exchange_info.ndim == 1, "exchange_info should be 1D"
all_exchange_info = torch.empty((nranks, exchange_info.size(0)),
dtype=exchange_info.dtype,
device=exchange_info.device)
exchange_info = exchange_info.unsqueeze(0)
torch.distributed.all_gather_into_tensor(all_exchange_info,
exchange_info,
group=moe_ep_group.cpu_group,
async_op=False)
mlu_ops.moe_all2all_init(self.handle, all_exchange_info)
torch.distributed.barrier(group=moe_ep_group.cpu_group)
def dispatch(self,
token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) -> None:
'''
The returned tensors are in-placed modified, we could directly use them
after dispatch finishes.
'''
mlu_ops.moe_all2all_dispatch(self.handle,
token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def combine(self,
token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) ->None:
mlu_ops.moe_all2all_combine(self.handle,
token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy(self) -> None:
mlu_ops.moe_all2all_destroy(self.handle)
_CNCLEP: CnclEP | None = None
_CNCLEP_BF16: CnclEP | None = None
def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP:
if use_quant_dispatch:
assert _CNCLEP is not None, "cnclep is not initialized"
return _CNCLEP
else:
assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized"
return _CNCLEP_BF16
def init_cnclep(dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True):
if use_quant_dispatch:
global _CNCLEP
assert _CNCLEP is None, "cnclep has been initialized"
_CNCLEP = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
else:
global _CNCLEP_BF16
assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized"
_CNCLEP_BF16 = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
def cnclep_dispatch(token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
else:
_CNCLEP_BF16.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def cnclep_combine(token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
else:
_CNCLEP_BF16.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy_cnclep():
global _CNCLEP
if _CNCLEP:
_CNCLEP.destroy()
_CNCLEP = None
global _CNCLEP_BF16
if _CNCLEP_BF16:
_CNCLEP_BF16.destroy()
_CNCLEP_BF16 = None
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.graph_capture,
vllm__distributed__parallel_state__GroupCoordinator__graph_capture)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import get_args
from vllm.platforms import current_platform
from vllm.config import (
ModelConfig,
VllmConfig,
SchedulerConfig,
)
from vllm.config.cache import CacheDType
from vllm.engine.arg_utils import (
EngineArgs,
_raise_unsupported_error,
)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
@classmethod
def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
'''
=============================
Modify by vllm_mlu
=============================
@brief: mlu-v1 default use unchunked scheduler
'''
if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED:
default_chunked_prefill = False
else:
default_chunked_prefill = True
'''
==================
End of MLU Hijack
==================
'''
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid
else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and getattr(model_config.hf_config, "is_causal", True)
)
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
return default_chunked_prefill, default_prefix_caching
def vllm__engine__arg_utils__EngineArgs___set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
logger.debug(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
logger.warning(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = default_prefix_caching
logger.debug(
"%s prefix caching by default",
"Enabling" if default_prefix_caching else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_prefix_caching
and not default_prefix_caching
):
logger.warning(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,
default_max_num_seqs,
) = self.get_batch_defaults(world_size)
orig_max_num_batched_tokens = self.max_num_batched_tokens
orig_max_num_seqs = self.max_num_seqs
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: only set max_num_batched_tokens when enable chunked_prefill
'''
if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if orig_max_num_batched_tokens is None:
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
model_config.max_model_len,
self.max_num_batched_tokens,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * model_config.max_model_len,
self.max_num_batched_tokens,
)
logger.debug(
"Defaulting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens,
usage_context.value if usage_context else None,
)
if orig_max_num_seqs is None:
if self.max_num_batched_tokens is not None: # For type checking
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
logger.debug(
"Defaulting max_num_seqs to %d for %s usage context.",
self.max_num_seqs,
usage_context.value if usage_context else None,
)
'''
==================
End of MLU Hijack
==================
'''
_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3']
def vllm__engine__arg_utils__EngineArgs__create_engine_config(
self,
usage_context: UsageContext | None = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
NOTE: If VllmConfig is incompatible, we raise an error.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: add data parallel params to parallel config.
'''
if self.mlu_config and "decoder_attn_dtype" in self.mlu_config:
if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]:
self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype")
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(
self, usage_context, headless)
world_size = engine_config.parallel_config.world_size_across_dp
tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size
embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size
if embedding_tp_size:
assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, (
f"embedding_tp_size = {embedding_tp_size} out of bounds. "
f"Require {tensor_parallel_size} ≤ size ≤ {world_size}")
dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size
if dense_mlp_tp_size:
assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, (
f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}")
if dense_mlp_tp_size != world_size:
assert not engine_config.mlu_config.is_dpsk_mcc_enabled, (
"dense_mlp_tp_size is not supported when dpsk mcc is enabled.")
if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1:
raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. "
"Necessity of this constraint requires further investigation.")
if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size:
raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} "
f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}")
if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0:
raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must "
f"be divisible by {tensor_parallel_size}")
if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None
or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph):
logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.")
engine_config.mlu_config.prefill_enable_mlugraph = False
if engine_config.mlu_config.decoder_attn_dtype:
if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType):
raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} "
f"decoder_attn_dtype for now")
is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and
engine_config.model_config.hf_text_config.model_type == "glm4_moe")
if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe)
and engine_config.mlu_config.decoder_attn_dtype != "auto"):
raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model")
# sequence parallel checks
if (engine_config.mlu_config.prefill_use_sequence_parallel
and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]):
raise ValueError("Prefill sequence parallel can only use in deepseek model.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill:
raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled:
raise ValueError("Prefill sequence parallel can not use with mcc.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1:
raise ValueError("Prefill sequence parallel can not use with data parallel.")
if (engine_config.mlu_config.prefill_use_sequence_parallel
and engine_config.model_config.hf_text_config.model_type == "deepseek_v3"
and engine_config.quant_config.get_name() != "SmoothQuant"):
raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.")
# disagg constraint
# 1、only support deepseek-v3/r1
# 2、unsupport kv8
if self.kv_transfer_config is not None:
if engine_config.model_config.hf_config.model_type != "deepseek_v3":
raise ValueError("Disagg only support DeepDeek-V3/R1")
if engine_config.cache_config.cache_dtype == "int8":
raise ValueError("Disagg does not support KV cache dtype is int8")
if engine_config.cache_config.enable_prefix_caching:
raise ValueError("Disagg does not support prefix caching")
if isinstance(self.kv_transfer_config, dict):
kv_connector = self.kv_transfer_config.get("kv_connector")
kv_role = self.kv_transfer_config.get("kv_role")
else:
kv_connector = self.kv_transfer_config.kv_connector
kv_role = self.kv_transfer_config.kv_role
if kv_connector != "LMCacheConnectorV1":
raise ValueError("Disagg only support LMCacheConnectorV1 connector")
if kv_role == "kv_consumer":
if not self.enable_chunked_prefill:
raise ValueError("Disagg decoder only support chunk scheduler")
'''
==================
End of MLU Hijack
==================
'''
return engine_config
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs._set_default_args,
vllm__engine__arg_utils__EngineArgs___set_default_args)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_engine_config,
vllm__engine__arg_utils__EngineArgs__create_engine_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.get_chunked_prefill_prefix_caching_defaults,
vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

158
vllm_mlu/entrypoints/llm.py Normal file
View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from tqdm import tqdm
from typing import Callable
from vllm.entrypoints.llm import LLM
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.logger import init_logger
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.mlu_metric import LLMMetric
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__entrypoints__llm__LLM__get_mlu_metrics(
self,
metrics_idx_start,
only_average,
input_len,
output_len,
tp_nums,
quantization,
show_per_iter=False,
is_embedding_task=False,
mm_kwargs=None,
total_prefill_steps=1,
num_speculative_tokens=0,
dp_size=1,
) -> None:
'''
@brief:该函数用来打印vLLM调用generate接口过程中代码统计的各项性能指标数据
@params:
metrics_idx_start: 考虑存在调用generate接口为warmup过程的情况
因此设置该参数可忽略统计[0,metrics_idx_start)之间的数据,默认为0,即所有性能数据有效。
only_average: True 只打印N次调用generate接口的平均性能 False 打印每次调用generate接口的性能及其均值 若N次性能数据波动较大需自行排查测试环境是否稳定。
其余参数:均为模型配置参数
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
batch_size = self.metric.batch_size_list[-1] * dp_size
if mm_kwargs or is_embedding_task:
# The multimodal and pooling model doesn't support the hfu feature yet.
hfu_info, io_efficiency = None, None
else:
hfu_info, io_efficiency = self.llm_engine.get_hfu_info(batch_size, input_len, output_len)
self.metric.calc_metric(
self.llm_engine.model_config.model,
self.llm_engine.model_config.dtype,
metrics_idx_start, only_average,
input_len, output_len, tp_nums,
quantization, show_per_iter,
is_embedding_task, mm_kwargs, total_prefill_steps,
num_speculative_tokens, dp_size=dp_size, hfu_info=hfu_info, io_efficiency=io_efficiency)
else:
print("Warnning:please set VLLM_LATENCY_DEBUG=true!")
def vllm__entrypoints__llm__LLM___run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True
) -> list[RequestOutput | PoolingRequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
pbar = tqdm_func(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
)
'''
=============================
Added by vllm_mlu
=============================
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
total_request_num = self.llm_engine.get_num_unfinished_requests()
e2e_start_time = self.metric.get_mlu_cost_time()
if not self.llm_engine.model_config.is_embedding_task():
peak_memory, block_memory, num_gpu_blocks, num_cpu_blocks = \
self.llm_engine.get_memory_usage()
self.metric.update_memory_usage(peak_memory, block_memory,
num_gpu_blocks, num_cpu_blocks)
'''
==================
End of addition
==================
'''
# Run the engine.
outputs: list[RequestOutput | PoolingRequestOutput] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
n = len(output.outputs)
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids) * n
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs
)
out_spd = total_out_toks / pbar.format_dict["elapsed"]
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s"
)
pbar.update(n)
else:
pbar.update(1)
if pbar.n == num_requests:
pbar.refresh()
if use_tqdm:
pbar.close()
'''
=============================
Added by vllm_mlu
=============================
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
e2e_end_time = self.metric.get_mlu_cost_time()
e2e_latency = e2e_end_time - e2e_start_time
engine_step_latency, model_forward_latency, mm_encoder_latency = self.llm_engine.get_latency()
self.metric.update_step_latency(engine_step_latency)
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
self.metric.update_step_latency_device(model_forward_latency)
self.metric.update_mm_encoder_latency_device(mm_encoder_latency)
self.metric.add_metrics(total_request_num, e2e_latency)
'''
==================
End of addition
==================
'''
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
LLM.metric = LLMMetric()
MluHijackObject.apply_hijack(LLM,
"get_mlu_metrics",
vllm__entrypoints__llm__LLM__get_mlu_metrics)
MluHijackObject.apply_hijack(LLM,
LLM._run_engine,
vllm__entrypoints__llm__LLM___run_engine)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from fastapi import Request
from fastapi.responses import Response
import vllm_mlu._mlu_utils as mlu_envs
from vllm.entrypoints.openai.api_server import (
router, engine_client
)
from vllm_mlu.logger import logger
if mlu_envs.VLLM_SCHEDULER_PROFILE:
logger.info(
"vLLM V1 Scheduler Profiler is enabled in the API server. Please use "
"'tools/utils/post_scheduler_view_action.py' to dump profiling data "
"after all requests finished.")
@router.post("/v1/start_scheduler_profile")
async def start_scheduler_profile(raw_request: Request):
logger.info("VLLM-V1 starting scheduler profiler...")
await engine_client(raw_request).start_scheduler_profile()
return Response(status_code=200)
@router.post("/v1/stop_scheduler_profile")
async def stop_scheduler_profile(raw_request: Request):
logger.info("VLLM-V1 scheduler stopping profiler...")
await engine_client(raw_request).stop_scheduler_profile()
return Response(status_code=200)

41
vllm_mlu/envs.py Normal file
View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
from typing import Any, Callable, Dict
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
env_variables: Dict[str, Callable[[], Any]] = {
# max compile thread num
"MAX_JOBS":
lambda: os.getenv("MAX_JOBS", None),
"CMAKE_BUILD_TYPE":
lambda: os.getenv("CMAKE_BUILD_TYPE"),
"COMPILE_CUSTOM_KERNELS":
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
"LD_LIBRARY_PATH":
lambda: os.getenv("LD_LIBRARY_PATH", None),
"CXX_COMPILER":
lambda: os.getenv("CXX_COMPILER", None),
"C_COMPILER":
lambda: os.getenv("C_COMPILER", None)
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in env_variables:
return env_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(env_variables.keys())

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

47
vllm_mlu/logger.py Normal file
View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import logging
from typing import cast
from vllm.logger import _VllmLogger
class _ColorFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
if not record.name.startswith('vllm_mlu'):
return True
if record.levelno == logging.INFO:
record.msg = f"\033[32m{record.msg}\033[0m"
elif record.levelno == logging.WARNING:
record.msg = f"\033[33m{record.msg}\033[0m"
return True
def _apply_mlu_color(logger):
if not logger.handlers:
return
for h in logger.handlers:
if any(isinstance(f, _ColorFilter) for f in h.filters):
return
h.addFilter(_ColorFilter())
def _mlu_init_logger(name: str) -> logging.Logger:
"""Initialize loggers for vllm_mlu module,
and keep the configuration consistent with the vllm module"""
mlu_logger = logging.getLogger(name)
vllm_logger = logging.Logger.manager.loggerDict.get('vllm', None)
if vllm_logger:
mlu_logger.setLevel(vllm_logger.level)
mlu_logger.propagate = vllm_logger.propagate
mlu_logger.handlers = vllm_logger.handlers
return mlu_logger
def init_logger(name: str) -> _VllmLogger:
vllm_logger = cast(_VllmLogger, _mlu_init_logger(name))
_apply_mlu_color(vllm_logger)
return vllm_logger
logger = init_logger(__name__)

41
vllm_mlu/lora/__init__.py Normal file
View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.utils import LoRAMapping
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
__all__ = [
"BaseLayerWithLoRA",
"VocabParallelEmbeddingWithLoRA",
"LogitsProcessorWithLoRA",
"ColumnParallelLinearWithLoRA",
"ColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearWithLoRA",
"MergedColumnParallelLinearWithShardedLoRA",
"MergedQKVParallelLinearWithLoRA",
"MergedQKVParallelLinearWithShardedLoRA",
"QKVParallelLinearWithLoRA",
"QKVParallelLinearWithShardedLoRA",
"RowParallelLinearWithLoRA",
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
]

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
from vllm.platforms import current_platform
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply(
self,
x: torch.Tensor,
bias: torch.Tensor | None,
residual: torch.Tensor | None = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual in matmul
'''
output = self.base_layer.quant_method.apply(self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
return output
MluHijackObject.apply_hijack(
BaseLinearLayerWithLoRA,
BaseLinearLayerWithLoRA.apply,
vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply,
)

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.lora.layers.column_parallel_linear import ColumnParallelLinearWithLoRA
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
assert not use_tp_weight, "LoRa does not support use_tp_weight yet."
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
return vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org(self, input_)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithLoRA.forward,
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward,
)

View File

@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.platforms import current_platform
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply(
self,
x: torch.Tensor,
bias: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual and bias in matmul
'''
output = self.base_layer.quant_method.apply(
self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffer = torch.zeros(
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0
)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
if self.tp_size > 1:
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
return output
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward(
self,
input_: torch.Tensor,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add parameters `residual`, `smooth_quant_scale`, `use_tp_weight` and `output`
to keep parameters consistent with RowParallelLinear.forward.
'''
assert (not use_tp_weight) and output is None, (
f"RowParallelLinearWithLoRA.forward does not support use_tp_wight=True"
f" or pass output parameters.")
'''
==================
End of MLU Hijack
==================
'''
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[self.tp_rank].contiguous()
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) apply residual fusion in matmul like RowParallelLinear
2) add bias in matmul, not after all reduce
'''
# Matrix multiply.
bias_ = (
None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add)
else self.base_layer.bias
)
residual_ = None if self.base_layer.tp_rank > 0 else residual
output_parallel = self.apply(input_parallel, bias_, residual_)
'''
==================
End of MLU Hijack
==================
'''
if self.base_layer.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not add bias after all_reduce
'''
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
'''
==================
End of MLU Hijack
==================
'''
if not self.base_layer.return_bias:
return output
return output, output_bias
MluHijackObject.apply_hijack(
RowParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA.apply,
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply,
)
MluHijackObject.apply_hijack(
RowParallelLinearWithLoRA,
RowParallelLinearWithLoRA.forward,
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward,
)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm_mlu.lora.ops.triton_ops.sgmv_expand import sgmv_expand_mlu
from vllm_mlu.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink_mlu
from vllm_mlu.lora.ops.triton_ops.lora_shrink_op import lora_shrink
from vllm_mlu.lora.ops.triton_ops.lora_expand_op import lora_expand
__all__ = [
"sgmv_expand_mlu",
"sgmv_expand_slice_mlu",
"sgmv_shrink_mlu",
"lora_expand",
"lora_shrink"
]

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Utilities for Punica kernel construction.
"""
from vllm.triton_utils import tl, triton
'''
=============================
Modify by vllm_mlu
=============================
@brief: modify mm triton
1) add parameter offset_n: mlu add offset_n of matrix B,
value: tl.arange(0, BLOCK_N) + pid_n * BLOCK_N, shape: [BLOCK_N]
add parameter N: mlu add column number of matrix B
2) tiled_b always need mask in case offset_n > N
'''
@triton.jit
def mm_k(
a_ptr,
b_ptr,
ak_stride,
bk_stride,
offset_n,
offset_k,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
N: tl.constexpr,
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N, other=0.0)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
tiled_b = tl.load(b_ptr,
mask=(offset_k[:, None]
< K - k * (BLOCK_K * SPLIT_K)) & (offset_n < N)[None, :],
other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
return accumulator
'''
==================
End of MLU Hijack
==================
'''
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE: # 'same_stride': True
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty))
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
will handle as head corruption
2) re-write b_ptr, use offset_n to identify its position
'''
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride)
# b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
# offset_k[:, None] * cur_lora_d2_stride +
# rbn[None, :] * cur_lora_d1_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
offset_n[None, :] * cur_lora_d1_stride)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, offset_n,
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N,
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
'''
==================
End of MLU Hijack
==================
'''
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
< (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
will handle as head corruption
2) re-write b_ptr, use offset_n to identify its position
'''
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
# b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
# rbn[None, :] * lora_d1_stride +
# offset_k[:, None] * lora_d2_stride)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
offset_n[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
# Compute partial/complete block matrix product.
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_n, offset_k,
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N, False,
cur_lora_ptr.dtype.element_ty)
'''
==================
End of MLU Hijack
==================
'''
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
'''
=============================
Modify by vllm_mlu
=============================
@brief: use vllm_mlu hijacked kernel
'''
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_expand_kernel
'''
==================
End of MLU Hijack
==================
'''
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
@triton.jit
def _lora_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
)
@torch.inference_mode()
def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
EVEN_K = K % BLOCK_K == 0 # type: ignore
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
)
_lora_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_expand_fake(
inputs: torch.Tensor,
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: use only vllm operand
'''
lora_expand = _lora_expand
'''
==================
End of MLU Hijack
==================
'''

View File

@@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
'''
=============================
Modify by vllm_mlu
=============================
@brief: use vllm_mlu hijacked kernel
'''
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
'''
==================
End of MLU Hijack
==================
'''
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
@triton.jit
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling,
input_d0_stride, input_d1_stride, lora_d0_stride,
lora_d1_stride, lora_d2_stride, output_d0_stride,
output_d1_stride, output_d2_stride,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: list[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
)
_lora_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float,
) -> None:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: use only vllm operand
'''
lora_shrink = _lora_shrink
'''
==================
End of MLU Hijack
==================
'''

View File

@@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
The sgmv's expand triton kernel is based on GroupGEMM.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
offset_k[None, :] * xk_stride
b_ptr = lora_ptr + l0_stride * lora_index + \
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
other=0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
other=0)
'''
==================
End of MLU Hijack
==================
'''
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_mlu(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: Adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_expand_kernel_mlu
'''
_sgmv_expand_kernel_mlu[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_slice_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
offset_k[None, :] * xk_stride
b_ptr = lora_ptr + l0_stride * lora_index + \
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
other=0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
other=0)
'''
==================
End of MLU Hijack
==================
'''
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
(slice_offset + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand_slice_mlu(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
"""_summary_
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: Adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_expand_kernel_mlu
'''
_sgmv_expand_slice_kernel_mlu[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,231 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import triton
import triton.language as tl
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _sgmv_shrink_kernel_mlu(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
xm_stride, # hidden_size
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid = tl.program_id(axis=0)
pid_sk = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
offset_k[None, :] * xk_stride
b_ptr = lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride + \
offset_k[:, None] * lora_n_stride
'''
==================
End of MLU Hijack
==================
'''
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
'''
=============================
Modify by vllm_mlu
=============================
@brief: adjust kernel impl to fit mlu.
'''
if EVEN_K:
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
tiled_a = tl.load(a_ptr,
mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)),
other=0.0)
tiled_b = tl.load(b_ptr,
mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)),
other=0.0)
'''
==================
End of MLU Hijack
==================
'''
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += BLOCK_K * SPLIT_K * xk_stride
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def sgmv_shrink_mlu(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
'''
=============================
Modify by vllm_mlu
=============================
@brief: Workaround: adjust block size to meet mlu restrictions.
The grid of mlu triton kernel must less than 65536, it will be out of bound when
the input seq is very long, and causes runtime error. So we need to adjust the block
size to avoid this.
'''
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16)
'''
==================
End of MLU Hijack
==================
'''
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K,
batches,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: call _sgmv_shrink_kernel_mlu
'''
_sgmv_shrink_kernel_mlu[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
)
'''
==================
End of MLU Hijack
==================
'''
return

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
from math import ceil
_MLU_MAX_GRID_SIZE = 65536
def adjust_kernel_block_size(
m: int,
block_m: int,
n: int,
block_n: int
) -> Tuple[int, int]:
"""Adjust block size to meet mlu triton grid restrictions.
Calculation of the max block size in candidates list:
LLama3.1-8b-tp1 max n is 14336
LLama3.1-70b-tp4 max n is 7168
LLama3.1-405b-tp8 max n is 6656
when n is 14336, the max sequence length of block size 256 can be
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
"""
candidates_list = [16, 32, 64, 96, 128, 192, 256]
candidates_list_len = len(candidates_list)
m_idx = 1
n_idx = 0 if block_n == 16 else 1
while m_idx < candidates_list_len and n_idx < candidates_list_len:
block_m = candidates_list[m_idx]
block_n = candidates_list[n_idx]
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
break
if m_idx < candidates_list_len:
m_idx += 1
if n_idx < candidates_list_len:
n_idx += 1
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
return block_m, block_n

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Optional, Tuple, Union, final
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
@final
class PunicaWrapperMLU(PunicaWrapperCPU):
"""
PunicaWrapperMLU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from dataclasses import dataclass
from typing import List
from vllm.forward_context import DPMetadata
@dataclass
class MLUDPMetadata(DPMetadata):
# mlu platform arguments
# token num for current dp group
token_num: int = None
# token num offset for current dp group
token_num_offset: int = None
# whether we can use reduce scatter for both attn layer and mlp layer
layer_use_reduce_scatter: bool = False
# token num need to be pad for prefill, then we can do reduce scatter +
# all gather to optimize comm time
prefill_pad_to_token_num: int = -1
# token num in each dp group, the list length is attn data parallel size
# used to do all gather in dp groups after all reduce in attn
token_split_list: List[int] = None
# token num in each card, the list length is world size
# used to do all gather in all cards after reduce scatter in attn
attn_token_split_list_reduce_scatter: List[int] = None
# token num in each tp group, the list length is tensor parallel size
# used to do all gather in tp groups after reduce scatter in moe
moe_token_split_list_reduce_scatter: List[int] = None
# prefill or decode stage in each dp group
dp_is_prefill: List[bool] = None
# ADDITIONAL fields for merged compute and communication.
# Global sequence lengths for each batch size for prefill stage.
seq_lens: List[int] = None
# Batch sizes for each attn dp rank for prefill stage.
batch_sizes: List[int] = None
# ADDITIONAL fields for custom split for embedding, logits and dense mlp layer
# token num in each emb tp group, the list length is tensor parallel size
# used to do all gather in emb tp groups after reduce scatter in moe
emb_token_split_list: List[int] = None
# batch sizes in each logits tp group, the list length is tensor parallel size
# used to do all gather in logits tp groups after reduce scatter in moe
logits_batch_split_list: List[int] = None
# token num in each dense mlp group, the list length is dense mlp tp size
# used to do one more all gather after dense mlp and before reduce scatter
dense_attn_token_split_list: List[int] = None
@staticmethod
def make_oot(
data_parallel_rank: int,
data_parallel_size: int,
tensor_parallel_size: int,
dp_token_nums: List[int],
dp_is_prefill: List[bool],
prefill_dispatch_use_RS_AG: bool,
seq_lens: List[int] = None,
batch_sizes: List[int] = None,
emb_query_lens: List[int] = None,
logits_batch_sizes: List[int] = None,
dense_attn_token_split_list: List[int] = None,
) -> "MLUDPMetadata":
token_num_offset = sum(dp_token_nums[:data_parallel_rank])
token_num = dp_token_nums[data_parallel_rank]
token_split_list = dp_token_nums
attn_can_use_reduce_scatter = all(
(num != 0 and num % tensor_parallel_size == 0)
for num in token_split_list
)
all_split_token_num_equal = all(
num == token_split_list[0] for num in token_split_list
)
layer_can_use_reduce_scatter = (
attn_can_use_reduce_scatter and all_split_token_num_equal
)
attn_token_split_list_reduce_scatter = None
moe_token_split_list_reduce_scatter = None
prefill_pad_to_token_num = -1
tp_world_size = data_parallel_size * tensor_parallel_size
if layer_can_use_reduce_scatter:
attn_token_split_list_reduce_scatter = (
[token_split_list[0] // tensor_parallel_size] * tp_world_size
)
moe_token_split_list_reduce_scatter = (
attn_token_split_list_reduce_scatter[:tensor_parallel_size]
)
elif (
prefill_dispatch_use_RS_AG
and all(is_prefill for is_prefill in dp_is_prefill)
):
dp_group_max_token_nums = max(dp_token_nums)
prefill_pad_to_token_num = (
(dp_group_max_token_nums + tensor_parallel_size - 1)
// tensor_parallel_size
) * tensor_parallel_size
attn_token_split_list_reduce_scatter = (
[prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size
)
return MLUDPMetadata(
max_tokens_across_dp_cpu=None,
num_tokens_across_dp_cpu=None,
token_num=token_num,
token_num_offset=token_num_offset,
token_split_list=token_split_list,
layer_use_reduce_scatter=layer_can_use_reduce_scatter,
prefill_pad_to_token_num=prefill_pad_to_token_num,
attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter,
moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter,
seq_lens=seq_lens,
batch_sizes=batch_sizes,
dp_is_prefill=dp_is_prefill,
emb_token_split_list=emb_query_lens,
logits_batch_split_list=logits_batch_sizes,
dense_attn_token_split_list=dense_attn_token_split_list,
)

79
vllm_mlu/mlu_hijack.py Normal file
View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import importlib.util
from vllm_mlu._mlu_utils import *
from vllm_mlu.logger import logger
def is_module_available(module_name):
spec = importlib.util.find_spec(module_name)
return spec is not None
def check_environ_compatibility():
if is_module_available('apex'):
logger.error(f"The `apex` package is currently present in your environment, "
f"which may cause model accuracy issues or other problems. It is "
f"strongly recommended that you uninstall it before using vLLM.")
# Check environment compatibility first before applying mlu hijack.
check_environ_compatibility()
logger.info(f"[MLU] Apply Monkey Patch.")
# Apply v1 hijack
import vllm_mlu.v1.engine.core
import vllm_mlu.v1.engine.core_client
import vllm_mlu.v1.engine.llm_engine
import vllm_mlu.v1.engine.async_llm
import vllm_mlu.v1.core.sched.scheduler
import vllm_mlu.v1.core.single_type_kv_cache_manager
import vllm_mlu.v1.core.kv_cache_utils
import vllm_mlu.v1.core.kv_cache_manager
import vllm_mlu.v1.executor.abstract
import vllm_mlu.v1.executor.ray_executor
import vllm_mlu.v1.executor.multiproc_executor
import vllm_mlu.v1.sample.rejection_sampler
import vllm_mlu.v1.worker.lora_model_runner_mixin
import vllm_mlu.v1.worker.block_table
import vllm_mlu.v1.worker.gpu_input_batch
import vllm_mlu.v1.worker.kv_connector_model_runner_mixin
import vllm_mlu.v1.attention.backends.gdn_attn
import vllm_mlu.v1.attention.backends.mla.flashmla
import vllm_mlu.compilation.fix_functionalization
# Apply common hijack
import vllm_mlu.attention.layer
import vllm_mlu.benchmarks.datasets
import vllm_mlu.config.model
import vllm_mlu.config.scheduler
import vllm_mlu.config.speculative
import vllm_mlu.config.vllm
import vllm_mlu.utils
import vllm_mlu.distributed.parallel_state
import vllm_mlu.distributed.kv_transfer.kv_connector.factory
import vllm_mlu.engine.arg_utils
import vllm_mlu.entrypoints.llm
import vllm_mlu.lora.layers.base_linear
import vllm_mlu.lora.layers.row_parallel_linear
import vllm_mlu.lora.layers.column_parallel_linear
import vllm_mlu.model_executor.parameter
import vllm_mlu.model_executor.layers.linear
import vllm_mlu.model_executor.layers.rotary_embedding
import vllm_mlu.model_executor.layers.quantization.utils.w8a8_utils
import vllm_mlu.model_executor.layers.quantization.fp8
import vllm_mlu.model_executor.layers.activation
import vllm_mlu.model_executor.layers.layernorm
import vllm_mlu.model_executor.layers.fused_moe.layer
import vllm_mlu.model_executor.model_loader.tensorizer_loader
import vllm_mlu.model_executor.models.registry
import vllm_mlu.model_executor.models.config
import vllm_mlu.multimodal.utils
if is_module_available('lmcache'):
import vllm_mlu.distributed.kv_transfer.kv_connector.v1.lmcache_connector
if VLLM_CI_ACCURACY_TEST:
import vllm_mlu.model_executor.model_loader.dummy_loader
if VLLM_SCHEDULER_PROFILE:
import vllm_mlu.entrypoints.openai.api_server

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.logger import init_logger
logger = init_logger(__name__)
IS_GATED=False
class MluHijackObject:
hijack_objs = []
@classmethod
def apply_hijack(cls, obj, org_func, hijack_func,
verify_orig_func_exists: bool = False):
"""
Optional Args:
verify_orig_func_exists (bool): If True, verifies that hijack succeeds
"""
cls.hijack_objs.append((obj, org_func, hijack_func))
if type(org_func) == str:
org_func_name = org_func
else:
if isinstance(org_func, property):
split_name = org_func.fget.__name__.split('__')
else:
split_name = org_func.__name__.split('__')
org_func_name = split_name[-1]
if org_func_name == "":
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
org_func_name = split_name[-2] + "__"
if len(split_name) >= 3 and split_name[-3] == "":
org_func_name = "__" + org_func_name
if verify_orig_func_exists and not hasattr(obj, org_func_name):
raise AttributeError(f"function {org_func_name} is not part of {obj}")
setattr(obj, org_func_name, hijack_func)
if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func):
raise AttributeError(
f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}")
@classmethod
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
if obj_ and hijack_func_:
for obj, org_func, hijack_func in cls.hijack_objs:
if obj_ == obj and hijack_func == hijack_func_:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
return
for obj, org_func, hijack_func in cls.hijack_objs:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
TypedDict = {
"hidden_size": 0,
"vocab_size": 0,
"ffn_inner_size": 0,
"moe_inner_size": 0,
"layer_num": 0,
"moe_layer_num": 0,
"head_num": 0,
"head_size": 0,
"head_num_kv": 0,
"tp_num": 0,
"shared_expert_intermediate_size": 0,
"shared_experts": 0,
"qk_nope_head_dim": 0,
"qk_rope_head_dim": 0,
"q_lora_rank": 0.0,
"num_attention_heads": 0,
"kv_lora_rank": 0,
"v_head_dim": 0,
"use_gated_ffn": False,
"experts_num": 0,
"topk_num": 0,
"use_causal_mask": False,
"cla_coeffient": 0,
"kv_cache_dtype": "",
"smooth_quant_type": "",
"data_type": "",
"model_type": "",
"filter_data_type": "",
}
def set_is_gated(flag):
global IS_GATED
IS_GATED=flag
def get_is_gated():
return IS_GATED

412
vllm_mlu/mlu_metric.py Normal file
View File

@@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import time
import statistics
import pandas as pd
import numpy as np
import json
import os
from datetime import datetime
from vllm.logger import init_logger
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG_WITH_DEVICE_EN, VLLM_DUMP_MLU_INFO_EN
from vllm.model_executor.layers.quantization import get_quantization_config
logger = init_logger(__name__)
millisecond2second_unit = 1000
class LLMMetric:
def __init__(self)->None:
self.batch_size_list = []
self.context_latency_list = []
self.e2e_latency_list = []
self.per_token_latency_list = [ [] ]
self.per_token_latency_device_list = [ [] ]
self.mm_encoder_latency_device_list = [ [] ]
self.peak_memory = 0
self.block_memory = 0
self.num_total_gpu_blocks = 0
self.num_total_cpu_blocks = 0
self.num_free_gpu_blocks_list = [ [] ]
self.num_free_cpu_blocks_list = [ [] ]
self.num_spec_tokens = 0
self.draft_acceptance_rate = 0.0
self.context_latency_device = 0.0
self.generate_latency_device = 0.0
self.mm_encoder_latency_device = 0.0
def reset_metric(self):
self.batch_size_list = []
self.context_latency_list = []
self.e2e_latency_list = []
self.per_token_latency_list = [ [] ]
self.per_token_latency_device_list = [ [] ]
self.mm_encoder_latency_device_list = [ [] ]
self.num_free_gpu_blocks_list = [ [] ]
self.num_free_cpu_blocks_list = [ [] ]
self.num_spec_tokens = 0
self.draft_acceptance_rate = 0.0
@classmethod
def get_mlu_cost_time(cls):
torch.mlu.synchronize()
return time.perf_counter()
def is_prefill_stage(self):
return len(self.per_token_latency_list[-1]) == 0
def update_memory_usage(self, peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks):
self.peak_memory = peak_memory
self.block_memory = block_memory
self.num_total_gpu_blocks = num_total_gpu_blocks
self.num_total_cpu_blocks = num_total_cpu_blocks
def update_step_block_usage(self, num_free_gpu_blocks, num_free_cpu_blocks):
self.num_free_gpu_blocks_list[-1].append(num_free_gpu_blocks)
self.num_free_cpu_blocks_list[-1].append(num_free_cpu_blocks)
def update_step_latency(self, step_latency):
if isinstance(step_latency, list):
self.per_token_latency_list[-1].extend(step_latency)
else:
self.per_token_latency_list[-1].append(step_latency)
def update_step_latency_device(self, step_latency):
if isinstance(step_latency, list):
self.per_token_latency_device_list[-1].extend(step_latency)
else:
self.per_token_latency_device_list[-1].append(step_latency)
def update_mm_encoder_latency_device(self, step_latency):
if isinstance(step_latency, list):
if len(step_latency) == 0:
return
assert len(step_latency) == 1, f"Not supported! Model with multi mm encoder steps. {len(step_latency)} {step_latency}"
self.mm_encoder_latency_device_list[-1].extend(step_latency)
else:
self.mm_encoder_latency_device_list[-1].append(step_latency)
def update_spec_decode_metrics(self, spec_decode_metrics):
self.num_spec_tokens = spec_decode_metrics.num_spec_tokens
self.draft_acceptance_rate = spec_decode_metrics.draft_acceptance_rate
def add_metrics(self, batch_size, e2e_latency)->None:
self.batch_size_list.append(batch_size)
self.e2e_latency_list.append(e2e_latency)
self.per_token_latency_list.append([]) # new iter
self.per_token_latency_device_list.append([])
self.mm_encoder_latency_device_list.append([])
self.num_free_gpu_blocks_list.append([])
self.num_free_cpu_blocks_list.append([])
def get_weight_dtype_str(self, model_path, model_dtype, quantization) -> str:
# get weight dtype based on quantization config if exists
if quantization == 'fp8':
return quantization
if quantization is not None:
quant_method = get_quantization_config(quantization)
# combine the model path with the quantization config file name
quant_config_paths = quant_method.get_config_filenames()
# if there are multiple quantization config files, return the first one existed
for quant_config_path in quant_config_paths:
quant_config_path = os.path.join(model_path, quant_config_path)
# check if the quantization config file exists
if not os.path.exists(quant_config_path):
continue
with open(quant_config_path, 'r') as f:
quant_config = json.load(f)
quant_config = quant_method.from_config(quant_config)
# for smoothquant and weightonly, return the quantization name with the weight bits
if quantization == "smoothquant" or quantization == ["weightonly"]:
return "{}-int{}".format(quant_config.get_name(), quant_config.weight_bits)
else:
# for other quantization methods, return the quantization name
return quant_config.get_name()
# if the quantization config file does not exist, just return the quanization name
return quant_config_path.get_name()
else:
# remove the prefix of model dtype from torch config
return str(model_dtype).split(".")[-1]
def to_csv(self, filename: str, show_per_iter=False) -> None:
if show_per_iter:
df = pd.DataFrame(self.metrics_data)
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
memory_df = pd.DataFrame(self.memory_metrics_data)
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
else:
df = pd.DataFrame(self.metrics_data)
memory_df = pd.DataFrame(self.memory_metrics_data)
df_mean = df.mean().round(3)
memory_df_mean = memory_df.mean().round(3)
header = ["datetime", "model",
"weight dtype", self.batch_size_name,
]
header = header + list(self.mm_kwargs.keys())
header = header + ["input len", "output len", "tp",
self.context_latency_name, self.per_token_latency_name]
data = [datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.model,
self.weight_dtype_str, int(self.metrics_data[self.batch_size_name][0])]
data = data + [self.mm_kwargs[k] for k in self.mm_kwargs.keys()]
data = data + [self.input_len, self.output_len, self.tp,
df_mean[self.context_latency_name], df_mean[self.per_token_latency_name]]
if self.num_spec_tokens > 0:
header += [self.per_step_latency_name]
data += [df_mean[self.per_step_latency_name]]
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
if self.is_v1_multimodal:
header += [self.mm_encoder_latency_device_name,]
data += [df_mean[self.mm_encoder_latency_device_name],]
header += [self.context_latency_device_name, self.per_token_latency_device_name]
data += [df_mean[self.context_latency_device_name], df_mean[self.per_token_latency_device_name]]
header += [self.e2e_latency_name, self.e2e_throughput_name, self.decoder_throughput_name,]
if self.num_spec_tokens > 0:
header += [self.k_name, self.acceptance_rate_name]
header += [self.decode_times_name,
self.peak_memory_name, self.block_memory_name]
data += [
df_mean[self.e2e_latency_name], df_mean[self.e2e_throughput_name], df_mean[self.decoder_throughput_name],
]
if self.num_spec_tokens > 0:
data += [self.num_spec_tokens, df_mean[self.acceptance_rate_name],]
data += [df_mean[self.decode_times_name], memory_df_mean[self.peak_memory_name], memory_df_mean[self.block_memory_name],]
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and self.save_hfu_info:
header += [self.context_hfu_name, self.decoder_hfu_name, self.decoder_io_efficiency_name]
data += [
df_mean[self.context_hfu_name], df_mean[self.decoder_hfu_name],
df_mean[self.decoder_io_efficiency_name]
]
data_dict = dict(zip(header, data))
df_csv = pd.DataFrame(data_dict, index=[0])
append = False
if os.path.isfile(filename):
try:
df_old = pd.read_csv(filename)
append = (df_old.columns.tolist() == header)
except Exception as e:
logger.info(f"Existing {filename} failed to be read and will be overwritten")
if append:
df_csv.to_csv(filename, mode='a', header=False, index=False)
logger.info(f"Metric appended to existing {filename}")
else:
df_csv.to_csv(filename, index=False)
logger.info(f"Metric written to {filename}")
def calc_metric(self, model, model_dtype, metrics_idx_start, only_average,
input_len, output_len, tp_nums, quantization,
show_per_iter=False, is_embedding_task=False, mm_kwargs=None,
total_prefill_steps=1, num_spec_tokens=0, dp_size=1, hfu_info=None, io_efficiency=0.0) -> None:
keep_digits = 2
def round_fn(data):
return round(data, keep_digits)
metrics_idx_end = len(self.per_token_latency_list) - 1 # without last []
idx_range = range(metrics_idx_start, metrics_idx_end)
# specify entries to write to csv
self.is_v1_multimodal = mm_kwargs
self.mm_kwargs = mm_kwargs if mm_kwargs else {} # multimodal args
self.batch_size_name = "batch size"
self.input_len = input_len
self.output_len = output_len
self.tp = tp_nums
self.dp = dp_size
self.model = model
self.context_latency_name = "context latency(ms)"
self.mm_encoder_latency_device_name = "multimodal encoder latency device(ms)"
self.context_latency_device_name = "context latency device(ms)"
if num_spec_tokens > 0:
self.per_step_latency_name = "per step latency(ms)"
self.per_token_latency_device_name = "per step latency device(ms)"
else:
self.per_token_latency_device_name = "per token latency device(ms)"
self.per_token_latency_name = "per token latency(ms)"
self.e2e_latency_name = "e2e latency(ms)"
self.e2e_throughput_name = "e2e throughput(tokens/s)"
self.decoder_throughput_name = "decoder throughput(tokens/s)"
self.k_name = "K"
self.acceptance_rate_name = "acceptance rate"
self.decode_times_name = "decode times"
self.weight_dtype_str = self.get_weight_dtype_str(model, model_dtype, quantization)
self.num_spec_tokens = num_spec_tokens
rate_list=[]
rate=0
if num_spec_tokens > 0:
for i in range(metrics_idx_end):
if len(self.per_token_latency_list[i]) - total_prefill_steps == 0:
logger.warning("For now output_len is 0, no need mtp info, if you need mtp info, please increase output_len.")
rate_list.append(0.0)
else:
rate_list.append(((self.output_len - 1) / (float)(len(self.per_token_latency_list[i]) - total_prefill_steps) - 1) / num_spec_tokens)
rate = statistics.fmean(rate_list[metrics_idx_start: metrics_idx_end])
metrics_data = [
(
self.batch_size_name, [self.dp * int(self.batch_size_list[i]) for i in idx_range]
),
(
self.context_latency_name, [round_fn(millisecond2second_unit * sum(self.per_token_latency_list[i][:total_prefill_steps])) for i in idx_range]
),
(
self.per_token_latency_name, [
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * (len(self.per_token_latency_list[i]) - total_prefill_steps) / (self.output_len - 1) * millisecond2second_unit) for i in idx_range
]
),
]
if num_spec_tokens > 0:
metrics_data += [(self.per_step_latency_name, [
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * millisecond2second_unit) for i in idx_range
])]
metrics_data += [
(
self.e2e_latency_name, [round_fn(millisecond2second_unit * self.e2e_latency_list[i]) for i in idx_range]
),
(
self.e2e_throughput_name, [
round_fn(self.dp * (output_len / self.e2e_latency_list[i]) * self.batch_size_list[i]) \
for i in idx_range
]
),
(
self.decoder_throughput_name, [
0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
round_fn(self.dp * ((output_len-1) / sum(self.per_token_latency_list[i][total_prefill_steps:])) * self.batch_size_list[i]) \
for i in idx_range
]
),
(
self.decode_times_name, [
0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \
len(self.per_token_latency_list[i][total_prefill_steps:]) for i in idx_range
]
),
]
if num_spec_tokens > 0:
metrics_data.append((self.k_name, num_spec_tokens))
metrics_data.append((self.acceptance_rate_name, [rate_list[i] for i in idx_range]))
insert_latency_device = VLLM_LATENCY_DEBUG_WITH_DEVICE_EN
if insert_latency_device:
device_item_idx = 3
if self.is_v1_multimodal:
mm_encoder_latency_device = [round_fn(sum(self.mm_encoder_latency_device_list[i])) for i in idx_range]
metrics_data.insert(device_item_idx, (self.mm_encoder_latency_device_name, mm_encoder_latency_device))
device_item_idx = device_item_idx + 1
context_latency_device = [round_fn(sum(self.per_token_latency_device_list[i][:total_prefill_steps])) for i in idx_range]
per_token_latency_device = [0.0 if len(self.per_token_latency_device_list[i]) <= total_prefill_steps else \
round_fn(statistics.fmean(self.per_token_latency_device_list[i][total_prefill_steps:])) for i in idx_range]
metrics_data.insert(device_item_idx, (self.context_latency_device_name, context_latency_device))
metrics_data.insert(device_item_idx + 1, (self.per_token_latency_device_name, per_token_latency_device))
self.metrics_data = dict(metrics_data)
# Print
df = pd.DataFrame(self.metrics_data)
if show_per_iter:
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
else:
df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = df.mean().round(keep_digits)
if only_average:
df = pd.DataFrame([df.iloc[-1]], columns=df.columns)
df.index.name = 'iter index'
df[self.batch_size_name] = df[self.batch_size_name].astype(int)
if num_spec_tokens > 0:
df[self.k_name] = df[self.k_name].astype(int)
self.peak_memory_name = "profile memory(GB)"
self.block_memory_name = "total cache memory(GB)"
memory_metrics_data = [
(
self.peak_memory_name, [round_fn(self.peak_memory / 1024 / 1024 / 1024) for i in idx_range]
),
(
self.block_memory_name, [round_fn(self.block_memory / 1024 / 1024 / 1024) for i in idx_range]
),
]
self.memory_metrics_data = dict(memory_metrics_data)
# Print
memory_df = pd.DataFrame(self.memory_metrics_data)
if show_per_iter:
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
else:
memory_df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = memory_df.mean().round(keep_digits)
if only_average:
memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns)
memory_df.index.name = 'iter index'
pd.set_option('display.colheader_justify', 'center')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
print("********************************* Test Info****************************")
mm_params_text = " ".join(f"{key}:{value}" for key, value in self.mm_kwargs.items())
print("Generation Config {} input len:{} output len:{} tp_nums:{} quantization:{}".format(
mm_params_text, input_len,output_len,tp_nums,quantization))
self.context_latency_device = np.mean(self.metrics_data['context latency device(ms)'])
self.generate_latency_device = np.mean(self.metrics_data[self.per_token_latency_device_name])
if self.is_v1_multimodal:
self.mm_encoder_latency_device = np.mean(self.metrics_data[self.mm_encoder_latency_device_name])
print("*************************Performance Info******************************")
print(f"Total prefill steps: {total_prefill_steps}")
print(df.to_string())
if not is_embedding_task:
# embedding task does not do profile run, so does not have memory infos
print(memory_df.to_string())
if insert_latency_device :
context_latency = np.mean(self.metrics_data['context latency device(ms)'])
generate_latency = np.mean(self.metrics_data[self.per_token_latency_device_name])
if num_spec_tokens > 0:
print("MTP token accept rate: {:.2f}%".format(rate*100))
self.dump_performance_info(hfu_info, io_efficiency)
avg_latency_e2e = sum(sum(self.per_token_latency_list[i]) for i in idx_range) / len(idx_range)
print("Avg latency without host time is :", avg_latency_e2e)
print("***********************************************************************")
# collect context_hfu and
self.save_hfu_info = False
if insert_latency_device:
if VLLM_DUMP_MLU_INFO_EN:
try:
import device_info
self.save_hfu_info = True
except:
logger.info(f"try import device_info failed. try pip install device_info.")
self.context_hfu_name = "Context HFU"
self.decoder_hfu_name = "Decoder HFU"
self.decoder_io_efficiency_name = "Decoder IO Efficiency"
if self.save_hfu_info:
self.metrics_data[self.context_hfu_name] = hfu_info["context_hfu"] * 100
self.metrics_data[self.decoder_hfu_name] = hfu_info["decoder_hfu"] * 100
self.metrics_data[self.decoder_io_efficiency_name] = io_efficiency * 100
if csv_path := os.getenv("OUTPUT_CSV_PATH"):
try:
if dir_path := os.path.dirname(csv_path):
os.makedirs(dir_path, exist_ok=True)
self.to_csv(csv_path, show_per_iter=show_per_iter)
except Exception as e:
logger.error(f"Invalid OUTPUT_CSV_PATH: {csv_path} to dump metrics, Error: {e}")
def dump_performance_info(self, hfu_info, io_efficiency):
try:
if VLLM_DUMP_MLU_INFO_EN and hfu_info != None:
hfu_info["context_hfu"] = hfu_info["context_hfu"] / (self.context_latency_device / millisecond2second_unit)
hfu_info["decoder_hfu"] = hfu_info["decoder_hfu"] / (self.generate_latency_device / millisecond2second_unit)
io_efficiency = io_efficiency / self.generate_latency_device
print(f"Context HFU-visible: {hfu_info['context_hfu']:.3%}")
print(f"Decoder HFU-visible: {hfu_info['decoder_hfu']:.3%}")
print(f"Decoder IO Efficiency: {io_efficiency:.3%}")
elif hfu_info != None:
print(f"Context FLOPS-visible: {hfu_info['context_flops']}")
print(f"Decoder FLOPS-visible: {hfu_info['decoder_flops']}")
else:
logger.info("Unsupport dump performance information")
except Exception as e:
logger.error(f"Failed to dump performance information: {str(e)}")

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.activation import QuickGELU
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
def vllm__model_executor__activation__QuickGELU__forward_oot(self, x: torch.Tensor) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: implement forward_oot
'''
return mlu_ops.active(x, 'quick_gelu', False)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(QuickGELU,
QuickGELU.forward_oot,
vllm__model_executor__activation__QuickGELU__forward_oot)

View File

@@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Callable
from scipy.linalg import hadamard
import torch
from torch import nn
import torch.nn.functional as F
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(
x,
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
)
out = out * scale
return out[..., :dim].reshape(*x_shape)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
hidden_size = x.size(-1)
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
class Compressor(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
head_dim: int = 512,
rotate: bool = False,
prefix: str = "",
**kwargs,):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.head_dim = head_dim
self.rope_head_dim =config.rope_head_dim
self.nope_head_dim = head_dim - config.rope_head_dim
self.compress_ratio = compress_ratio
self.overlap = compress_ratio == 4
self.rotate = rotate
coff = 1 + self.overlap
self.norm_eps = config.norm_eps
self.window_size = config.window_size
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
# The first half of dimensions for overlapping compression and second half for normal compression.
self.wkv = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wkv",
)
self.wgate = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wgate",
)
self.norm = RMSNorm(self.head_dim, self.norm_eps)
self.rotary_emb = rope
hf_config = vllm_config.model_config.hf_config
assert hasattr(hf_config, "cached_state_num"), \
f"cached_state_num is not set in hf_config"
cached_state_num = hf_config.cached_state_num
self.register_buffer(
"kv_state",
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"score_state",
torch.full(
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
float("-inf"),
dtype=torch.float32,
),
persistent=False,
)
self.hadamard_matrix = torch.tensor(
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
def overlap_transform(self, tensor: torch.Tensor, value=0):
# tensor: [b,s,r,2d]
b, s, _, _ = tensor.size()
ratio, d = self.compress_ratio, self.head_dim
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
return new_tensor
def forward_decode(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
mlu_ops.fused_compress_single_kv(
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
position=positions,
ape=self.ape,
kv_state=self.kv_state,
score_state=self.score_state,
gamma=self.norm.weight,
sin=self.rotary_emb.sin_,
cos=self.rotary_emb.cos_,
hadamard_matrix=self.hadamard_matrix,
slot_mapping=compressor_slot_mapping,
kv_cache=kv_cache,
kv_cache_scale=None,
eps=self.norm_eps,
overlap=self.overlap,
rotate=self.rotate,
state_idx=batch_to_kv_state,
)
# Here, return fake compressed_kv.
return None
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
forward_func: Callable = (
self.forward_prefill if common_metadata.is_prefill_only
else self.forward_decode
)
return forward_func(
x,
positions,
attn_metadata,
batch_to_kv_state,
kv_cache,
window_offset,
compressor_slot_mapping,
)
def forward_prefill(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
seq_lens = common_metadata.seq_lens
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
ratio, overlap = self.compress_ratio, self.overlap
dtype = x.dtype
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
compress_lens = query_lens // self.compress_ratio
cu_compress_lens = torch.cat([
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
torch.cumsum(compress_lens, dim=0)],
)
compress_positions = []
for i in range(len(seq_lens)):
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
remainder = seqlen % ratio
cutoff = seqlen - remainder
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
positions_ = pos[:cutoff:ratio].contiguous()
compress_positions.append(positions_)
kv_positions = torch.cat(compress_positions, dim=0)
total_compress_len = cu_compress_lens[-1].item()
kv = torch.empty(
[total_compress_len, self.head_dim],
dtype=kv_pack.dtype,
device=kv_pack.device,
)
mlu_ops.fused_compress_multi_kv(
kv = kv_pack,
score = score_pack,
kv_state = self.kv_state,
score_state = self.score_state,
state_batch_idx = batch_to_kv_state,
cu_seqlens = query_start_loc,
ape = self.ape,
max_seqlen = common_metadata.max_query_len,
overlap = overlap,
compressed_kv = kv,
)
if kv.size(0) == 0:
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
kv = self.norm(kv.to(dtype))
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
# use compressed cu_seqlens here, so can not call rotary_emb directly
kv_rope = mlu_ops.rotary_embedding(
kv_rope,
self.rotary_emb.sin_,
self.rotary_emb.cos_,
kv_positions,
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
True, # interleaved
True, # discrete
False,
common_metadata.max_query_len,
)
if self.rotate:
kv = rotate_activation(kv)
mlu_ops.reshape_paged_cache(
kv.unsqueeze(1),
None,
kv_cache,
None,
compressor_slot_mapping,
)
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional
import torch
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm_mlu.model_executor.models.dp_utils import (
tensor_model_parallel_all_gather_dp, DataParallelRuntimeParams)
class DPLogitsProcessor(LogitsProcessor):
"""DP LogitsProcessor."""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
dp_params: Optional[DataParallelRuntimeParams] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
batch_sizes = None
if (lm_head.tp_group is not None
and dp_params is not None
and dp_params.logits_batch_split_list is not None):
batch_sizes = dp_params.logits_batch_split_list
hidden_states = tensor_model_parallel_all_gather_dp(
group_num_tokens=batch_sizes,
rank=lm_head.tp_rank,
hidden_states=hidden_states,
group=lm_head.tp_group,
)
logits = lm_head.quant_method.apply(
lm_head, hidden_states, bias=embedding_bias)
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits, tp_group=lm_head.tp_group)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits, tp_group=lm_head.tp_group)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., : self.org_vocab_size]
if batch_sizes is not None:
offset = sum(batch_sizes[:lm_head.tp_rank])
logits = logits[offset : offset + batch_sizes[lm_head.tp_rank]]
return logits
def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
dp_params: Optional[DataParallelRuntimeParams] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
# Get the logits for the next tokens.
logits = self._get_logits(
hidden_states, lm_head, embedding_bias, dp_params)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale
return logits

View File

@@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
method_has_implemented_embedding,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
DEFAULT_VOCAB_PADDING_SIZE,
get_masked_input_and_mask,
pad_vocab_size,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
divide,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_logits_tp_group,
get_logits_tp_world_size,
get_logits_tp_rank,
)
from vllm_mlu.model_executor.models.dp_utils import (
DataParallelRuntimeParams,
tensor_model_parallel_all_gather_dp,
)
class DPVocabParallelEmbedding(VocabParallelEmbedding):
"""DP Embedding parallelized in the vocabulary dimension."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
torch.nn.Module.__init__(self)
"""
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support other parallel
"""
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_world_size = get_tensor_model_parallel_world_size()
self.tp_group = None
logits_tp_world_size = get_logits_tp_world_size()
if logits_tp_world_size != self.tp_world_size:
self.tp_group = get_logits_tp_group()
self.tp_world_size = logits_tp_world_size
self.tp_rank = get_logits_tp_rank()
# Keep the input dimensions.
tp_rank = self.tp_rank
self.tp_size = self.tp_world_size
"""
=================
End of MLU Hijack
=================
"""
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def forward(self, input_,
dp_params: Optional[DataParallelRuntimeParams] = None):
token_split_list = None
if (dp_params is not None
and self.tp_group is not None
and dp_params.emb_token_split_list is not None):
token_split_list = dp_params.emb_token_split_list
input_ = tensor_model_parallel_all_gather_dp(
group_num_tokens=token_split_list,
rank=self.tp_rank,
hidden_states=input_.reshape(-1, 1),
group=self.tp_group,
).reshape(-1)
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
if token_split_list is not None:
offset = sum(token_split_list[:self.tp_rank])
output = output[offset : offset + token_split_list[self.tp_rank]]
return output
class DPParallelLMHead(DPVocabParallelEmbedding):
"""DP Parallelized LM head.
NOTE: A copy of ParallelLMHead class, and only change its parent
from VocabParallelEmbedding to DPVocabParallelEmbedding.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")

View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn.functional as F
from typing import Any
from vllm.distributed import (
get_parallel_world_size_with_group,
get_parallel_rank_with_group,
)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
ColumnParallelLinear,
RowParallelLinear
)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import set_is_gated
logger = init_logger(__name__)
class FeedForward(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
bias: bool,
quant_config: QuantizationConfig | None = None,
skip_bias_add: bool = False,
reduce_results: bool = True,
prefix: str = "",
tp_group: Any = None,
keep_full_weights: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.is_gated = is_gated
self.bias = bias
self.up_proj_name = up_proj_name
self.down_proj_name = down_proj_name
self.quant_config = quant_config
self.is_initialized = False
self.skip_bias_add = skip_bias_add
self.reduce_results = reduce_results
self.use_bt_ffn = True
set_is_gated(self.is_gated)
# modify tp_size, tp_rank and tp_group when enable data parallel
self.tp_size = get_parallel_world_size_with_group(tp_group)
self.tp_rank = get_parallel_rank_with_group(tp_group)
self.tp_group = tp_group
self.keep_full_weights = keep_full_weights
if self.keep_full_weights:
self.tp_size = 1
self.tp_rank = 0
self.tp_group = None
# up_proj with gate or not
if self.is_gated:
up_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
else:
up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=bias,
skip_bias_add=skip_bias_add,
quant_config=quant_config,
prefix=f"{prefix}.{up_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(up_proj_name, up_proj)
# down_proj
down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
skip_bias_add=skip_bias_add,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.{down_proj_name}",
tp_group=self.tp_group,
keep_full_weights=keep_full_weights)
self.register_module(down_proj_name, down_proj)
def prepare_weight(self):
if not self.is_initialized:
# alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now
self.alpha = 1.0
self.beta = 0.0
# place it here to avoid the overhead of calling it in the forward pass
self.is_initialized = True
def _forward(self, hidden_states):
self.prepare_weight()
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
act_dict = {
"relu": F.relu,
"gelu": F.gelu,
"silu": F.silu,
}
fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias)
if self.is_gated:
d = fc1.shape[-1] // 2
fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:]
else:
fc1 = act_dict[self.hidden_act](fc1)
fc2 = F.linear(fc1, down_proj.weight, bias=None)
fc2 = tensor_model_parallel_all_reduce(fc2)
if not self.skip_bias_add:
fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2
return fc2
def forward_naive(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None
):
'''
used by quant_tools
'''
assert self.quant_config is None, "ffn naive forward dosen't support quantization"
assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale"
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
fc1, bias = up_proj(hidden_states)
if bias is not None:
fc1 += bias
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(fc1, residual=residual_)
if self.skip_bias_add:
return out, bias
return out
def forward(
self,
hidden_states,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
):
self.prepare_weight()
if self.use_bt_ffn is False:
return self.forward_naive(hidden_states, residual, None)
up_proj = getattr(self, self.up_proj_name)
down_proj = getattr(self, self.down_proj_name)
residual_ = None if self.tp_rank > 0 else residual
if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA)
and not isinstance(down_proj, BaseLayerWithLoRA)):
# The matmul formula is the following:
# mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual
# output = active(mul_out)
# Notes: We cannot use the activation function in matmul because it does not support gated operation
# we might support its in tmo matmul in the future
up_proj_weight = up_proj.weight
down_proj_weight = down_proj.weight
if self.keep_full_weights and use_tp_weight:
up_proj_weight = up_proj.tp_weight
down_proj_weight = down_proj.tp_weight
fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias,
None, 'none', self.alpha, self.beta)
act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype)
beta = 0.0
if residual_ is not None:
beta = 1.0
residual_ = residual_.view(-1, residual_.shape[-1])
out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta)
# bias if existed need to add after second matmul according to the original design of vllm
if self.reduce_results:
out = tensor_model_parallel_all_reduce(out_, self.tp_group)
else:
out = out_
# do the bias add if needed
if not self.skip_bias_add:
out = out + down_proj.bias if down_proj.bias is not None else out
else:
return out, down_proj.bias
else:
fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight)
if bias is not None:
fc1 += bias
input_scale= None
if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and
self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8):
down_proj.quant_method.skip_quant_input = True
down_proj_smooth = down_proj.smooth
if self.keep_full_weights and use_tp_weight:
assert down_proj.tp_smooth is not None, "tp_smooth is not initialized"
down_proj_smooth = down_proj.tp_smooth
fc1, input_scale = mlu_ops.per_token_smooth_quantize(
fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated)
else:
fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated)
out, bias = down_proj(
fc1, residual=residual_, smooth_quant_scale=input_scale,
use_tp_weight=use_tp_weight, output=output)
if self.skip_bias_add:
return out, bias
return out

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,935 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe_kernel_gptq_awq,
write_zeros_to_output,
get_default_config,
try_get_optimal_moe_config,
_get_config_quant_dtype,
)
from vllm.model_executor.layers.fused_moe.utils import (
activation_without_mul,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
b_bias_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
stride_bbe, # bias expert stride
stride_bbn, # bias N stride
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
'''
=============================
Modify by vllm_mlu
=============================
@brief: Split the program ID into two dimensions (pid_0 and pid_1)
'''
pid_0 = tl.program_id(axis=0)
pid_1 = tl.program_id(axis=1)
pid = pid_1 * tl.num_programs(axis=0) + pid_0
'''
==================
End of MLU Hijack
==================
'''
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_int8_w8a16:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# channel-wise
elif per_channel_quant:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if HAS_BIAS:
# bias shape: [num_experts, N]
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
if use_fp8_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert block_shape is None or triton.cdiv(
B.size(-2), block_shape[0]
) == B_scale.size(-2)
assert block_shape is None or triton.cdiv(
B.size(-1), block_shape[1]
) == B_scale.size(-1)
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
'''
=============================
Modify by vllm_mlu
=============================
@brief: Split the program ID into two dimensions (pid_0, pid_1)
'''
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
assert not (use_int8_w8a16 or use_int4_w4a16)
'''
==================
End of MLU Hijack
==================
'''
HAS_BIAS = B_bias is not None
if (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
config = config.copy()
config.update(
get_moe_wna16_block_config(
config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens,
size_k=A.size(1),
size_n=B.size(1),
num_experts=B.size(1),
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"],
)
)
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(
A,
C,
B,
B_scale,
B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
bit,
)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
A.size(1),
EM,
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else:
config = config.copy()
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
fused_moe_kernel[grid](
A,
B,
C,
B_bias,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
B.size(2),
EM,
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_bias.stride(0) if B_bias is not None else 0,
B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
True,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
w1_bias,
w2_bias,
)
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
pass
direct_register_custom_op(
op_name="outplace_fused_experts_mlu",
op_func=outplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=outplace_fused_experts_fake,
dispatch_key="PrivateUse1",
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
)
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
return torch.ops.vllm.outplace_fused_experts_mlu(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in {
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
)
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
num_tokens = hidden_states.size(0)
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = _get_config_dtype_str(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
ocp_mx_scheme=ocp_mx_scheme,
dtype=hidden_states.dtype,
)
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
# quantized prior to calling fused_experts.
quant_dtype = _get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
ocp_mx_scheme=ocp_mx_scheme,
)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
block_shape=block_shape,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Only use the default config
'''
config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1],
hidden_states.dtype, block_shape)
'''
==================
End of MLU Hijack
==================
'''
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(
M * top_k_num * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty(
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE, num_tokens),
)
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.size()
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qcurr_hidden_states, a1q_scale = _fp8_quantize(
curr_hidden_states, a1_scale, block_shape)
else:
qcurr_hidden_states = curr_hidden_states
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)
invoke_fused_moe_kernel(
qcurr_hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Activate by mlu_ops
'''
intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N),
act_mode=activation,
is_gated=True)
'''
==================
End of MLU Hijack
==================
'''
a2q_scale: Optional[torch.Tensor] = None
if use_fp8_w8a8:
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
else:
qintermediate_cache2 = intermediate_cache2
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace moe_sum with torch.sum
Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513
'''
if topk_ids.shape[1] == 2:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
).squeeze(dim=1)
elif topk_ids.shape[1] > 2:
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
'''
==================
End of MLU Hijack
==================
'''
return out_hidden_states

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
#TODO: support `routed_scaling_factor`
assert routed_scaling_factor == 1.0, (
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
)
use_fused_kernel = topk_group is None
if use_fused_kernel:
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
return mlu_ops.fused_moe(
x,
router_logits,
layer.w13_weight, layer.w2_weight,
None, None, # bias1, bias2
None, # residual
None, # input_smooth
None, # act_smooth
None, None, # w1_scale, w2_scale
top_k,
renormalize,
True, # gated
activation
)
else:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
MluHijackObject.apply_hijack(
UnquantizedFusedMoEMethod,
UnquantizedFusedMoEMethod.forward_oot,
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
)

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, round_up
'''
=============================
Modify by vllm_mlu
=============================
@brief: Implementation of moe_align_block_size_triton.
Note: the implemtentation has been removed from vllm since the
cuda implementation is more efficient.
'''
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = cdiv(numel, num_experts)
sorted_token_ids.fill_(numel)
expert_ids.zero_()
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
'''
==================
End of MLU Hijack
==================
'''
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Only use triton to implement moe_align_block_size
'''
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
'''
==================
End of MLU Hijack
==================
'''
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
import torch
from vllm.utils.math_utils import cdiv
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
assert block_shape is not None
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group
)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.compressor import (
Compressor,
rotate_activation,
)
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
class Indexer(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
prefix: str = "",
**kwargs,
):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.n_heads = config.index_n_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.n_local_heads = config.index_n_heads // self.tp_size
self.head_dim = config.index_head_dim
self.rope_head_dim = config.rope_head_dim
self.index_topk = config.index_topk
self.q_lora_rank = config.q_lora_rank
self.window_size = config.window_size
self.block_size = vllm_config.cache_config.block_size
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=None,
prefix=f"{prefix}.wq_b",
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=None,
params_dtype = torch.bfloat16,
prefix=f"{prefix}.weights_proj",
)
self.softmax_scale = self.head_dim ** -0.5
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
self.compress_ratio = compress_ratio
self.max_model_len = vllm_config.model_config.max_model_len
self.rotary_emb = rope
self.tp_group = get_tp_group()
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
self.freqs_cis = None
def forward_prefill(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
k_full: torch.Tensor,
context_lens: torch.Tensor,
):
assert attn_metadata.prefill.chunked_context is None, \
f"Prefill chunked context is not supported."
query_start_loc = attn_metadata.prefill.query_start_loc
cu_seq_q_lens = query_start_loc
cu_seq_k_lens = torch.zeros(
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
)
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
attn_metadata.prefill.query_start_loc
seq_lens = torch.diff(cu_seq_k_lens)
batch_size = seq_lens.shape[0]
new_block_tables = torch.empty(
[attn_metadata.num_prefill_tokens, self.index_topk],
dtype=torch.int32,
device=q.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_prefill_tokens],
dtype=torch.int32,
device=q.device,
)
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
max_seq_len = q_seq_lens.max().item()
batch_size = q_seq_lens.size(0)
max_compressed_kv_len = max_seq_len // self.compress_ratio
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
# The layout of linear kv is as follows:
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
for i in range(batch_size):
start = cu_seq_k_lens[i].item()
kv_cache_block_table[i] = torch.arange(
start, start + max_compressed_kv_len,
dtype=torch.int32,
device=q.device,
)
# offset total origin_kv len
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
# query: (tokens, index_head, index_head_dim)
# k_full: (tokens, index_head_dim)
# weights: (tokens, index_head, 1)
mlu_ops.masked_indexer_select_paged_kv_prefill(
query=q,
key_value=k_full,
weights=weights.unsqueeze(-1),
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=cu_seq_q_lens,
cu_seq_k_lens=cu_seq_k_lens,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=None,
)
return new_block_tables, new_context_lens
def forward_decode(
self,
q: torch.Tensor,
x: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
):
block_table = attn_metadata.decode.block_table
batch_size = block_table.shape[0]
seq_len = x.shape[0] // batch_size
q = q.view(batch_size, seq_len, *q.shape[1:])
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
seq_lens = attn_metadata.decode.seq_lens
k_block_table = block_table
seq_len = x.shape[0] // batch_size
new_block_tables = torch.empty(
[batch_size, seq_len, self.index_topk],
dtype=torch.int32,
device=block_table.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset=torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset.fill_(self.window_size)
mlu_ops.masked_indexer_select_paged_kv_decode(
query=q,
k_cache=k_cache,
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
kv_cache_block_table=block_table,
k_context_lens=seq_lens // self.compress_ratio,
k_cache_block_table=k_block_table,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=kv_cache_block_table_offset,
)
# [batch, seq_q, index_topk] -> [batch, index_topk]
new_block_tables = new_block_tables.squeeze(1)
return new_block_tables, new_context_lens
def forward(self,
x: torch.Tensor,
qr: torch.Tensor,
positions: torch.Tensor,
offsets: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
indexer_kv_cache: torch.Tensor,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
rd = self.rope_head_dim
q = self.wq_b(qr)[0]
q = q.unflatten(-1, (self.n_heads, self.head_dim))
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
q_pack = rotate_activation(q)
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
num_decode_tokens = attn_metadata.num_decode_tokens
compressed_kv = self.compressor(
x,
positions,
attn_metadata,
batch_to_kv_state,
indexer_kv_cache,
0,
compressor_slot_mapping,
)
if attn_metadata.prefill:
assert compressed_kv is not None and compressed_kv.dim() == 3
compressed_kv = compressed_kv.squeeze(-2)
compressed_context_lens = query_lens // self.compress_ratio
prefill_q = q_pack[num_decode_tokens:, ...]
prefill_weights = weights_pack[num_decode_tokens:, ...]
prefill_block_tables, prefill_context_lens = self.forward_prefill(
prefill_q,
indexer_kv_cache,
prefill_weights,
attn_metadata,
compressed_kv,
compressed_context_lens,
)
if attn_metadata.decode:
decode_x = x[:num_decode_tokens, ...]
decode_q = q_pack[:num_decode_tokens, ...]
decode_weights = weights_pack[attn_metadata.num_prefills:]
decode_block_tables, decode_context_lens = self.forward_decode(
decode_q,
decode_x,
indexer_kv_cache,
decode_weights,
attn_metadata,
)
if attn_metadata.prefill and attn_metadata.decode:
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
elif attn_metadata.prefill:
new_block_tables = prefill_block_tables
new_context_lens = prefill_context_lens
else:
new_block_tables = decode_block_tables
new_context_lens = decode_context_lens
return new_block_tables, new_context_lens

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
@CustomOp.register("quant_fusion_rms_norm")
class QuantFusionRMSNorm(RMSNorm):
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
super().__init__(hidden_size, variance_epsilon)
assert not isinstance(
proj.quant_method, UnquantizedLinearMethod
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
proj.quant_method.skip_quant_input = True
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
quant_scale = proj.smooth.data
else:
quant_scale = proj.scale_to_int.data
self.dynamic_quant = dynamic_quant
self.quant_scale = torch.nn.Parameter(quant_scale)
def forward(
self, x: torch.Tensor, residual: torch.Tensor | None = None
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
return mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
False,
self.quant_scale.data,
self.dynamic_quant,
)
@CustomOp.register("quant_fusion_layer_norm")
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
super().__init__(hidden_size, variance_epsilon)
assert not isinstance(
proj.quant_method, UnquantizedLinearMethod
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
proj.quant_method.skip_quant_input = True
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
quant_scale = proj.smooth.data
else:
quant_scale = proj.scale_to_int.data
self.dynamic_quant = dynamic_quant
self.quant_scale = torch.nn.Parameter(quant_scale)
def forward(
self, x: torch.Tensor, residual: torch.Tensor | None = None
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
bias = None if self.bias is None else self.bias.data
return mlu_ops.fused_layer_norm(
x,
residual,
self.weight.data,
bias,
None,
self.eps,
False,
self.quant_scale.data,
self.dynamic_quant,
)
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
org_shape = x.shape
x = x.reshape(-1, self.weight.data.shape[0])
if out is not None:
out = out.view(-1, self.weight.data.shape[0])
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
x = mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
True,
out=out,
)
else:
x = mlu_ops.fused_rms_norm(
x,
residual,
self.weight.data,
None,
None,
self.variance_epsilon,
False,
out=out,
)
if out is not None:
return x
if residual is None:
assert isinstance(x, torch.Tensor)
return x.view(org_shape)
assert isinstance(x, tuple)
assert len(x) == 2
return x[0].view(org_shape), x[1].view(org_shape)
MluHijackObject.apply_hijack(
RMSNorm,
RMSNorm.forward_oot,
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
)

View File

@@ -0,0 +1,693 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Any
import torch
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, split_tensor_along_last_dim,
get_parallel_rank_with_group, get_parallel_world_size_with_group,
get_tp_world_group, get_tp_world_world_size, get_tp_world_rank)
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase,
ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED.extend([
"GPTQMluLinearMethod",
"AWQMluLinearMethod"
])
vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual parameter.
@brief: dispatch unquantized_gemm to mlu ops.
'''
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
residual: torch.Tensor | None = None
) -> torch.Tensor:
beta = 0.0
if residual is not None:
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]),
layer.weight,
bias, residual, 'none', 1.0, beta).view(res_shape)
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__LinearBase____init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
vllm__module_executor__layers__linear__LinearBase____init__org(
self=self,
input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel
'''
self.tp_group = tp_group
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
self.tp_size = self.tp_world_size
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
self.keep_full_weights = keep_full_weights
if self.keep_full_weights or disable_tp:
self.tp_group = None
self.tp_world_size = 1
self.tp_size = self.tp_world_size
self.tp_rank = 0
self.tp_world_size_org = get_tp_world_world_size()
self.tp_rank_org = get_tp_world_rank()
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(ColumnParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__
'''
# Divide the weight matrix along the last dimension.
# self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
# self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
'''
=================
End of MLU Hijack
=================
'''
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size) for output_size in self.output_sizes
]
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add input_scale and use_tp_weight parameter.
'''
kwargs = {'bias': bias}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
output_parallel = self.quant_method.apply(self, input_, **kwargs)
'''
==================
End of MLU Hijack
==================
'''
if self.gather_output and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group param to tensor_model_parallel_all_gather
'''
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
'''
=============================
Modify by vllm_mlu
=============================
@brief: checkout output_sizes after init to get self.tp_world_size
@brief: add keep_full_weights for dp parallelize shared expert
'''
super(MergedColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
output_sizes=self.output_sizes,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_weight = Parameter(
self.weight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_qweight = Parameter(
self.qweight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
self.tp_per_channel_scale = Parameter(
self.per_channel_scale.data.new_empty((out_dim_tp)),
requires_grad=False,
)
else:
raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
):
loaded_weight_orig = loaded_weight
output_dim = getattr(param, "output_dim", None)
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
loaded_shard_id=loaded_shard_id,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
# load into tp weight
if self.keep_full_weights:
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size)
tp_weight_shard.copy_(tp_weight)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if output_dim is None:
return
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
if loaded_weight_orig.ndim == 1:
tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size)
elif loaded_weight_orig.ndim == 2:
tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size)
else:
raise ValueError("only support rank 1 and 2 when using tp_weight")
tp_weight_shard.copy_(tp_weight)
else:
raise TypeError(f"quant method is expected to be either unquantized or smoothquant")
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__RowParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(RowParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
# Divide the weight matrix along the last dimension
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
if hasattr(self, "smooth"):
assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor"
dim = self.smooth.shape[0]
dim_tp = divide(dim, tp_size)
self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)),
requires_grad=False)
else:
raise TypeError("quant method expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
self.update_param_tp_status()
def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader(
self, param: Parameter, loaded_weight: torch.Tensor
):
input_dim = getattr(param, "input_dim", None)
loaded_weight_orig = loaded_weight
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
if input_dim is None:
return
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size)
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
shard_view = self.weight.narrow(input_dim, start_idx, shard_size)
self.tp_weight.copy_(shard_view)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if loaded_weight_orig.ndim == 1:
shard_view = self.smooth.narrow(input_dim, start_idx, shard_size)
self.tp_smooth.copy_(shard_view)
elif loaded_weight_orig.ndim == 2:
shard_view = self.qweight.narrow(input_dim, start_idx, shard_size)
self.tp_qweight.copy_(shard_view)
else:
raise ValueError("only rank 1 and 2 is supported for tp_weight")
else:
raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant")
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters.
'''
def vllm__module_executor__layers__linear__RowParallelLinear__forward(
self,
input_,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add additional matmul parameters.
'''
residual_ = None if self.tp_rank > 0 else residual
kwargs = {'bias': bias_, 'residual': residual_}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
if output is not None:
kwargs['output'] = output
output_parallel = self.quant_method.apply(self, input_parallel, **kwargs)
'''
=================
End of MLU Hijack
=================
'''
if self.reduce_results and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
'''
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.apply_hijack(LinearBase,
LinearBase.__init__,
vllm__module_executor__layers__linear__LinearBase____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__ColumnParallelLinear____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.forward,
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.weight_loader,
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.__init__,
vllm__module_executor__layers__linear__RowParallelLinear____init__)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.weight_loader,
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__module_executor__layers__linear__RowParallelLinear__forward)

View File

@@ -0,0 +1,744 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""Inference-only MOE model."""
from typing import Optional, Any, List, Dict
import torch
from torch import nn
from vllm.distributed import (
divide,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu._mlu_utils import *
from vllm_mlu.distributed.parallel_state import(
cnclep_dispatch, cnclep_combine)
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
class LongCatSparseMoeMlp(SparseMoeMlp):
"""
sparse moe mlp layer specific to longcat model
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
up_proj_name: str,
is_gated: bool,
down_proj_name: str,
has_bias: bool,
skip_bias_add: bool = False,
renormalize:bool = False,
hidden_act: str = "silu",
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
is_use_fused_moe: bool = False,
expert_group: Optional[int] = 1,
topk_group: Optional[int] = 1,
scoring_func: str = "softmax",
topk_method: str = "",
routed_scaling_factor: float = 1.0,
tp_group: Any = None,
use_all2all: bool = False,
num_zero_experts: int = 0,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
up_proj_name=up_proj_name,
is_gated=is_gated,
down_proj_name=down_proj_name,
has_bias=has_bias,
skip_bias_add=skip_bias_add,
renormalize=renormalize,
hidden_act=hidden_act,
params_dtype=params_dtype,
quant_config=quant_config,
is_use_fused_moe=is_use_fused_moe,
expert_group=expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
topk_method=topk_method,
routed_scaling_factor=routed_scaling_factor,
tp_group=tp_group,
use_all2all=use_all2all,
init_avg_moe=False,
)
self.num_zero_experts = num_zero_experts
self.total_experts_including_zero = self.num_total_experts + self.num_zero_experts
self.use_quant_all2all = use_all2all and quant_config is not None
self.zero_expert_size = divide(self.num_zero_experts, self.moe_ep_size)
self.start_zero_expert_id = (
self.num_total_experts + self.moe_ep_rank * ((self.num_zero_experts + self.moe_ep_size - 1) // self.moe_ep_size)
)
if VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg:
n_tokens = SparseMoeMlp.max_batched_token * self.dp_size
expert_group = self.moe_ep_size
val = 1.0 / float(self.total_experts_including_zero)
SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32)
if VLLM_RANDOM_MOE_EN:
import numpy as np
# example deepseekv2: experts 160 topk 6
# avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,......
array = np.stack([np.random.permutation(self.total_experts_including_zero)[:top_k] for _ in range(n_tokens)])
table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32)
else:
# example deepseekv2: experts 160
# avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,......
import math
batch_table = math.ceil(n_tokens * top_k / self.total_experts_including_zero) * self.total_experts_including_zero
hi_val = batch_table // self.total_experts_including_zero
table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view(
hi_val, expert_group, num_experts // expert_group).transpose(1, 2)
if self.num_zero_experts > 0:
# Longcat model, for avg expert, we choose eight non-zero experts and four zero
# experts for each token accorrding to the paper.
assert num_experts == 512 and num_zero_experts == 256 and top_k == 12
assert num_zero_experts % expert_group == 0
non_zero_expert_num_per_token = 8
zero_expert_num_per_token = 4
zero_expert_table = torch.arange(
num_experts, num_experts + num_zero_experts, dtype=table.dtype, device=table.device).view(
expert_group, num_zero_experts // expert_group).transpose(0, 1).flatten()
non_zero_expert_table = table[0].flatten()
token_expert_list = []
for idx in range(0, num_experts // non_zero_expert_num_per_token):
token_expert_list.append(non_zero_expert_table[
idx * non_zero_expert_num_per_token:
idx * non_zero_expert_num_per_token + non_zero_expert_num_per_token])
token_expert_list.append(zero_expert_table[
idx * zero_expert_num_per_token:
idx * zero_expert_num_per_token + zero_expert_num_per_token])
avg_expert_table = torch.cat(token_expert_list)
table = avg_expert_table.repeat(hi_val)
SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k)
SparseMoeMlp.is_expert_avg = True
def forward_experts_nofused_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
assert self.moe_ep_size == 1
assert not self.use_all2all
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = mlu_ops.moe_gen_idx(
topk_indices.to(torch.int32), total_num_experts)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if cusum_token_count[-1] == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_gather_idx, cusum_token_count,
start_expert_id=self.start_expert_id,
expert_size=self.end_expert_id - self.start_expert_id)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_gather_idx, cusum_token_count,
start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count[:self.num_total_experts]):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
if expert_idx < self.num_total_experts:
expert_output = self.experts[expert_idx](expert_hidden_states)
else:
expert_output = expert_hidden_states
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
num_normal_tokens = cusum_token_count[self.num_total_experts]
expand_hidden_states[:num_normal_tokens] = expand_output
# reduce normal experts
final_hidden_states = mlu_ops.moe_combine_result(
expand_hidden_states, topk_weights, scatter_idx,
residual_, cusum_token_count, start_expert_id=self.start_expert_id,
expert_size=self.end_expert_id - self.start_expert_id, bias=None)
# reduce zero experts
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
final_hidden_states = mlu_ops.moe_combine_result(
expand_hidden_states_zero, topk_weights, scatter_idx,
final_hidden_states, cusum_token_count, start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size, bias=None,
output=final_hidden_states)
return final_hidden_states
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_all2all_int8_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias2=self.b2
input_smooth=self.a13_scale_all_experts
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
act_mode=self.hidden_act
quant_input=None
max_m = hidden_states.shape[0]
reduce_weight = topk_weights
expert_id = topk_indices
expand_idx, combine_idx, token_count, cusum_token_count \
= mlu_ops.moe_gen_idx(expert_id, total_num_experts)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
)
quant_size = self.hidden_size
quant_input = dispatch_send_token_tensor[:, : quant_size]
input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32)
quant_input, input_scale = mlu_ops.moe_quantize(
hidden_states, input_smooth, None, token_count[:self.num_total_experts],
expand_idx, None,
output=quant_input,
output_scale=input_scale)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num)
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
pad_num = self.max_num_tokens_per_rank
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size))
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv,
self.input_scale_recv)
max_m = self.max_num_tokens_per_expert
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1,
tokens_per_local_expert,
None, None, None, None,
self.input_scale_recv.view(torch.float32).flatten(),
w1_scale, dtype, max_m)
# continue reusing self.quant_input_recv and self.input_scale_recv
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
tokens_per_local_expert,
output=quant_input,
output_scale=input_scale_fp32,
act_mode=act_mode,
is_gated=self.is_gated)
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
tokens_per_local_expert,
None, None, None, None, input_scale, w2_scale, dtype, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size)
combine_recv_layout = self.dispatch_recv_layout
# combine
combine_args = dict(
token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=combine_recv_layout,
send_token=None,
recv_token=None)
cnclep_combine(**combine_args)
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
residual_ = None
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
assert self.moe_ep_size > 1
# zero expert reduce
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=hidden_states)
return output.view(ori_input_shape)
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_all2all_bf16_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None):
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
gated=self.is_gated
act_mode=self.hidden_act
max_m = hidden_states.shape[0]
reduce_weight = topk_weights
expert_id = topk_indices
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = \
mlu_ops.moe_gen_idx(expert_id, total_num_experts)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
.view(hidden_states.dtype)
)
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_token_tensor.copy_(expand_hidden_states)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num,
use_quant_dispatch=False,
)
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
pad_num = self.max_num_tokens_per_rank
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size)
.view(hidden_states.dtype)
)
self.quant_input_recv = self.quant_input_recv.view(hidden_states.dtype)
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv)
max_m = self.max_num_tokens_per_expert
gemm_out = mlu_ops.group_gemm(
self.quant_input_recv, w1, tokens_per_local_expert,
None, None, None, None, max_m)
act_out = mlu_ops.moe_active(
gemm_out, act_mode, gated)
gemm_out = mlu_ops.group_gemm(
act_out, w2, tokens_per_local_expert,
None, None, None, None, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(
self.max_num_tokens_recv, -1).view(hidden_states.dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
self.dispatch_recv_token_num, self.moe_ep_size)
combine_recv_layout = self.dispatch_recv_layout
combine_args = dict(
token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=combine_recv_layout,
send_token=None,
recv_token=None,
use_quant_dispatch=False,
)
cnclep_combine(**combine_args)
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
residual_ = None
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=bias2, output=hidden_states)
# zero expert reduce
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=hidden_states)
return output.view(ori_input_shape)
def forward_before_dispatch(self, hidden_states: torch.Tensor,
topk_indices: torch.Tensor):
# gate and softmax topk is called in router for longcat
# other models can do these operations here
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(
topk_indices, self.total_experts_including_zero)
num_token_expand = hidden_states.shape[0] * self.top_k
dispatch_bytes = num_token_expand * self.dispatch_token_size
dispatch_send_token_tensor = (
self.dispatch_send_buffer[:dispatch_bytes]
.view(num_token_expand, self.dispatch_token_size)
)
if self.use_quant_all2all:
hidden_states_stride = self.hidden_size
quant_input = dispatch_send_token_tensor[:, : hidden_states_stride]
input_scale = dispatch_send_token_tensor[:, hidden_states_stride :].view(torch.float32)
# expand input + quantize
quant_input, input_scale = mlu_ops.moe_quantize(
hidden_states, self.a13_scale_all_experts, None,
token_count[:self.num_total_experts],
expand_idx, None,
output=quant_input,
output_scale=input_scale)
# expand input of zero-expert
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
else:
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts)
dispatch_send_token_tensor = dispatch_send_token_tensor.view(
hidden_states.dtype)
dispatch_send_token_tensor.copy_(expand_hidden_states)
del expand_hidden_states
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.num_total_experts,
expert_size=self.num_zero_experts)
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(
token_count[:self.num_total_experts], self.moe_ep_size)
return combine_idx, token_count, cusum_token_count, dispatch_send_layout, expand_hidden_states_zero
def forward_dispatch(self, token_num: int, dispatch_send_layout: torch.Tensor,
token_count: torch.Tensor):
num_token_expand = token_num * self.top_k
cnclep_dispatch(self.dispatch_token_size,
num_token_expand,
dispatch_send_layout,
token_count[:self.num_total_experts],
self.dispatch_recv_layout,
self.dispatch_recv_token_num,
use_quant_dispatch=self.use_quant_all2all)
def forward_before_combine(self, hidden_states_dtype: torch.dtype):
recv_token_num = self.dispatch_recv_token_num.view(
self.moe_ep_size, self.num_experts_per_rank)
(
gather_by_expert_index,
gather_by_rank_index,
tokens_per_local_expert,
token_sum,
cusum_token_count
) = mlu_ops.moe_all2all_gen_gather_index(
recv_token_num, self.max_num_tokens_per_rank,
return_cusum_token_count=True)
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
dispatch_recv_token_tensor = (
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
.view(self.max_num_tokens_recv, self.dispatch_token_size))
max_m = self.max_num_tokens_per_expert
if self.use_quant_all2all:
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.quant_input_recv,
self.input_scale_recv)
# OPT: input_scale_recv_flatten can reuse self.input_scale_recv
input_scale_recv_flatten = self.input_scale_recv.view(torch.float32).flatten()
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, self.w13,
tokens_per_local_expert,
None, None, None, None,
input_scale_recv_flatten,
self.w13_scale, hidden_states_dtype, max_m)
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
input_scale_fp32 = input_scale_recv_flatten[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, self.a2_scale, None,
tokens_per_local_expert,
output=quant_input,
output_scale=input_scale_fp32,
act_mode=self.hidden_act,
is_gated=self.is_gated)
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, self.w2, tokens_per_local_expert,
None, None, None, None, input_scale, self.w2_scale,
hidden_states_dtype, max_m)
else:
dispatch_recv_token_tensor = dispatch_recv_token_tensor.view(hidden_states_dtype)
self.input_recv = self.input_recv.view(hidden_states_dtype)
mlu_ops.gather_split(dispatch_recv_token_tensor,
gather_by_expert_index,
token_sum,
self.input_recv)
gemm_out = mlu_ops.group_gemm(
self.input_recv, self.w13, tokens_per_local_expert,
None, None, None, None, max_m)
act_out = self.input_recv[:, :gemm_out.shape[-1] // 2]
act_out = mlu_ops.moe_active(
gemm_out, self.hidden_act, self.is_gated, output=act_out,
bias=None, cusum_token_count=cusum_token_count,
start_expert_id=0, expert_size=self.num_experts_per_rank)
gemm_out = mlu_ops.group_gemm(
act_out, self.w2, tokens_per_local_expert,
None, None, None, None, max_m)
combine_send_token_tensor = self.combine_send_buffer.view(
self.max_num_tokens_recv, -1).view(hidden_states_dtype)
mlu_ops.gather_split(gemm_out,
gather_by_rank_index,
token_sum,
combine_send_token_tensor,
None)
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(
self.dispatch_recv_token_num, self.moe_ep_size)
return combine_send_layout
def forward_combine(self, token_num: int, combine_send_layout: torch.Tensor):
num_token_expand = token_num * self.top_k
# combine_recv_layout(self.dispatch_recv_layout) is calculated when cnclep_dispatch
# because dispatch and combine are inverse operation
cnclep_combine(token_byte=self.hidden_size * 2,
token_num=num_token_expand,
send_src_layout=combine_send_layout,
send_dst_layout=self.dispatch_recv_layout,
send_token=None,
recv_token=None,
use_quant_dispatch=self.use_quant_all2all)
def forward_after_combine(self, token_num: int,
reduce_weight: torch.Tensor,
combine_idx: torch.Tensor,
cusum_token_count: torch.Tensor,
expand_hidden_states_zero: torch.Tensor,
output_tensor_dtype: torch.dtype,
output_tensor: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
num_token_expand = token_num * self.top_k
numel_recv = num_token_expand * self.hidden_size
recv_token = (self.combine_recv_buffer.view(output_tensor_dtype)[:numel_recv]
.view(num_token_expand, self.hidden_size))
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
residual, cusum_token_count, start_expert_id=0,
expert_size=self.num_total_experts, bias=self.b2, output=output_tensor)
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.num_total_experts,
self.num_zero_experts, output=output_tensor)
return output
# no compute-communication parallel, for prototyping only, not in actual use.
# subject to becoming stale
def forward_group_experts_longcat(
self, hidden_states, total_num_experts, total_num_experts_per_rank,
topk_indices=None, topk_weights=None, residual_=None,
expand_idx=None, combine_idx=None, token_count=None, cusum_token_count=None):
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
ori_input_shape = hidden_states.shape
dtype = hidden_states.dtype
self.pack_params()
self.pack_params_after_loading()
w1=self.w13
w2=self.w2
bias1=self.b13
bias2=self.b2
input_smooth=self.a13_scale
act_smooth=self.a2_scale
w1_scale=self.w13_scale
w2_scale=self.w2_scale
gated=self.is_gated
act_mode=self.hidden_act
quant_input=None
start_expert_id=self.start_expert_id
expert_size = w1.size(0)
max_m = hidden_states.shape[0]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
# Check smooth quant parameters.
per_token_sq = False
if not is_fp8_quant:
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
"and absent at the same time.")
expert_id = topk_indices
reduce_weight = topk_weights
# gen_idx
if expert_id is not None:
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, total_num_experts)
# check quant
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
raise NotImplementedError
elif per_token_sq:
expand_hidden_states = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=start_expert_id,
expert_size=expert_size)
expand_hidden_states_zero = mlu_ops.moe_expand_input(
hidden_states, expand_idx, cusum_token_count,
start_expert_id=self.start_zero_expert_id,
expert_size=self.zero_expert_size)
quant_input, input_scale = mlu_ops.moe_quantize(
expand_hidden_states, input_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size])
else:
expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
expand_hidden_states_zero = mlu_ops.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, self.start_zero_expert_id, self.zero_expert_size)
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
gemm_out = mlu_ops.smooth_quant_group_gemm(
quant_input, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w1_scale, dtype, max_m)
else:
gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m)
# add_bias_active
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
raise NotImplementedError
elif per_token_sq:
quant_input = quant_input[:, :gemm_out.shape[-1] // 2]
input_scale = input_scale[:gemm_out.shape[0]]
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size],
output=quant_input,
output_scale=input_scale,
act_mode=act_mode,
is_gated=self.is_gated)
if ((is_fp8_quant and self.quant_config.activation_quant_method == 'per_token')
or per_token_sq):
# Remove the reference to gemm_out tensor.
# If that was the only reference, the tensors memory becomes eligible for deallocation
# So that we can reuse this memory for the new allocation of next gemm operation
# del gemm_out
gemm_out = mlu_ops.smooth_quant_group_gemm(
quant_input, w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, max_m,
output=expand_hidden_states)
else:
act_out = mlu_ops.moe_active(
gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2],
bias1, cusum_token_count, start_expert_id, expert_size)
gemm_out = mlu_ops.group_gemm(
act_out, w2, token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, max_m,
output=expand_hidden_states)
output = mlu_ops.moe_combine_result(
gemm_out, reduce_weight, combine_idx,
residual_, cusum_token_count, start_expert_id,
expert_size, bias2)
if self.moe_ep_size > 1 or self.moe_tp_rank == 0:
output = mlu_ops.moe_combine_result(
expand_hidden_states_zero, reduce_weight, combine_idx,
output, cusum_token_count, self.start_zero_expert_id,
self.zero_expert_size, bias2,
output=output)
return output.view(ori_input_shape)

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config
)
MLU_QUANTIZATION_METHODS= [
"smoothquant",
"weightonly",
"awq_mlu",
"gptq_mlu",
]
def register_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method not in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.append(quant_method)
def remove_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(quant_method)
def register_real_mlu_quantization_methods():
remove_fake_mlu_quantization_methods()
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
register_quantization_config("weightonly")(WeightOnlyConfig)
register_quantization_config("smoothquant")(SmoothQuantConfig)
register_quantization_config("awq_mlu")(AWQMluConfig)
register_quantization_config("gptq_mlu")(GPTQMluConfig)

View File

@@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("awq_mlu")
class AWQMluConfig(QuantizationConfig):
"""Config class for AWQMlu.
Reference: https://arxiv.org/abs/2306.00978
"""
# num_bits -> type
TYPE_MAP = {
4: {
False: scalar_types.uint4b8,
True: scalar_types.uint4,
},
8: {
False: scalar_types.uint8b128,
True: scalar_types.uint8,
}
}
VERSION = ["gemm"]
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
version: str = "gemm",
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
self.version = version
self.support_scale_zeros = False
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is supported for "
f"AWQMlu, but got {self.weight_bits} bits.")
if self.version not in self.VERSION:
raise ValueError(
"Currently, only gemm, gemv version is supported for "
f"AWQMlu, but got verion:{self.version}.")
if self.version in ["gemm"]:
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
else:
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
def __repr__(self) -> str:
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}), "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
return "awq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
version = cls.get_from_keys_or(config, ["version"],
default="gemm")
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "awq"
or user_quant == "awq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_mlu"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_mlu for"
" faster inference")
return None
@classmethod
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
version = quant_config.get("version", "gemm")
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
return False
if num_bits not in cls.TYPE_MAP:
return False
if version not in cls.VERSION:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
group_size=group_size,
has_zp=has_zp)
class AWQMluLinearMethod(LinearMethodBase):
"""Linear method for AWQMlu.
Args:
quant_config: The AWQMlu quantization config.
"""
def __init__(self, quant_config: AWQMluConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
qzeros = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
packed_qweight, scale_zeros = self.extract_autoawq(layer)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autoawq(self, layer: torch.nn.Module):
qweight = layer.qweight.data
qzeros = layer.qzeros.data
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
if izeros is not None:
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
iweights = iweights.view(iweights.shape[0], -1)
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
izeros = izeros.view(izeros.shape[0], -1)
if not self.quant_config.zero_point:
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
else:
izeros = None
return iweights, izeros
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
reverse_order_tensor = reverse_order_tensor.view(-1)
rweights = iweights[:, reverse_order_tensor]
if izeros is not None:
rzeros = izeros[:, reverse_order_tensor]
return rweights, rzeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,753 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import functools
from functools import partial
import importlib.util
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from typing import Any, Dict, List, Optional, Callable
from vllm import envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.fp8 import (
get_flashinfer_moe_backend,
ACTIVATION_SCHEMES,
Fp8Config,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
validate_fp8_block_shape
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
maybe_create_device_identity, Fp8LinearOp)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable marlin for MLU backend.
'''
if current_platform.is_rocm() or current_platform.is_out_of_tree():
use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm():
logger.warning_once("DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and block_quant
):
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
Fp8Config____init____org = Fp8Config.__init__
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
activation_quant_method: Optional[str] = None,
weight_quant_method: Optional[str] = None,
) -> None:
super(Fp8Config, self).__init__()
Fp8Config____init____org(
self,
is_checkpoint_fp8_serialized,
activation_scheme,
ignored_layers,
weight_block_size
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add class members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.activation_quant_method = activation_quant_method
self.weight_quant_method = weight_quant_method
assert (self.weight_block_size or \
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
"input dynamic per-token weight per-channel quantization yet."
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
cls, config: Dict[str, Any]
) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
activation_quant_method = cls.get_from_keys_or(config,
["activation_quant_method"],
'per_token')
weight_quant_method = cls.get_from_keys_or(config,
["weight_quant_method"],
None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
activation_quant_method=activation_quant_method,
weight_quant_method=weight_quant_method)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
tp_group = extra_weight_attrs.get("tp_group", None)
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
'''
==================
End of MLU Hijack
==================
'''
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Support weight per channel quantization.
@brief: Add tp_group to enable custom split.
'''
if self.weight_per_channel:
scale = ChannelQuantScaleParameter(
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
'''
==================
End of MLU Hijack
==================
'''
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
self,
quant_config: Fp8Config
):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False
# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
if self.weight_per_channel and self.activation_per_token:
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
)
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
self,
layer: Module,
) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: For dynamic activation and channel-wise weight quantization,
additional processing is not needed.
'''
if (self.quant_config.is_checkpoint_fp8_serialized
and self.weight_per_channel
and self.quant_config.activation_scheme == "dynamic"):
return
'''
==================
End of MLU Hijack
==================
'''
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert residual is None, "Fp8Linear residual is not supported yet."
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
assert self.weight_block_size is not None
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use activation per token quantization based on quantization config.
'''
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
weight_per_channel=self.weight_per_channel,
activation_per_token=self.activation_per_token)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
self,
quant_config: Fp8Config,
layer: torch.nn.Module
):
super(Fp8MoEMethod, self).__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: In mlu, always set self.use_marlin as False.
'''
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
'''
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
if scoring_func == "softmax":
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
route_scale=routed_scaling_factor,
)
elif scoring_func == "sigmoid":
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
routed_scaling_factor,
e_score_correction_bias,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# gen_idx
ori_input_shape = x.shape
x = x.reshape(-1, x.size(-1))
router_logits = router_logits.reshape(-1, router_logits.size(-1))
expert_num = router_logits.size(-1)
tokens_num = x.size(0)
expert_size = layer.w13_weight.size(0)
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
topk_ids, expert_num
)
expand_hidden_states = mlu_ops.moe_expand_input(
x, expand_idx, cumsum_token_count, 0, expert_size
)
quant_input, input_scale = _fp8_quantize(
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm1_out = mlu_ops.smooth_quant_group_gemm(
quant_input,
layer.w13_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=input_scale.T.contiguous(),
b_scale=layer.w13_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
act_out_quantize, act_out_scale = _fp8_quantize(
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm2_out = mlu_ops.smooth_quant_group_gemm(
act_out_quantize,
layer.w2_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=act_out_scale.T.contiguous(),
b_scale=layer.w2_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
output = mlu_ops.moe_combine_result(
gemm2_out,
topk_weights,
combine_idx,
residual=None,
cusum_token_count=cumsum_token_count,
start_expert_id=0,
expert_size=expert_size,
bias=None,
)
return output.view(ori_input_shape)
"""
==================
End of MLU Hijack
==================
"""
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.from_config,
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.create_weights,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.process_weights_after_loading,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
)

View File

@@ -0,0 +1,440 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("gptq_mlu")
class GPTQMluConfig(QuantizationConfig):
"""Config class for GPTQMlu.
Reference: https://arxiv.org/abs/2210.17323
"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
(4, False): scalar_types.uint4b8,
(8, False): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
self.support_scale_zeros = False
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is "
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod
def get_name(cls) -> str:
return "gptq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size, has_zp=False)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
class GPTQMluLinearMethod(LinearMethodBase):
"""Linear method for GPTQMlu.
Args:
quant_config: The GPTQMlu quantization config.
"""
def __init__(self, quant_config: GPTQMluConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
-1) and (not self.quant_config.desc_act):
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.device = layer.qweight.data.device
packed_qweight, scale_zeros = self.extract_autogptq(layer)
if self.quant_config.use_native:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_native:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autogptq(self, layer: torch.nn.Module):
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
if self.quant_config.use_native:
if self.quant_config.desc_act:
scales = torch.index_select(scales, 0, layer.g_idx)
if izeros is not None:
izeros = torch.index_select(izeros, 0, layer.g_idx)
else:
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
if izeros is not None:
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
# for is_sym is true now, so make iweight to sign value and ignore qzeros
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
shifts.unsqueeze(-1),
).to(dtype)
weight = torch.bitwise_and(weight, (2**bits) - 1)
weight = weight.reshape(-1, weight.shape[-1])
return weight
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
shifts.unsqueeze(0),
).to(dtype)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
zeros = zeros.reshape(qzeros.shape[0], -1)
return zeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,337 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
RowvLLMParameter)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
str_dtype_to_bits,
is_fp8_str_dtype)
# @register_quantization_config("smoothquant")
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant.
"""
def __init__(
self,
quant_mode: str, # smoothquant
input_quant_method: str, # per token/per tensor
group_size: int,
weight_precision: str,
activation_precision: str,
only_expert_per_group: bool,
expert_weight_precision: str,
expert_activation_precision: str,
force_use_weightonly_except_expert: bool,
) -> None:
super().__init__()
self.quant_mode = quant_mode
self.input_quant_method = input_quant_method
self.group_size = group_size
self.weight_precision = weight_precision
self.activation_precision = activation_precision
self.only_expert_per_group = only_expert_per_group
self.expert_weight_precision = expert_weight_precision
self.expert_activation_precision = expert_activation_precision
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
self.weight_bits = str_dtype_to_bits(self.weight_precision)
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
if self.weight_precision == 'int4':
self.weight_dtype = torch.int8
else:
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
if self.expert_weight_precision == 'int4':
self.expert_weight_dtype = torch.int8
else:
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
self.pack_factor = 8 // self.weight_bits
self.expert_pack_factor = 8 // self.expert_weight_bits
def __repr__(self) -> str:
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
f"quant_mode={self.quant_mode}, "
f"group_size={self.group_size}, "
f"weight_precision={self.weight_precision}, "
f"activation_precision={self.activation_precision}, "
f"only_expert_per_group={self.only_expert_per_group}, "
f"expert_weight_precision={self.expert_weight_precision}, "
f"expert_activation_precision={self.expert_activation_precision}, "
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
@classmethod
def get_name(self) -> str:
return "SmoothQuant"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
quant_mode = cls.get_from_keys(config, ["quant_mode"])
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
if expert_weight_precision is None:
expert_weight_precision = weight_precision
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
weight_precision = 'int8'
if expert_activation_precision is None:
expert_activation_precision = activation_precision
return cls(quant_mode=quant_mode,
input_quant_method=input_quant_method,
group_size=group_size,
weight_precision=weight_precision,
activation_precision=activation_precision,
only_expert_per_group=only_expert_per_group,
expert_weight_precision=expert_weight_precision,
expert_activation_precision=expert_activation_precision,
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
if isinstance(layer, LinearBase):
return SmoothQuantLinearMethod(self, prefix)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class SmoothQuantLinearMethod(LinearMethodBase):
"""Linear method for SmoothQuant.
Args:
quant_config: The SmoothQuant quantization config.
"""
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
self.quant_config = quant_config
# for per-tensor case, we can skip quant input for the first attn|ffn linear
# and fusion this step in layernorm to get better performance
self.skip_quant_input = False
self.compute_dtype = torch.get_default_dtype()
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
self.is_group_quant = True
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
self.is_group_quant = True
else:
self.is_group_quant = False
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor != 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
group_num = 1
if self.is_group_quant:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"The input size {input_size_per_partition} is not aligned with the quantized "
f"weight shape. This can be caused by too large "
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
if input_size_per_partition != input_size:
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
qweight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
device="mlu",
dtype=self.weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_group_quant:
per_channel_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
group_num,
device="mlu",
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
per_channel_scale = ChannelQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
device="mlu",
dtype=torch.float32,
),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("per_channel_scale", per_channel_scale)
if self.has_smooth:
smooth = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(smooth, {
"ignore_warning": True,
})
layer.register_parameter("smooth", smooth)
if self.quant_config.input_quant_method == "per_tensor":
scale_to_int = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(scale_to_int, {
"ignore_warning": True,
})
layer.register_parameter("scale_to_int", scale_to_int)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.has_smooth and layer.smooth.dtype != torch.float:
layer.smooth = layer.smooth.to(torch.float)
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
layer.scale_to_int = layer.scale_to_int.to(torch.float)
if layer.per_channel_scale.dtype != torch.float:
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
if self.has_smooth:
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
if self.quant_config.input_quant_method == "per_tensor":
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
use_tp_weight : bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
layer_smooth = layer.smooth if self.has_smooth else None
layer_qweight = layer.qweight
layer_per_channel_scale = layer.per_channel_scale
if use_tp_weight:
if hasattr(layer, 'tp_smooth'):
layer_smooth = layer.tp_smooth
if hasattr(layer, 'tp_qweight'):
layer_qweight = layer.tp_qweight
if hasattr(layer, 'tp_per_channel_scale'):
layer_per_channel_scale = layer.tp_per_channel_scale
quant_input = None
if self.skip_quant_input:
quant_input = x
elif self.quant_config.input_quant_method == "per_token":
if self.is_fp8:
quant_input, input_scale = mlu_ops.scaled_quantize(x,
layer_smooth,
quant_type=self.weight_dtype,
quant_mode='dynamic_per_token')
else:
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
elif self.quant_config.input_quant_method == "per_tensor":
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
else:
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
quant_input_shape = quant_input.shape
if len(quant_input_shape) > 2:
quant_input = quant_input.view(-1, quant_input_shape[-1])
input_scale = input_scale.view(-1)
if residual is not None and len(residual.shape) > 2:
residual = residual.view(-1, residual.shape[-1])
if self.is_fp8:
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
layer_per_channel_scale,
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
bias,
c=residual, act_mode="none",quant_bit_size=8,
alpha=1.0, beta=1.0, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
if output is not None:
out = out.view(output.shape)
output.copy_(out)
out = output
else:
if output is not None:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual)
if len(quant_input_shape) > 2:
out = out.view(*quant_input_shape[:-1], out.shape[-1])
return out

Some files were not shown because too many files have changed in this diff Show More