commit b9925203b89ca6e57dcdc0fb3680bd037d796823 Author: chenxb002 Date: Fri Apr 24 09:50:34 2026 +0800 [Model] Support DeepSeek-V4 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..69da720 --- /dev/null +++ b/.gitignore @@ -0,0 +1,206 @@ +# version file generated by setuptools-scm +/vllm/_version.py + +# vllm-flash-attn built from source +vllm/vllm_flash_attn/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +/.deps/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/source/getting_started/examples/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# generated files +**/generated/** + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# VSCode +.vscode/ + +# DS Store +.DS_Store + +# Results +*.csv +# but may add new accuracy reference +!benchmarks/TruthfulQA/*.csv + +# Python pickle files +*.pkl + +# Sphinx documentation +_build/ + +# vim swap files +*.swo +*.swp + +# hip files generated by PyTorch +*.hip +*_hip* +hip_compat.h + +# Benchmark dataset +benchmarks/*.json + +# Linting +actionlint +shellcheck*/ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..dad0589 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +cmake_minimum_required(VERSION 3.16) +project(vllm_mlu_C) + +function(detect_debian10) + if(EXISTS "/etc/os-release") + file(READ "/etc/os-release" os_release) + if(os_release MATCHES "PRETTY_NAME=\"Debian GNU/Linux 10" OR + os_release MATCHES "VERSION_ID=\"10") + set(DEBIAN_10 TRUE PARENT_SCOPE) + message(STATUS "Detected Debian 10 (buster)") + endif() + endif() +endfunction() + +detect_debian10() + +if(DEBIAN_10) + find_program(GCC_PATH "gcc" PATHS "/usr/local/bin") + if(GCC_PATH) + message(STATUS "Using GCC on Debian 10: ${GCC_PATH}") + set(CMAKE_C_COMPILER "${GCC_PATH}") + set(CMAKE_CXX_COMPILER "/usr/local/bin/g++") + else() + message(WARNING "Debian 10 detected but gcc not found!") + endif() +endif() + +set(CMAKE_CXX_STANDARD 17) + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +# Suppress potential warnings about unused manually-specified variables +set(ignoreMe "${VLLM_PYTHON_PATH}") + +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") + +find_package(pybind11 REQUIRED) + +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") +set(VLLM_MLU_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}") + +find_package(Torch REQUIRED) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE) +endif() + +file(GLOB VLLM_MLU_SRC ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp) + +include_directories( + ${pybind11_INCLUDE_DIRS} + ${PYTHON_INCLUDE_PATH} + ${TORCH_INCLUDE_DIRS} + $ENV{NEUWARE_HOME}/include +) + +pybind11_add_module(vllm_mlu_C ${VLLM_MLU_SRC}) + +target_link_directories( + vllm_mlu_C + PRIVATE + $ENV{NEUWARE_HOME}/lib64 +) + +target_link_libraries( + vllm_mlu_C + PUBLIC + ${TORCH_LIBRARIES} + libcndrv.so +) + +target_link_options(vllm_mlu_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib") + +install(TARGETS vllm_mlu_C DESTINATION ${VLLM_MLU_INSTALL_PATH}) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2a047d6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Modifications made by Cambricon Technologies Corporation Limited. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cecc1ab --- /dev/null +++ b/README.md @@ -0,0 +1,116 @@ + + +### Cambricon vLLM (vllm_mlu) + +#### 1. 项目描述 + +Cambricon vLLM(vllm_mlu)基于社区vLLM提供的[插件系统](https://docs.vllm.ai/en/latest/design/plugin_system.html)开发,旨在为用户提供在寒武纪MLU硬件平台上高效运行大语言模型(LLM)推理和服务的能力。 + +vllm_mlu支持包括但不限于Chunk Prefill、Prefix Caching、Spec Decode、Graph Mode、Sleep Mode等vLLM原生特性。 + +#### 2. 更新历史 + +[2026.04.24] vllm_mlu day0支持DeepSeek-V4 + +#### 3. 使用说明 + +软件环境依赖:Cambricon SDK,SDK获取请联系寒武纪官方支持渠道:[ecosystem@cambricon.com](mailto:ecosystem@cambricon.com) + +*NOTE:vllm-mlu仓库仅支持MLU370以上的设备* + +##### 3.1 镜像使用 + +使⽤寒武纪SDK提供的镜像 Cambricon vLLM Container。 + +``` +# 加载镜像 + +docker load -i cambricon_vllm_container.tar.gz + + +# 进入镜像 + +docker run -it --net=host \ + --shm-size '64gb' --privileged -it \ + --ulimit memlock=-1 ${IMAGE_NAME} \ + /bin/bash + +# 使⽤推理环境 +source /torch/venv3/pytorch_infer/bin/activate +``` + +##### 3.2 ⾃定义安装步骤 + +安装Cambricon vLLM前需要保证依赖已正确安装。 + +安装步骤: + +```bash +# 已经获取Cambricon vLLM源码,包含vllm源码 + +# 基于vllm源码安装 +cd vllm-v{社区vLLM版本}/ +VLLM_TARGET_DEVICE=empty pip install -e . # 使⽤开发者模式安装 + +# 基于vllm-mlu源码安装 +git clone https://github.com/Cambricon/vllm-mlu +cd vllm-mlu +pip install -e . # 使⽤开发者模式安装 + +# 安装ray +# 1. 进⼊vllm-mlu源码中。 +cd tools/ray_mlu/ +# 2. 适配基于Ray安装 +pip install --no-cache-dir --force-reinstall ray==2.51.1 +# 3. 为了在寒武纪设备运⾏,Ray也需要适配寒武纪软件。 +# PIP_INSTALL_LOC 指向pip的安装路径 +cp __init__.py ${RAY_DIR}/_private/accelerators/__init__.py +cp mlu.py ${RAY_DIR}/_private/accelerators/ +cp nsight.py ${RAY_DIR}/_private/runtime_env/nsight.py +cp node.py ${RAY_DIR}/_private/node.py +cp worker.py ${RAY_DIR}/_private/worker.py +cp device_manager/__init__.py ${RAY_DIR}/air/_internal/device_manager/__init__.py +cp device_manager/mlu.py ${RAY_DIR}/air/_internal/device_manager/ +``` + +##### 3.3 运行步骤 + +Cambricon vLLM代码运⾏和vLLM社区⼀致。 + +###### 3.3.1 离线推理命令 + +``` +# 运行推理命令 + +python examples/offline_inference/offline_inference.py ${MODEL_PATH} +``` + +###### 3.3.2 在线推理命令 + +分别启动server和client,完成推理服务,示例如下: + +``` +# server + +vllm serve ${MODEL_PATH} \ + --port 8100 \ + --block-size 1 \ + --max-model-len 4096 \ + --tensor-parallel-size 8 \ + --gpu-memory-utilization 0.96 \ + --trust-remote-code \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --disable-log-requests \ + --enforce-eager + + +# client, we post a single request here. + +curl -X POST http://localhost:8100/v1/completions \ + -H "Content-Type: application/json" \ + -d '{"model": ${MODEL_PATH}, \ + "prompt": "The future of AI is", \ + "max_tokens": 128, "temperature": 0.7 \ + }' +``` diff --git a/cmake/utils.cmake b/cmake/utils.cmake new file mode 100644 index 0000000..a01a4e3 --- /dev/null +++ b/cmake/utils.cmake @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +# +# Attempt to find the python package that uses the same python executable as +# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`. +# +macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) + file(REAL_PATH ${EXECUTABLE} EXECUTABLE) + set(Python_EXECUTABLE ${EXECUTABLE}) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) + if (NOT Python_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() + set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}") + set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN}) + if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST) + message(FATAL_ERROR + "Python version (${_VER}) is not one of the supported versions: " + "${_SUPPORTED_VERSIONS_LIST}.") + endif() + message(STATUS "Found python matching: ${EXECUTABLE}.") +endmacro() + +# +# Run `EXPR` in python. The standard output of python is stored in `OUT` and +# has trailing whitespace stripped. If an error is encountered when running +# python, a fatal message `ERR_MSG` is issued. +# +function (run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${PYTHON_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro (append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() diff --git a/csrc/cnmem_allocator.cpp b/csrc/cnmem_allocator.cpp new file mode 100644 index 0000000..28b2906 --- /dev/null +++ b/csrc/cnmem_allocator.cpp @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +// A MLU PluggableAllocator based on cn_api APIs. + +#include + +extern "C" { + +#define PY_SSIZE_T_CLEAN +#include +#include + + +#define DRV_CHECK_GET_RETURN(...) \ + DRV_CHECK_GET_RETURN_IMPL(__VA_ARGS__, return, ) +#define DRV_CHECK_GET_RETURN_IMPL(_1, _2, ...) _2 + +#define CN_CHECK(return_code, ...) \ + do { \ + CNresult rc = (return_code); \ + if (rc) { \ + const char *error_str; \ + cnGetErrorString(rc, &error_str); \ + std::cout << "Error: " << error_str \ + << " at " << __FILE__ \ + << ":" << __LINE__ \ + << std::endl; \ + DRV_CHECK_GET_RETURN(__VA_ARGS__) \ + __VA_ARGS__; \ + } \ + } while (0) + +// Global references to Python callables +static PyObject* g_python_malloc_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; + +// --------------------------------------------------------------------------- +// Helper functions: +void ensure_context(CNdev device) { + CNcontext pctx; + CN_CHECK(cnCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context; + CN_CHECK(cnCtxCreate(&pctx, 0, device)); + CN_CHECK(cnCtxSetCurrent(pctx)); + } +} + +void create_and_map(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) { + ensure_context(device); + // Define memory allocation properties + CNmemAllocationProp prop = {}; + // The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide. + prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED + prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE; + prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE; + + // Allocate memory using cnMemCreate + CN_CHECK(cnMemCreate(p_memHandle, size, &prop, 0)); + CN_CHECK(cnMemMap(d_mem, size, 0, *p_memHandle, 0)); + + CNmemAccessDesc accessDesc = {}; + accessDesc.location.type = CN_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.accessFlags = CN_MEM_ACCESS_FLAGS_PROT_READWRITE; + CN_CHECK(cnMemSetAccess(d_mem, size, &accessDesc, 1)); +} + +void unmap_and_release(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) { + ensure_context(device); + CN_CHECK(cnMemUnmap(d_mem, size)); + CN_CHECK(cnMemRelease(*p_memHandle)); +} + +PyObject* create_tuple_from_c_integers(unsigned long long a, + unsigned long long b, + unsigned long long c, + unsigned long long d) { + // Create a new tuple of size 4 + PyObject* tuple = PyTuple_New(4); + if (!tuple) { + return NULL; + } + // Convert integers to Python objects and set them in the tuple + // Steals reference to the PyLong + PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a)); + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; +} + +// --------------------------------------------------------------------------- +// Our exported C functions that call Python: +__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, CNqueue stream) { + ensure_context(device); + // first allocation, align the size, and reserve an address, and also allocate + // a CNmemGenericAllocationHandle + + // Define memory allocation properties + CNmemAllocationProp prop = {}; + // The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide. + prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED + prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE; + prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE; + + //Check if the allocation is supported + size_t granularity; + CN_CHECK(cnMemGetAllocationGranularity(&granularity, &prop, CN_MEM_ALLOC_GRANULARITY_MINIMUM), nullptr); + + size_t alignedSize = ((size+granularity-1)/granularity)*granularity; + CNaddr d_mem; + CN_CHECK(cnMemAddressReserve(&d_mem, alignedSize, 0, 0, 0), nullptr); + + // allocate the CNmemGenericAllocationHandle + CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)malloc(sizeof(CNmemGenericAllocationHandle)); + + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* arg_tuple = create_tuple_from_c_integers( + (unsigned long long)device, (unsigned long long)alignedSize, + (unsigned long long)d_mem, (unsigned long long)p_memHandle); + + // Call g_python_malloc_callback + PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + + if (!py_result) { + PyErr_Print(); + PyGILState_Release(gstate); + return nullptr; + } + + PyGILState_Release(gstate); + + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; +} + +__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, CNqueue stream) { + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } + + PyGILState_Release(gstate); + + // Free memory + CNaddr d_mem = (CNaddr)recv_d_mem; + CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); + + //free address and the handle + CN_CHECK(cnMemAddressFree(d_mem, size)); + free(p_memHandle); +} + +// --------------------------------------------------------------------------- +// Python extension boilerplate: +// Python-exposed function: init_module(python_malloc, python_free) +static PyObject* py_init_module(PyObject* self, PyObject* args) { + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + // This module keeps a strong reference to prevent premature GC + Py_XINCREF(malloc_callback); + Py_XINCREF(free_callback); + + Py_XDECREF(g_python_malloc_callback); + Py_XDECREF(g_python_free_callback); + + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; + } + +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CNaddr d_mem_ptr = (CNaddr)recv_d_mem; + CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; + } + +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CNaddr d_mem_ptr = (CNaddr)recv_d_mem; + CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyObject* python_cn_memcpy(PyObject* self, PyObject* args){ + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 3) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 3"); + return nullptr; + } + + CNaddr dst, src; + cn_uint64_t bytes; + if (!PyArg_ParseTuple(args, "KKK", &dst, &src, &bytes)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CN_CHECK(cnMemcpy(dst, src, bytes), nullptr); + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, + "Initialize module with python_malloc and python_free callables."}, + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, + "Create and map memory on the device."}, + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, + METH_VARARGS, "Unmap and release memory on the device."}, + {"python_cn_memcpy", (PyCFunction)python_cn_memcpy, METH_VARARGS, "Copies data from source address to destination address."}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef cnmem_allocator_module = { + PyModuleDef_HEAD_INIT, "cnmem_allocator", + "cnapi-mem-based allocator for MLUPluggableAllocator", -1, module_methods}; + +PyMODINIT_FUNC PyInit_vllm_mlu_C(void) { + // Initialize the module + PyObject* module = PyModule_Create(&cnmem_allocator_module); + if (!module) { + return NULL; + } + return module; +} + +} // extern "C" \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h new file mode 100644 index 0000000..c7d3642 --- /dev/null +++ b/csrc/ops.h @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +#pragma once + +#include +#include + +#include + +namespace vllm_mlu { + + torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { + // Ensure tensor is on MLU + if (!tensor.is_privateuseone()) { + throw std::runtime_error("Tensor must be on MLU device"); + } + + // Get the raw data pointer + void* data_ptr = tensor.data_ptr(); + + // Get tensor sizes and strides + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + + // Get tensor options (dtype, device) + auto options = tensor.options(); + + // Create a new tensor from the raw data pointer + auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); + + return new_tensor; + } + +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp new file mode 100644 index 0000000..b9f6de5 --- /dev/null +++ b/csrc/torch_bindings.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +#include +#include +#include +#include +#include "ops.h" +#include "utils.h" + + +TORCH_LIBRARY_EXPAND(_C, ops) +{ + // vLLM-MLU custom ops + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); + ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_mlu::weak_ref_tensor); +} + +REGISTER_EXTENSION(_C) diff --git a/csrc/utils.h b/csrc/utils.h new file mode 100644 index 0000000..ac58244 --- /dev/null +++ b/csrc/utils.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/examples/offline_inference/offline_inference.py b/examples/offline_inference/offline_inference.py new file mode 100644 index 0000000..1f7f2b1 --- /dev/null +++ b/examples/offline_inference/offline_inference.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import sys +from vllm import LLM, SamplingParams + + +def main(model_path): + # Sample prompts. + prompts = [ + "The benefits of exercise include", + "The importance of reading books is", + "Gardening can be relaxing because", + "A good night's sleep is essential for", + ] + sampling_params = SamplingParams( + temperature=0.6, top_p=0.95, max_tokens=10) + + # Create an LLM. + engine_args_dict = { + "model": model_path, + "tensor_parallel_size": 8, + "enable_expert_parallel": True, + "enable_prefix_caching": False, + "enforce_eager": True, + "trust_remote_code": True, + "max_num_seqs": len(prompts), + "max_model_len": 4096, + "block_size": 1, + "gpu_memory_utilization": 0.96, + } + llm = LLM(**engine_args_dict) + + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == '__main__': + if len(sys.argv) < 2: + print("Usage: python offline_inference.py ") + sys.exit(1) + main(sys.argv[1]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..75f59c2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +# Dependencies for Cambricon MLUs +ray == 2.51.1 +click == 8.2.1 +triton >= 3.2.0 +torch == 2.9.1 +torch-mlu >= 1.29.1 +torch_mlu_ops >= 1.8.1 + +matplotlib == 3.10.3 +datasets == 3.6.0 +blobfile == 3.0.0 +scipy == 1.10.1 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..44b216d --- /dev/null +++ b/setup.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +import importlib.util +import io +import logging +import os +import re +import subprocess +import sys + +from sysconfig import get_paths +from typing import List, Dict +from setuptools import Extension +from setuptools import find_namespace_packages, setup +from setuptools.command.build_ext import build_ext +from setuptools.command.install import install +from setuptools.command.develop import develop + +ROOT_DIR = os.path.dirname(__file__) +logger = logging.getLogger(__name__) + +def check_or_set_default_env(cmake_args, env_name, env_variable, default_path=""): + if env_variable is None: + logging.warning(f"Set default {env_name}: {default_path}") + env_variable = default_path + else: + logging.info(f"Found existing {env_name}: {env_variable}") + cmake_args += [f"-D{env_name}={env_variable}"] + return cmake_args + +def load_module_from_path(module_name, path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + +envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm_mlu", "envs.py")) +class CMakeExtension(Extension): + def __init__(self, + name: str, + cmake_lists_dir: str = ".", + **kwargs) -> None: + super().__init__(name, sources=[], py_limited_api=False, **kwargs) + self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) + + +def get_path(*filepath) -> str: + return os.path.join(ROOT_DIR, *filepath) + + +def get_vllm_version() -> str: + """ + get vllm version + """ + with open(get_path("tools/build.property"), 'r') as file: + content = file.read() + + results = re.findall(r'VLLM_VERSION=([\d|\.]+)\+mlu([\d|\.]+)\.pt(\d+)', content) + + assert results, "fail to get vllm, vllm_mlu and pytorch version." + + version = f"{results[-1][0]}+mlu{results[-1][1]}.pt{results[-1][2]}" + + return version + + +def read_readme() -> str: + """Read the README file if present.""" + p = get_path("README.md") + if os.path.isfile(p): + return io.open(get_path("README.md"), "r", encoding="utf-8").read() + else: + return "" + + +def get_requirements() -> List[str]: + """Get Python package dependencies from requirements.txt.""" + + def _read_requirements(filename: str) -> List[str]: + with open(get_path(filename)) as f: + requirements = f.read().strip().split("\n") + resolved_requirements = [] + for line in requirements: + if line.startswith("-r "): + resolved_requirements += _read_requirements(line.split()[1]) + elif line.startswith("--"): + continue + else: + resolved_requirements.append(line) + return resolved_requirements + return _read_requirements("requirements.txt") + + +class cmake_build_ext(build_ext): + # A dict of extension directories that have been configured. + did_config: Dict[str, bool] = {} + + # Determine number of compilation jobs + def compute_num_jobs(self): + # `num_jobs` is either the value of the MAX_JOBS environment variable + # (if defined) or the number of CPUs available. + num_jobs = envs.MAX_JOBS + if num_jobs is not None: + num_jobs = int(num_jobs) + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) + else: + try: + # os.sched_getaffinity() isn't universally available, so fall + # back to os.cpu_count() if we get an error here. + num_jobs = len(os.sched_getaffinity(0)) + except AttributeError: + num_jobs = os.cpu_count() + num_jobs = max(1, num_jobs) + + return num_jobs + + # + # Perform cmake configuration for a single extension. + # + def configure(self, ext: CMakeExtension) -> None: + os.makedirs(self.build_temp, exist_ok=True) + source_dir = os.path.abspath(ROOT_DIR) + python_executable = sys.executable + cmake_args = ["cmake"] + # Default use release mode to compile the csrc code + # Turbo now support compiled with Release, Debug and RelWithDebugInfo + if envs.CMAKE_BUILD_TYPE is None or envs.CMAKE_BUILD_TYPE not in [ + "Debug", + "Release", + "RelWithDebugInfo", + ]: + envs.CMAKE_BUILD_TYPE = "Release" + cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"] + # Default dump the compile commands for lsp + cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"] + if envs.CXX_COMPILER is not None: + cmake_args += [f"-DCMAKE_CXX_COMPILER={envs.CXX_COMPILER}"] + if envs.C_COMPILER is not None: + cmake_args += [f"-DCMAKE_C_COMPILER={envs.C_COMPILER}"] + if envs.VERBOSE: + cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] + + # find PYTHON_EXECUTABLE + check_or_set_default_env(cmake_args, "PYTHON_EXECUTABLE", sys.executable) + + # find PYTHON_INCLUDE_PATH + check_or_set_default_env(cmake_args, "PYTHON_INCLUDE_PATH", + get_paths()["include"]) + + try: + # if pybind11 is installed via pip + subprocess.check_call([sys.executable, "-m", "pip", "install", "pybind11==2.13.6"]) + pybind11_cmake_path = (subprocess.check_output([python_executable, "-m", + "pybind11", "--cmake"]).decode().strip()) + except subprocess.CalledProcessError as e: + # else specify pybind11 path installed from source code on CI container + raise RuntimeError(f"CMake configuration failed: {e}") + + install_path = os.path.join(ROOT_DIR, self.build_lib) + if isinstance(self.distribution.get_command_obj("develop"), develop): + install_path = os.path.join(ROOT_DIR, "vllm_mlu") + + # add CMAKE_INSTALL_PATH + cmake_args += [f"-DCMAKE_INSTALL_PREFIX={install_path}"] + + cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"] + + cmake_args += [source_dir] + logging.info(f"cmake config command: {cmake_args}") + try: + subprocess.check_call(cmake_args, cwd=self.build_temp) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"CMake configuration failed: {e}") + + def build_extensions(self) -> None: + if not envs.COMPILE_CUSTOM_KERNELS: + return + # Ensure that CMake is present and working + try: + subprocess.check_output(["cmake", "--version"]) + except OSError as e: + raise RuntimeError(f"Cannot find CMake executable: {e}") + + # Create build directory if it does not exist. + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + os.makedirs(os.path.join(self.build_lib, "vllm_mlu"), exist_ok=True) + + targets = [] + + def get_target_name(s: str) -> str: + return s.removeprefix("vllm_mlu.") + + # Build all the extensions + for ext in self.extensions: + self.configure(ext) + targets.append(get_target_name(ext.name)) + + num_jobs = self.compute_num_jobs() + + build_args = ["--build", ".", f"-j={num_jobs}", + *[f"--target={name}" for name in targets], + ] + logger.info(build_args) + try: + subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) + except OSError as e: + raise RuntimeError(f"Build library failed: {e}") + + # Install the libraries + install_args = ["--install", ".", ] + try: + subprocess.check_call(["cmake", *install_args], cwd=self.build_temp) + except OSError as e: + raise RuntimeError(f"Install library failed: {e}") + + # copy back to build folder for editable build + if isinstance(self.distribution.get_command_obj("develop"), develop): + for root, _, files in os.walk(self.build_temp): + for file in files: + if file.endswith(".so"): + src_path = os.path.join(root, file) + dst_path = os.path.join(self.build_lib, "vllm_mlu", file) + self.copy_file(src_path, dst_path) + logger.info(f"Copy: {src_path} -> {dst_path}") + + def run(self): + # First, run the standard build_ext command to compile the extensions + super().run() + + +class custom_install(install): + def run(self): + self.run_command("build_ext") + install.run(self) + +ext_modules = [] +if envs.COMPILE_CUSTOM_KERNELS: + ext_modules = [CMakeExtension(name="vllm_mlu.vllm_mlu_C")] +cmdclass = {"build_ext": cmake_build_ext, "install": custom_install} + +setup( + name="vllm_mlu", + version=get_vllm_version(), + author="Cambricon vLLM Team", + license="Apache 2.0", + description=("A high-throughput and memory-efficient inference and " + "serving engine for LLMs on MLU backend"), + long_description=read_readme(), + long_description_content_type="text/markdown", + url="", + project_urls={ + "Homepage": "https://github.com/vllm-project/vllm", + "Documentation": "https://vllm.readthedocs.io/en/latest/", + }, + classifiers=[ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + packages=find_namespace_packages(exclude=("docs", "examples", "tests*", "csrc")), + include_package_data=True, + python_requires=">=3.8", + install_requires=get_requirements(), + ext_modules = ext_modules, + cmdclass=cmdclass, + entry_points={ + 'vllm.platform_plugins': ["mlu = vllm_mlu:register_mlu_platform"], + 'vllm.general_plugins': ["mlu_hijack = vllm_mlu:register_mlu_hijack"] + } +) diff --git a/tools/ray_mlu/__init__.py b/tools/ray_mlu/__init__.py new file mode 100644 index 0000000..34808f1 --- /dev/null +++ b/tools/ray_mlu/__init__.py @@ -0,0 +1,89 @@ +from typing import Optional, Set + +from ray._private.accelerators.accelerator import ( + RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR, + AcceleratorManager, +) +from ray._private.accelerators.amd_gpu import AMDGPUAcceleratorManager +from ray._private.accelerators.hpu import HPUAcceleratorManager +from ray._private.accelerators.intel_gpu import IntelGPUAcceleratorManager +from ray._private.accelerators.neuron import NeuronAcceleratorManager +from ray._private.accelerators.npu import NPUAcceleratorManager +from ray._private.accelerators.nvidia_gpu import NvidiaGPUAcceleratorManager +from ray._private.accelerators.rbln import RBLNAcceleratorManager +from ray._private.accelerators.tpu import TPUAcceleratorManager +from ray._private.accelerators.mlu import MLUAcceleratorManager + + +def get_all_accelerator_managers() -> Set[AcceleratorManager]: + """Get all accelerator managers supported by Ray.""" + return { + NvidiaGPUAcceleratorManager, + IntelGPUAcceleratorManager, + AMDGPUAcceleratorManager, + TPUAcceleratorManager, + NeuronAcceleratorManager, + HPUAcceleratorManager, + NPUAcceleratorManager, + RBLNAcceleratorManager, + MLUAcceleratorManager, + } + + +def get_all_accelerator_resource_names() -> Set[str]: + """Get all resource names for accelerators.""" + return { + accelerator_manager.get_resource_name() + for accelerator_manager in get_all_accelerator_managers() + } + + +def get_accelerator_manager_for_resource( + resource_name: str, +) -> Optional[AcceleratorManager]: + """Get the corresponding accelerator manager for the given + accelerator resource name + + E.g., TPUAcceleratorManager is returned if resource name is "TPU" + """ + try: + return get_accelerator_manager_for_resource._resource_name_to_accelerator_manager.get( # noqa: E501 + resource_name, None + ) + except AttributeError: + # Lazy initialization. + resource_name_to_accelerator_manager = { + accelerator_manager.get_resource_name(): accelerator_manager + for accelerator_manager in get_all_accelerator_managers() + } + # Special handling for GPU resource name since multiple accelerator managers + # have the same GPU resource name. + if AMDGPUAcceleratorManager.get_current_node_num_accelerators() > 0: + resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager + elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0: + resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager + elif MLUAcceleratorManager.get_current_node_num_accelerators() > 0: + resource_name_to_accelerator_manager["GPU"] = MLUAcceleratorManager + else: + resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager + get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = ( + resource_name_to_accelerator_manager + ) + return resource_name_to_accelerator_manager.get(resource_name, None) + + +__all__ = [ + "NvidiaGPUAcceleratorManager", + "IntelGPUAcceleratorManager", + "AMDGPUAcceleratorManager", + "TPUAcceleratorManager", + "NeuronAcceleratorManager", + "HPUAcceleratorManager", + "NPUAcceleratorManager", + "RBLNAcceleratorManager", + "MLUAcceleratorManager", + "get_all_accelerator_managers", + "get_all_accelerator_resource_names", + "get_accelerator_manager_for_resource", + "RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR", +] diff --git a/tools/ray_mlu/device_manager/__init__.py b/tools/ray_mlu/device_manager/__init__.py new file mode 100644 index 0000000..c400184 --- /dev/null +++ b/tools/ray_mlu/device_manager/__init__.py @@ -0,0 +1,114 @@ +import logging +import threading +from typing import Optional + +import ray +import ray._private.ray_constants as ray_constants +from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager +from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager +from ray.air._internal.device_manager.npu import NPUTorchDeviceManager +from ray.air._internal.device_manager.mlu import MLUTorchDeviceManager +from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager + +logger = logging.getLogger(__name__) + + +DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager + +''' +============================= +Modify by vllm_mlu +============================= +@brief: use MLUTorchDeviceManager when key="GPU" +''' +SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = { + ray_constants.GPU: MLUTorchDeviceManager, + ray_constants.HPU: HPUTorchDeviceManager, + ray_constants.NPU: NPUTorchDeviceManager, +} +''' +================== +End of MLU Hijack +================== +''' + + +def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None: + if backend == "hccl": + # The name for the communication backend of Habana and torch-npu is the same. + HPUTorchDeviceManager.register_custom_torch_dist_backend() + + NPUTorchDeviceManager.register_custom_torch_dist_backend() + + +_torch_device_manager = None +_torch_device_manager_lock = threading.Lock() + + +def get_torch_device_manager_by_context() -> TorchDeviceManager: + global _torch_device_manager + + with _torch_device_manager_lock: + if not _torch_device_manager: + existing_device_manager_cls = None + resources = ray.get_runtime_context().get_accelerator_ids() + + # select correct accelerator type from resources + for resource_type, resource_value in resources.items(): + device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get( + resource_type, None + ) + if resource_value and device_manager_cls: + # An error will raise when multiple accelerators are specified. + if existing_device_manager_cls: + raise RuntimeError( + "Unable to determine the appropriate DeviceManager " + f"for the specified resources {resources}." + ) + else: + existing_device_manager_cls = device_manager_cls + + device_manager_cls = ( + existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS + ) + + _torch_device_manager = device_manager_cls() + + return _torch_device_manager + + +def get_torch_device_manager_by_device_type(device_type: str): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use MLUTorchDeviceManager when key="GPU" + ''' + if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda": + return MLUTorchDeviceManager() + elif device_type.lower() == ray_constants.NPU.lower(): + return NPUTorchDeviceManager() + elif device_type.lower() == ray_constants.HPU.lower(): + return HPUTorchDeviceManager() + elif device_type.lower() == "cpu": + return CPUTorchDeviceManager() + ''' + ================== + End of MLU Hijack + ================== + ''' + raise RuntimeError(f"Device type {device_type} cannot be recognized.") + + +__all__ = [ + TorchDeviceManager, + CPUTorchDeviceManager, + CUDATorchDeviceManager, + HPUTorchDeviceManager, + NPUTorchDeviceManager, + MLUTorchDeviceManager, + register_custom_torch_dist_backend, + get_torch_device_manager_by_context, + get_torch_device_manager_by_device_type, +] diff --git a/tools/ray_mlu/device_manager/mlu.py b/tools/ray_mlu/device_manager/mlu.py new file mode 100644 index 0000000..e709638 --- /dev/null +++ b/tools/ray_mlu/device_manager/mlu.py @@ -0,0 +1,103 @@ +import os +from importlib.util import find_spec +from typing import List, Union + +import torch + +import ray +import ray._private.ray_constants as ray_constants +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager +from ray._private.accelerators.mlu import MLU_VISIBLE_DEVICES_ENV_VAR + + +def is_package_present(package_name: str) -> bool: + try: + return find_spec(package_name) is not None + except ModuleNotFoundError: + return False + + +MLU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_mlu") + + +if MLU_TORCH_PACKAGE_AVAILABLE: + import torch_mlu # noqa: F401 + + +class MLUTorchDeviceManager(TorchDeviceManager): + """Cambricon MLU device manager""" + + @staticmethod + def register_custom_torch_dist_backend(): + if MLU_TORCH_PACKAGE_AVAILABLE: + import torch_mlu # noqa: F401, F811 + + def is_available(self) -> bool: + if not MLU_TORCH_PACKAGE_AVAILABLE: + return False + + return torch.mlu.is_available() + + def get_devices(self) -> List[torch.device]: + """Gets the correct torch device list configured for this process. + Returns a list of torch MLU devices allocated for the current worker. + If no MLUs are assigned, then it returns a list with a single CPU device. + """ + if MLU_TORCH_PACKAGE_AVAILABLE and torch.mlu.is_available(): + mlu_ids = [ + str(id) + for id in ray.get_runtime_context().get_accelerator_ids()[ + ray_constants.GPU + ] + ] + + device_ids = [] + + if len(mlu_ids) > 0: + mlu_visible_str = os.environ.get(MLU_VISIBLE_DEVICES_ENV_VAR, "") + if mlu_visible_str and mlu_visible_str != "NoDevFiles": + mlu_visible_list = mlu_visible_str.split(",") + else: + mlu_visible_list = [] + + for mlu_id in mlu_ids: + try: + device_ids.append(mlu_visible_list.index(mlu_id)) + except IndexError: + raise RuntimeError( + "MLU_VISIBLE_DEVICES set incorrectly. " + f"Got {mlu_visible_str}, expected to include {mlu_id}. " + "Did you override the `MLU_VISIBLE_DEVICES` " + "environment variable?" + ) + else: + # If called on the driver or outside of Ray Train, return the + # 0th device. + device_ids.append(0) + + devices = [torch.device(f"mlu:{device_id}") for device_id in device_ids] + else: + raise RuntimeError( + "Using MLUTorchDeviceManager but torch mlu is not available." + ) + + return devices + + def set_device(self, device: Union[torch.device, int]): + torch.mlu.set_device(device) + + def supports_stream(self) -> bool: + """Validate if the device type support to create a stream""" + return True + + def create_stream(self, device): + """Create a stream on MLU device""" + return torch.mlu.Stream(device) + + def get_stream_context(self, stream): + """Get a torch.stream context on MLU device""" + return torch.mlu.stream(stream) + + def get_current_stream(self): + """Get current stream for MLU device""" + return torch.mlu.current_stream() diff --git a/tools/ray_mlu/diff.patch b/tools/ray_mlu/diff.patch new file mode 100644 index 0000000..4ed43cf --- /dev/null +++ b/tools/ray_mlu/diff.patch @@ -0,0 +1,243 @@ +commit 7376225d16e381ecae5cc07d84db9eed043ed06a +Author: tanhaojue +Date: Thu Mar 7 15:54:09 2024 +0800 + + support mlu + +diff --git a/python/ray/_private/accelerators/__init__.py b/python/ray/_private/accelerators/__init__.py +index 71550bc..07bdcd6 100644 +--- a/python/ray/_private/accelerators/__init__.py ++++ b/python/ray/_private/accelerators/__init__.py +@@ -8,6 +8,7 @@ from ray._private.accelerators.tpu import TPUAcceleratorManager + from ray._private.accelerators.neuron import NeuronAcceleratorManager + from ray._private.accelerators.hpu import HPUAcceleratorManager + from ray._private.accelerators.npu import NPUAcceleratorManager ++from ray._private.accelerators.mlu import MLUAcceleratorManager + + + def get_all_accelerator_managers() -> Set[AcceleratorManager]: +@@ -20,6 +21,7 @@ def get_all_accelerator_managers() -> Set[AcceleratorManager]: + NeuronAcceleratorManager, + HPUAcceleratorManager, + NPUAcceleratorManager, ++ MLUAcceleratorManager, + } + + +@@ -55,6 +57,8 @@ def get_accelerator_manager_for_resource( + resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager + elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0: + resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager ++ elif MLUAcceleratorManager.get_current_node_num_accelerators() > 0: ++ resource_name_to_accelerator_manager["GPU"] = MLUAcceleratorManager + else: + resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager + get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = ( +@@ -71,6 +75,7 @@ __all__ = [ + "NeuronAcceleratorManager", + "HPUAcceleratorManager", + "NPUAcceleratorManager", ++ "MLUAcceleratorManager", + "get_all_accelerator_managers", + "get_all_accelerator_resource_names", + "get_accelerator_manager_for_resource", +diff --git a/python/ray/_private/accelerators/mlu.py b/python/ray/_private/accelerators/mlu.py +new file mode 100755 +index 0000000..21a5771 +--- /dev/null ++++ b/python/ray/_private/accelerators/mlu.py +@@ -0,0 +1,92 @@ ++import os ++import glob ++import logging ++from typing import Optional, List, Tuple ++import torch ++import torch_mlu ++from ray._private.accelerators.accelerator import AcceleratorManager ++ ++logger = logging.getLogger(__name__) ++ ++MLU_VISIBLE_DEVICES_ENV_VAR = "MLU_VISIBLE_DEVICES" ++NOSET_MLU_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_MLU_VISIBLE_DEVICES" ++ ++ ++class MLUAcceleratorManager(AcceleratorManager): ++ """Cambricon MLU accelerators.""" ++ ++ @staticmethod ++ def get_resource_name() -> str: ++ return "GPU" ++ ++ @staticmethod ++ def get_visible_accelerator_ids_env_var() -> str: ++ return MLU_VISIBLE_DEVICES_ENV_VAR ++ ++ @staticmethod ++ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]: ++ mlu_visible_devices = os.environ.get( ++ MLUAcceleratorManager.get_visible_accelerator_ids_env_var(), None ++ ) ++ ++ if mlu_visible_devices is None: ++ return None ++ ++ if mlu_visible_devices == "": ++ return [] ++ ++ if mlu_visible_devices == "NoDevFiles": ++ return [] ++ ++ return list(mlu_visible_devices.split(",")) ++ ++ @staticmethod ++ def get_current_node_num_accelerators() -> int: ++ """Attempt to detect the number of MLUs on this machine. ++ ++ MLU chips are represented as devices within `/dev/`, either as `/dev/davinci?`. ++ ++ Returns: ++ The number of MLUs if any were detected, otherwise 0. ++ """ ++ try: ++ return torch.mlu.device_count() ++ except Exception as e: ++ logger.debug("Could not import CambriconCL: %s", e) ++ ++ try: ++ mlu_files = glob.glob("/dev/cambricon_dev?") ++ return len(mlu_files) ++ except Exception as e: ++ logger.debug("Failed to detect number of MLUs: %s", e) ++ return 0 ++ ++ @staticmethod ++ def get_current_node_accelerator_type() -> Optional[str]: ++ """Get the type of the Cambricon MLU on the current node. ++ ++ Returns: ++ A string of the type, such as "MLU370". ++ """ ++ try: ++ return torch.mlu.get_device_name(0) ++ except Exception: ++ logger.exception("Failed to detect MLU type.") ++ return None ++ ++ @staticmethod ++ def validate_resource_request_quantity( ++ quantity: float, ++ ) -> Tuple[bool, Optional[str]]: ++ return (True, None) ++ ++ @staticmethod ++ def set_current_process_visible_accelerator_ids( ++ visible_mlu_devices: List[str], ++ ) -> None: ++ if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR): ++ return ++ ++ os.environ[ ++ MLUAcceleratorManager.get_visible_accelerator_ids_env_var() ++ ] = ",".join([str(i) for i in visible_mlu_devices]) +diff --git a/python/ray/tests/accelerators/test_mlu.py b/python/ray/tests/accelerators/test_mlu.py +new file mode 100755 +index 0000000..70e81f7 +--- /dev/null ++++ b/python/ray/tests/accelerators/test_mlu.py +@@ -0,0 +1,92 @@ ++import os ++import sys ++import pytest ++from unittest.mock import patch ++ ++import ray ++from ray._private.accelerators import MLUAcceleratorManager as Accelerator ++ ++ ++@patch("glob.glob") ++@patch("os.listdir") ++def test_autodetect_num_mlus(mock_list, mock_glob): ++ mock_glob.return_value = [f"/dev/davinci{i}" for i in range(4)] ++ # mock_list.return_value = [] ++ assert Accelerator.get_current_node_num_accelerators() == 4 ++ ++ ++@patch("glob.glob") ++@patch("os.listdir") ++def test_autodetect_num_mlus_without_devices(mock_list, mock_glob): ++ mock_glob.side_effect = Exception ++ # mock_list.return_value = [] ++ assert Accelerator.get_current_node_num_accelerators() == 0 ++ ++ ++def test_mlu_accelerator_manager_api(): ++ assert Accelerator.get_resource_name() == "MLU" ++ assert Accelerator.get_visible_accelerator_ids_env_var() == "MLU_VISIBLE_DEVICES" ++ assert Accelerator.validate_resource_request_quantity(0.5) == (True, None) ++ assert Accelerator.validate_resource_request_quantity(1) == (True, None) ++ ++ ++def test_visible_mlu_type(monkeypatch, shutdown_only): ++ with patch.object( ++ Accelerator, "get_current_node_num_accelerators", return_value=4 ++ ), patch.object( ++ Accelerator, "get_current_node_accelerator_type", return_value="MLU370" ++ ): ++ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") ++ manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU") ++ assert manager.get_current_node_accelerator_type() == "MLU370" ++ ++@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows") ++def test_visible_mlu_ids(monkeypatch, shutdown_only): ++ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") ++ with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4): ++ ++ ray.init() ++ manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU") ++ assert manager.get_current_node_num_accelerators() == 4 ++ assert manager.__name__ == "MLUAcceleratorManager" ++ assert ray.available_resources()["MLU"] == 3 ++ ++def test_get_current_process_visible_accelerator_ids(monkeypatch, shutdown_only): ++ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") ++ assert Accelerator.get_current_process_visible_accelerator_ids() == ["0", "1", "2"] ++ ++ monkeypatch.delenv("MLU_VISIBLE_DEVICES") ++ assert Accelerator.get_current_process_visible_accelerator_ids() is None ++ ++ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "") ++ assert Accelerator.get_current_process_visible_accelerator_ids() == [] ++ ++ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "NoDevFiles") ++ assert Accelerator.get_current_process_visible_accelerator_ids() == [] ++ ++ ++def test_set_current_process_visible_accelerator_ids(shutdown_only): ++ Accelerator.set_current_process_visible_accelerator_ids(["0"]) ++ assert os.environ["MLU_VISIBLE_DEVICES"] == "0" ++ ++ Accelerator.set_current_process_visible_accelerator_ids(["0", "1"]) ++ assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1" ++ ++ Accelerator.set_current_process_visible_accelerator_ids(["0", "1", "2"]) ++ assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1,2" ++ ++ ++@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows") ++def test_auto_detected_more_than_visible(monkeypatch, shutdown_only): ++ with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4): ++ # If more MLUs are detected than visible. ++ monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") ++ ++ ray.init() ++ assert ray.available_resources()["MLU"] == 3 ++ ++if __name__ == "__main__": ++ if os.environ.get("PARALLEL_CI"): ++ sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) ++ else: ++ sys.exit(pytest.main(["-sv", __file__])) +diff --git a/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl b/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl +new file mode 100644 +index 0000000..8628a88 +Binary files /dev/null and b/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl differ diff --git a/tools/ray_mlu/diff_for_dump_info.patch b/tools/ray_mlu/diff_for_dump_info.patch new file mode 100644 index 0000000..ab76162 --- /dev/null +++ b/tools/ray_mlu/diff_for_dump_info.patch @@ -0,0 +1,11 @@ +diff --git a/ray_mlu/mlu.py b/ray_mlu/mlu.py +index 21a57719..2c63fd5b 100755 +--- a/ray_mlu/mlu.py ++++ b/ray_mlu/mlu.py +@@ -87,6 +87,3 @@ class MLUAcceleratorManager(AcceleratorManager): + if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR): + return + +- os.environ[ +- MLUAcceleratorManager.get_visible_accelerator_ids_env_var() +- ] = ",".join([str(i) for i in visible_mlu_devices]) diff --git a/tools/ray_mlu/mlu.py b/tools/ray_mlu/mlu.py new file mode 100755 index 0000000..8403399 --- /dev/null +++ b/tools/ray_mlu/mlu.py @@ -0,0 +1,94 @@ +import os +import glob +import logging +from typing import Optional, List, Tuple +import torch +import torch_mlu +from ray._private.accelerators.accelerator import AcceleratorManager + +logger = logging.getLogger(__name__) + +MLU_VISIBLE_DEVICES_ENV_VAR = "MLU_VISIBLE_DEVICES" +NOSET_MLU_VISIBLE_DEVICES_ENV_VAR = ( + "RAY_EXPERIMENTAL_NOSET_MLU_VISIBLE_DEVICES" +) + + +class MLUAcceleratorManager(AcceleratorManager): + """Cambricon MLU accelerators.""" + + @staticmethod + def get_resource_name() -> str: + return "GPU" + + @staticmethod + def get_visible_accelerator_ids_env_var() -> str: + return MLU_VISIBLE_DEVICES_ENV_VAR + + @staticmethod + def get_current_process_visible_accelerator_ids() -> Optional[List[str]]: + mlu_visible_devices = os.environ.get( + MLUAcceleratorManager.get_visible_accelerator_ids_env_var(), None + ) + + if mlu_visible_devices is None: + return None + + if mlu_visible_devices == "": + return [] + + if mlu_visible_devices == "NoDevFiles": + return [] + + return list(mlu_visible_devices.split(",")) + + @staticmethod + def get_current_node_num_accelerators() -> int: + """Attempt to detect the number of MLUs on this machine. + + MLU chips are represented as devices within `/dev/`, either as `/dev/davinci?`. + + Returns: + The number of MLUs if any were detected, otherwise 0. + """ + try: + return torch.mlu.device_count() + except Exception as e: + logger.debug("Could not import CambriconCL: %s", e) + + try: + mlu_files = glob.glob("/dev/cambricon_dev?") + return len(mlu_files) + except Exception as e: + logger.debug("Failed to detect number of MLUs: %s", e) + return 0 + + @staticmethod + def get_current_node_accelerator_type() -> Optional[str]: + """Get the type of the Cambricon MLU on the current node. + + Returns: + A string of the type, such as "MLU370". + """ + try: + return torch.mlu.get_device_name(0) + except Exception: + logger.exception("Failed to detect MLU type.") + return None + + @staticmethod + def validate_resource_request_quantity( + quantity: float, + ) -> Tuple[bool, Optional[str]]: + return (True, None) + + @staticmethod + def set_current_process_visible_accelerator_ids( + visible_mlu_devices: List[str], + ) -> None: + if os.environ.get(NOSET_MLU_VISIBLE_DEVICES_ENV_VAR): + return + + os.environ[ + MLUAcceleratorManager.get_visible_accelerator_ids_env_var() + ] = ",".join([str(i) for i in visible_mlu_devices]) diff --git a/tools/ray_mlu/node.py b/tools/ray_mlu/node.py new file mode 100644 index 0000000..f876c60 --- /dev/null +++ b/tools/ray_mlu/node.py @@ -0,0 +1,1890 @@ +import atexit +import collections +import datetime +import errno +import json +import logging +import os +import random +import signal +import socket +import subprocess +import sys +import tempfile +import threading +import time +import traceback +from collections import defaultdict +from typing import IO, AnyStr, Dict, Optional, Tuple + +from filelock import FileLock + +import ray +import ray._private.ray_constants as ray_constants +import ray._private.services +from ray._common.network_utils import build_address, parse_address +from ray._common.ray_constants import LOGGING_ROTATE_BACKUP_COUNT, LOGGING_ROTATE_BYTES +from ray._common.utils import try_to_create_directory +from ray._private.resource_and_label_spec import ResourceAndLabelSpec +from ray._private.resource_isolation_config import ResourceIsolationConfig +from ray._private.services import get_address, serialize_config +from ray._private.utils import ( + is_in_test, + open_log, + try_to_symlink, + validate_socket_filepath, +) +from ray._raylet import GcsClient, get_session_key_from_storage + +import psutil + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray configures it by default automatically +# using logging.basicConfig in its entry/init points. +logger = logging.getLogger(__name__) + + +class Node: + """An encapsulation of the Ray processes on a single node. + + This class is responsible for starting Ray processes and killing them, + and it also controls the temp file policy. + + Attributes: + all_processes: A mapping from process type (str) to a list of + ProcessInfo objects. All lists have length one except for the Redis + server list, which has multiple. + """ + + def __init__( + self, + ray_params, + head: bool = False, + shutdown_at_exit: bool = True, + spawn_reaper: bool = True, + connect_only: bool = False, + default_worker: bool = False, + ray_init_cluster: bool = False, + ): + """Start a node. + + Args: + ray_params: The RayParams to use to configure the node. + head: True if this is the head node, which means it will + start additional processes like the Redis servers, monitor + processes, and web UI. + shutdown_at_exit: If true, spawned processes will be cleaned + up if this process exits normally. + spawn_reaper: If true, spawns a process that will clean up + other spawned processes if this process dies unexpectedly. + connect_only: If true, connect to the node without starting + new processes. + default_worker: Whether it's running from a ray worker or not + ray_init_cluster: Whether it's a cluster created by ray.init() + """ + if shutdown_at_exit: + if connect_only: + raise ValueError( + "'shutdown_at_exit' and 'connect_only' cannot both be true." + ) + self._register_shutdown_hooks() + self._default_worker = default_worker + self.head = head + self.kernel_fate_share = bool( + spawn_reaper and ray._private.utils.detect_fate_sharing_support() + ) + self.resource_isolation_config: ResourceIsolationConfig = ( + ray_params.resource_isolation_config + ) + self.all_processes: dict = {} + self.removal_lock = threading.Lock() + + self.ray_init_cluster = ray_init_cluster + if ray_init_cluster: + assert head, "ray.init() created cluster only has the head node" + + # Set up external Redis when `RAY_REDIS_ADDRESS` is specified. + redis_address_env = os.environ.get("RAY_REDIS_ADDRESS") + if ray_params.external_addresses is None and redis_address_env is not None: + external_redis = redis_address_env.split(",") + + # Reuse primary Redis as Redis shard when there's only one + # instance provided. + if len(external_redis) == 1: + external_redis.append(external_redis[0]) + ray_params.external_addresses = external_redis + ray_params.num_redis_shards = len(external_redis) - 1 + + if ( + ray_params._system_config + and len(ray_params._system_config) > 0 + and (not head and not connect_only) + ): + raise ValueError( + "System config parameters can only be set on the head node." + ) + + ray_params.update_if_absent( + include_log_monitor=True, + resources={}, + worker_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "workers", + "default_worker.py", + ), + setup_worker_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "workers", + ray_constants.SETUP_WORKER_FILENAME, + ), + ) + + self._resource_and_label_spec = None + self._localhost = socket.gethostbyname("localhost") + self._ray_params = ray_params + self._config = ray_params._system_config or {} + + self._dashboard_agent_listen_port = ray_params.dashboard_agent_listen_port + + # Configure log rotation parameters. + self.max_bytes = int(os.getenv("RAY_ROTATION_MAX_BYTES", LOGGING_ROTATE_BYTES)) + self.backup_count = int( + os.getenv("RAY_ROTATION_BACKUP_COUNT", LOGGING_ROTATE_BACKUP_COUNT) + ) + + assert self.max_bytes >= 0 + assert self.backup_count >= 0 + + self._redis_address = ray_params.redis_address + if head: + ray_params.update_if_absent(num_redis_shards=1) + self._gcs_address = ray_params.gcs_address + self._gcs_client = None + + if not self.head: + self.validate_ip_port(self.address) + self._init_gcs_client() + + # Register the temp dir. + self._session_name = ray_params.session_name + if self._session_name is None: + if head: + # We expect this the first time we initialize a cluster, but not during + # subsequent restarts of the head node. + maybe_key = self.check_persisted_session_name() + if maybe_key is None: + # date including microsecond + date_str = datetime.datetime.today().strftime( + "%Y-%m-%d_%H-%M-%S_%f" + ) + self._session_name = f"session_{date_str}_{os.getpid()}" + else: + self._session_name = ray._common.utils.decode(maybe_key) + else: + assert not self._default_worker + session_name = ray._private.utils.internal_kv_get_with_retry( + self.get_gcs_client(), + "session_name", + ray_constants.KV_NAMESPACE_SESSION, + num_retries=ray_constants.NUM_REDIS_GET_RETRIES, + ) + self._session_name = ray._common.utils.decode(session_name) + + # Initialize webui url + if head: + self._webui_url = None + else: + if ray_params.webui is None: + assert not self._default_worker + self._webui_url = ray._private.services.get_webui_url_from_internal_kv() + else: + self._webui_url = build_address( + ray_params.dashboard_host, ray_params.dashboard_port + ) + + # It creates a session_dir. + self._init_temp() + + node_ip_address = ray_params.node_ip_address + if node_ip_address is None: + if connect_only: + node_ip_address = self._wait_and_get_for_node_address() + else: + node_ip_address = ray.util.get_node_ip_address() + + assert node_ip_address is not None + ray_params.update_if_absent(node_ip_address=node_ip_address) + self._node_ip_address = node_ip_address + if not connect_only: + ray._private.services.write_node_ip_address( + self.get_session_dir_path(), node_ip_address + ) + + self._object_spilling_config = self._get_object_spilling_config() + logger.debug( + f"Starting node with object spilling config: {self._object_spilling_config}" + ) + + # Obtain the fallback directoy from the object spilling config + # Currently, we set the fallback directory to be the same as the object spilling + # path when the object spills to file system + self._fallback_directory = None + if self._object_spilling_config: + config = json.loads(self._object_spilling_config) + if config.get("type") == "filesystem": + directory_path = config.get("params", {}).get("directory_path") + if isinstance(directory_path, list): + self._fallback_directory = directory_path[0] + elif isinstance(directory_path, str): + self._fallback_directory = directory_path + + # If it is a head node, try validating if external storage is configurable. + if head: + self.validate_external_storage() + + if connect_only: + # Get socket names from the configuration. + self._plasma_store_socket_name = ray_params.plasma_store_socket_name + self._raylet_socket_name = ray_params.raylet_socket_name + self._node_id = ray_params.node_id + + # If user does not provide the socket name, get it from Redis. + if ( + self._plasma_store_socket_name is None + or self._raylet_socket_name is None + or self._ray_params.node_manager_port is None + or self._node_id is None + ): + # Get the address info of the processes to connect to + # from Redis or GCS. + node_info = ray._private.services.get_node_to_connect_for_driver( + self.gcs_address, + self._node_ip_address, + ) + self._plasma_store_socket_name = node_info["object_store_socket_name"] + self._raylet_socket_name = node_info["raylet_socket_name"] + self._ray_params.node_manager_port = node_info["node_manager_port"] + self._node_id = node_info["node_id"] + else: + # If the user specified a socket name, use it. + self._plasma_store_socket_name = self._prepare_socket_file( + self._ray_params.plasma_store_socket_name, default_prefix="plasma_store" + ) + self._raylet_socket_name = self._prepare_socket_file( + self._ray_params.raylet_socket_name, default_prefix="raylet" + ) + if ( + self._ray_params.env_vars is not None + and "RAY_OVERRIDE_NODE_ID_FOR_TESTING" in self._ray_params.env_vars + ): + node_id = self._ray_params.env_vars["RAY_OVERRIDE_NODE_ID_FOR_TESTING"] + logger.debug( + f"Setting node ID to {node_id} " + "based on ray_params.env_vars override" + ) + self._node_id = node_id + elif os.environ.get("RAY_OVERRIDE_NODE_ID_FOR_TESTING"): + node_id = os.environ["RAY_OVERRIDE_NODE_ID_FOR_TESTING"] + logger.debug(f"Setting node ID to {node_id} based on env override") + self._node_id = node_id + else: + node_id = ray.NodeID.from_random().hex() + logger.debug(f"Setting node ID to {node_id}") + self._node_id = node_id + + # The dashboard agent port is assigned first to avoid + # other processes accidentally taking its default port + self._dashboard_agent_listen_port = self._get_cached_port( + "dashboard_agent_listen_port", + default_port=ray_params.dashboard_agent_listen_port, + ) + + self.metrics_agent_port = self._get_cached_port( + "metrics_agent_port", default_port=ray_params.metrics_agent_port + ) + self._metrics_export_port = self._get_cached_port( + "metrics_export_port", default_port=ray_params.metrics_export_port + ) + self._runtime_env_agent_port = self._get_cached_port( + "runtime_env_agent_port", + default_port=ray_params.runtime_env_agent_port, + ) + + ray_params.update_if_absent( + metrics_agent_port=self.metrics_agent_port, + metrics_export_port=self._metrics_export_port, + dashboard_agent_listen_port=self._dashboard_agent_listen_port, + runtime_env_agent_port=self._runtime_env_agent_port, + ) + + # Pick a GCS server port. + if head: + gcs_server_port = os.getenv(ray_constants.GCS_PORT_ENVIRONMENT_VARIABLE) + if gcs_server_port: + ray_params.update_if_absent(gcs_server_port=int(gcs_server_port)) + if ray_params.gcs_server_port is None or ray_params.gcs_server_port == 0: + ray_params.gcs_server_port = self._get_cached_port("gcs_server_port") + + if not connect_only and spawn_reaper and not self.kernel_fate_share: + self.start_reaper_process() + if not connect_only: + self._ray_params.update_pre_selected_port() + + # Start processes. + if head: + self.start_head_processes() + + if not connect_only: + self.start_ray_processes() + # Wait for the node info to be available in the GCS so that + # we know it's started up. + + # Grace period to let the Raylet register with the GCS. + # We retry in a loop in case it takes longer than expected. + time.sleep(0.1) + start_time = time.monotonic() + raylet_start_wait_time_s = 30 + node_info = None + while True: + try: + # Will raise a RuntimeError if the node info is not available. + node_info = ray._private.services.get_node( + self.gcs_address, + self._node_id, + ) + break + except RuntimeError as e: + logger.info(f"Failed to get node info {e}") + if time.monotonic() - start_time > raylet_start_wait_time_s: + raise Exception( + "The current node timed out during startup. This " + "could happen because some of the raylet failed to " + "startup or the GCS has become overloaded." + ) + # Use node info to update port + if self._ray_params.node_manager_port == 0: + self._ray_params.node_manager_port = node_info["node_manager_port"] + + if connect_only: + # Fetch node info to get labels. + node_info = ray._private.services.get_node( + self.gcs_address, + self._node_id, + ) + # Set node labels from GCS if provided at node init. + self._node_labels = node_info.get("labels", {}) + + # Makes sure the Node object has valid addresses after setup. + self.validate_ip_port(self.address) + self.validate_ip_port(self.gcs_address) + + if not connect_only: + self._record_stats() + + def check_persisted_session_name(self): + if self._ray_params.external_addresses is None: + return None + self._redis_address = self._ray_params.external_addresses[0] + redis_ip_address, redis_port, enable_redis_ssl = get_address( + self._redis_address, + ) + # Address is ip:port or redis://ip:port + if int(redis_port) < 0: + raise ValueError( + f"Invalid Redis port provided: {redis_port}." + "The port must be a non-negative integer." + ) + + return get_session_key_from_storage( + redis_ip_address, + int(redis_port), + self._ray_params.redis_username, + self._ray_params.redis_password, + enable_redis_ssl, + serialize_config(self._config), + b"session_name", + ) + + @staticmethod + def validate_ip_port(ip_port): + """Validates the address is in the ip:port format""" + parts = parse_address(ip_port) + if parts is None: + raise ValueError(f"Port is not specified for address {ip_port}") + try: + _ = int(parts[1]) + except ValueError: + raise ValueError( + f"Unable to parse port number from {parts[1]} (full address = {ip_port})" + ) + + def check_version_info(self): + """Check if the Python and Ray version of this process matches that in GCS. + + This will be used to detect if workers or drivers are started using + different versions of Python, or Ray. + + Raises: + Exception: An exception is raised if there is a version mismatch. + """ + import ray._common.usage.usage_lib as ray_usage_lib + + cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client()) + if cluster_metadata is None: + cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client()) + + if not cluster_metadata: + return + node_ip_address = ray._private.services.get_node_ip_address() + ray._private.utils.check_version_info( + cluster_metadata, f"node {node_ip_address}" + ) + + def _register_shutdown_hooks(self): + # Register the atexit handler. In this case, we shouldn't call sys.exit + # as we're already in the exit procedure. + def atexit_handler(*args): + self.kill_all_processes(check_alive=False, allow_graceful=True) + + atexit.register(atexit_handler) + + # Register the handler to be called if we get a SIGTERM. + # In this case, we want to exit with an error code (1) after + # cleaning up child processes. + def sigterm_handler(signum, frame): + self.kill_all_processes(check_alive=False, allow_graceful=True) + sys.exit(1) + + ray._private.utils.set_sigterm_handler(sigterm_handler) + + def _init_temp(self): + # Create a dictionary to store temp file index. + self._incremental_dict = collections.defaultdict(lambda: 0) + + if self.head: + self._ray_params.update_if_absent( + temp_dir=ray._common.utils.get_ray_temp_dir() + ) + self._temp_dir = self._ray_params.temp_dir + else: + if self._ray_params.temp_dir is None: + assert not self._default_worker + temp_dir = ray._private.utils.internal_kv_get_with_retry( + self.get_gcs_client(), + "temp_dir", + ray_constants.KV_NAMESPACE_SESSION, + num_retries=ray_constants.NUM_REDIS_GET_RETRIES, + ) + self._temp_dir = ray._common.utils.decode(temp_dir) + else: + self._temp_dir = self._ray_params.temp_dir + + try_to_create_directory(self._temp_dir) + + if self.head: + self._session_dir = os.path.join(self._temp_dir, self._session_name) + else: + if self._temp_dir is None or self._session_name is None: + assert not self._default_worker + session_dir = ray._private.utils.internal_kv_get_with_retry( + self.get_gcs_client(), + "session_dir", + ray_constants.KV_NAMESPACE_SESSION, + num_retries=ray_constants.NUM_REDIS_GET_RETRIES, + ) + self._session_dir = ray._common.utils.decode(session_dir) + else: + self._session_dir = os.path.join(self._temp_dir, self._session_name) + session_symlink = os.path.join(self._temp_dir, ray_constants.SESSION_LATEST) + + # Send a warning message if the session exists. + try_to_create_directory(self._session_dir) + try_to_symlink(session_symlink, self._session_dir) + # Create a directory to be used for socket files. + self._sockets_dir = os.path.join(self._session_dir, "sockets") + try_to_create_directory(self._sockets_dir) + # Create a directory to be used for process log files. + self._logs_dir = os.path.join(self._session_dir, "logs") + try_to_create_directory(self._logs_dir) + old_logs_dir = os.path.join(self._logs_dir, "old") + try_to_create_directory(old_logs_dir) + # Create a directory to be used for runtime environment. + self._runtime_env_dir = os.path.join( + self._session_dir, self._ray_params.runtime_env_dir_name + ) + try_to_create_directory(self._runtime_env_dir) + # Create a symlink to the libtpu tpu_logs directory if it exists. + user_temp_dir = ray._common.utils.get_user_temp_dir() + tpu_log_dir = f"{user_temp_dir}/tpu_logs" + if os.path.isdir(tpu_log_dir): + tpu_logs_symlink = os.path.join(self._logs_dir, "tpu_logs") + try_to_symlink(tpu_logs_symlink, tpu_log_dir) + + def get_resource_and_label_spec(self): + """Resolve and return the current ResourceAndLabelSpec for the node.""" + if not self._resource_and_label_spec: + self._resource_and_label_spec = ResourceAndLabelSpec( + self._ray_params.num_cpus, + self._ray_params.num_gpus, + self._ray_params.memory, + self._ray_params.object_store_memory, + self._ray_params.resources, + self._ray_params.labels, + ).resolve(is_head=self.head, node_ip_address=self.node_ip_address) + return self._resource_and_label_spec + + @property + def node_id(self): + """Get the node ID.""" + return self._node_id + + @property + def session_name(self): + """Get the current Ray session name.""" + return self._session_name + + @property + def node_ip_address(self): + """Get the IP address of this node.""" + return self._node_ip_address + + @property + def address(self): + """Get the address for bootstrapping, e.g. the address to pass to + `ray start` or `ray.init()` to start worker nodes, that has been + converted to ip:port format. + """ + return self._gcs_address + + @property + def gcs_address(self): + """Get the gcs address.""" + assert self._gcs_address is not None, "Gcs address is not set" + return self._gcs_address + + @property + def redis_address(self): + """Get the cluster Redis address.""" + return self._redis_address + + @property + def redis_username(self): + """Get the cluster Redis username.""" + return self._ray_params.redis_username + + @property + def redis_password(self): + """Get the cluster Redis password.""" + return self._ray_params.redis_password + + @property + def plasma_store_socket_name(self): + """Get the node's plasma store socket name.""" + return self._plasma_store_socket_name + + @property + def unique_id(self): + """Get a unique identifier for this node.""" + return f"{self.node_ip_address}:{self._plasma_store_socket_name}" + + @property + def webui_url(self): + """Get the cluster's web UI url.""" + return self._webui_url + + @property + def raylet_socket_name(self): + """Get the node's raylet socket name.""" + return self._raylet_socket_name + + @property + def node_manager_port(self): + """Get the node manager's port.""" + return self._ray_params.node_manager_port + + @property + def metrics_export_port(self): + """Get the port that exposes metrics""" + return self._metrics_export_port + + @property + def runtime_env_agent_port(self): + """Get the port that exposes runtime env agent as http""" + return self._runtime_env_agent_port + + @property + def runtime_env_agent_address(self): + """Get the address that exposes runtime env agent as http""" + return f"http://{build_address(self._node_ip_address, self._runtime_env_agent_port)}" + + @property + def dashboard_agent_listen_port(self): + """Get the dashboard agent's listen port""" + return self._dashboard_agent_listen_port + + @property + def logging_config(self): + """Get the logging config of the current node.""" + return { + "log_rotation_max_bytes": self.max_bytes, + "log_rotation_backup_count": self.backup_count, + } + + @property + def address_info(self): + """Get a dictionary of addresses.""" + return { + "node_ip_address": self._node_ip_address, + "redis_address": self.redis_address, + "object_store_address": self._plasma_store_socket_name, + "raylet_socket_name": self._raylet_socket_name, + "webui_url": self._webui_url, + "session_dir": self._session_dir, + "metrics_export_port": self._metrics_export_port, + "gcs_address": self.gcs_address, + "address": self.address, + "dashboard_agent_listen_port": self.dashboard_agent_listen_port, + } + + @property + def node_labels(self): + """Get the node labels.""" + return self._node_labels + + def is_head(self): + return self.head + + def get_gcs_client(self): + if self._gcs_client is None: + self._init_gcs_client() + return self._gcs_client + + def _init_gcs_client(self): + if self.head: + gcs_process = self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][ + 0 + ].process + else: + gcs_process = None + + # TODO(ryw) instead of create a new GcsClient, wrap the one from + # CoreWorkerProcess to save a grpc channel. + for _ in range(ray_constants.NUM_REDIS_GET_RETRIES): + gcs_address = None + last_ex = None + try: + gcs_address = self.gcs_address + client = GcsClient( + address=gcs_address, + cluster_id=self._ray_params.cluster_id, # Hex string + ) + self.cluster_id = client.cluster_id + if self.head: + # Send a simple request to make sure GCS is alive + # if it's a head node. + client.internal_kv_get(b"dummy", None) + self._gcs_client = client + break + except Exception: + if gcs_process is not None and gcs_process.poll() is not None: + # GCS has exited. + break + last_ex = traceback.format_exc() + logger.debug(f"Connecting to GCS: {last_ex}") + time.sleep(1) + + if self._gcs_client is None: + if hasattr(self, "_logs_dir"): + with open(os.path.join(self._logs_dir, "gcs_server.err")) as err: + # Use " C " or " E " to exclude the stacktrace. + # This should work for most cases, especitally + # it's when GCS is starting. Only display last 10 lines of logs. + errors = [e for e in err.readlines() if " C " in e or " E " in e][ + -10: + ] + error_msg = "\n" + "".join(errors) + "\n" + raise RuntimeError( + f"Failed to {'start' if self.head else 'connect to'} GCS. " + f" Last {len(errors)} lines of error files:" + f"{error_msg}." + f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}" + f" for details. Last connection error: {last_ex}" + ) + else: + raise RuntimeError( + f"Failed to {'start' if self.head else 'connect to'} GCS. Last " + f"connection error: {last_ex}" + ) + + ray.experimental.internal_kv._initialize_internal_kv(self._gcs_client) + + def get_temp_dir_path(self): + """Get the path of the temporary directory.""" + return self._temp_dir + + def get_runtime_env_dir_path(self): + """Get the path of the runtime env.""" + return self._runtime_env_dir + + def get_session_dir_path(self): + """Get the path of the session directory.""" + return self._session_dir + + def get_logs_dir_path(self): + """Get the path of the log files directory.""" + return self._logs_dir + + def get_sockets_dir_path(self): + """Get the path of the sockets directory.""" + return self._sockets_dir + + def _make_inc_temp( + self, suffix: str = "", prefix: str = "", directory_name: Optional[str] = None + ): + """Return an incremental temporary file name. The file is not created. + + Args: + suffix: The suffix of the temp file. + prefix: The prefix of the temp file. + directory_name (str) : The base directory of the temp file. + + Returns: + A string of file name. If there existing a file having + the same name, the returned name will look like + "{directory_name}/{prefix}.{unique_index}{suffix}" + """ + if directory_name is None: + directory_name = ray._common.utils.get_ray_temp_dir() + directory_name = os.path.expanduser(directory_name) + index = self._incremental_dict[suffix, prefix, directory_name] + # `tempfile.TMP_MAX` could be extremely large, + # so using `range` in Python2.x should be avoided. + while index < tempfile.TMP_MAX: + if index == 0: + filename = os.path.join(directory_name, prefix + suffix) + else: + filename = os.path.join( + directory_name, prefix + "." + str(index) + suffix + ) + index += 1 + if not os.path.exists(filename): + # Save the index. + self._incremental_dict[suffix, prefix, directory_name] = index + return filename + + raise FileExistsError(errno.EEXIST, "No usable temporary filename found") + + def should_redirect_logs(self): + redirect_output = self._ray_params.redirect_output + if redirect_output is None: + # Fall back to stderr redirect environment variable. + redirect_output = ( + os.environ.get( + ray_constants.LOGGING_REDIRECT_STDERR_ENVIRONMENT_VARIABLE + ) + != "1" + ) + return redirect_output + + # TODO(hjiang): Re-implement the logic in C++, and expose via cython. + def get_log_file_names( + self, + name: str, + unique: bool = False, + create_out: bool = True, + create_err: bool = True, + ) -> Tuple[Optional[str], Optional[str]]: + """Get filename to dump logs for stdout and stderr, with no files opened. + If output redirection has been disabled, no files will + be opened and `(None, None)` will be returned. + + Args: + name: descriptive string for this log file. + unique: if true, a counter will be attached to `name` to + ensure the returned filename is not already used. + create_out: if True, create a .out file. + create_err: if True, create a .err file. + + Returns: + A tuple of two file handles for redirecting optional (stdout, stderr), + or `(None, None)` if output redirection is disabled. + """ + if not self.should_redirect_logs(): + return None, None + + log_stdout = None + log_stderr = None + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: disable ray dump log to prevent log files from continuously growing + ''' + DEVNULL_PATH = '/dev/null' if sys.platform != 'win32' else 'NUL' + if ( + name in ["gcs_server", "raylet"] \ + and ("VLLM_DUMP_RAY_LOG_EN" not in os.environ or \ + os.environ["VLLM_DUMP_RAY_LOG_EN"].lower() not in ["true", "1"]) + ): + log_stdout = DEVNULL_PATH + elif create_out: + log_stdout = self._get_log_file_name(name, "out", unique=unique) + ''' + ================== + End of MLU Hijack + ================== + ''' + if create_err: + log_stderr = self._get_log_file_name(name, "err", unique=unique) + + return log_stdout, log_stderr + + def get_log_file_handles( + self, + name: str, + unique: bool = False, + create_out: bool = True, + create_err: bool = True, + ) -> Tuple[Optional[IO[AnyStr]], Optional[IO[AnyStr]]]: + """Open log files with partially randomized filenames, returning the + file handles. If output redirection has been disabled, no files will + be opened and `(None, None)` will be returned. + + Args: + name: descriptive string for this log file. + unique: if true, a counter will be attached to `name` to + ensure the returned filename is not already used. + create_out: if True, create a .out file. + create_err: if True, create a .err file. + + Returns: + A tuple of two file handles for redirecting optional (stdout, stderr), + or `(None, None)` if output redirection is disabled. + """ + log_stdout_fname, log_stderr_fname = self.get_log_file_names( + name, unique=unique, create_out=create_out, create_err=create_err + ) + log_stdout = None if log_stdout_fname is None else open_log(log_stdout_fname) + log_stderr = None if log_stderr_fname is None else open_log(log_stderr_fname) + return log_stdout, log_stderr + + def _get_log_file_name( + self, + name: str, + suffix: str, + unique: bool = False, + ) -> str: + """Generate partially randomized filenames for log files. + + Args: + name: descriptive string for this log file. + suffix: suffix of the file. Usually it is .out of .err. + unique: if true, a counter will be attached to `name` to + ensure the returned filename is not already used. + + Returns: + A tuple of two file names for redirecting (stdout, stderr). + """ + # strip if the suffix is something like .out. + suffix = suffix.strip(".") + + if unique: + filename = self._make_inc_temp( + suffix=f".{suffix}", prefix=name, directory_name=self._logs_dir + ) + else: + filename = os.path.join(self._logs_dir, f"{name}.{suffix}") + return filename + + def _get_unused_port(self, allocated_ports=None): + if allocated_ports is None: + allocated_ports = set() + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + + # Try to generate a port that is far above the 'next available' one. + # This solves issue #8254 where GRPC fails because the port assigned + # from this method has been used by a different process. + for _ in range(ray_constants.NUM_PORT_RETRIES): + new_port = random.randint(port, 65535) + if new_port in allocated_ports: + # This port is allocated for other usage already, + # so we shouldn't use it even if it's not in use right now. + continue + new_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + new_s.bind(("", new_port)) + except OSError: + new_s.close() + continue + s.close() + new_s.close() + return new_port + logger.error("Unable to succeed in selecting a random port.") + s.close() + return port + + def _prepare_socket_file(self, socket_path: str, default_prefix: str): + """Prepare the socket file for raylet and plasma. + + This method helps to prepare a socket file. + 1. Make the directory if the directory does not exist. + 2. If the socket file exists, do nothing (this just means we aren't the + first worker on the node). + + Args: + socket_path: the socket file to prepare. + """ + result = socket_path + if sys.platform == "win32": + if socket_path is None: + result = ( + f"tcp://{build_address(self._localhost, self._get_unused_port())}" + ) + else: + if socket_path is None: + result = self._make_inc_temp( + prefix=default_prefix, directory_name=self._sockets_dir + ) + else: + try_to_create_directory(os.path.dirname(socket_path)) + + validate_socket_filepath(result.split("://", 1)[-1]) + return result + + def _get_cached_port( + self, port_name: str, default_port: Optional[int] = None + ) -> int: + """Get a port number from a cache on this node. + + Different driver processes on a node should use the same ports for + some purposes, e.g. exporting metrics. This method returns a port + number for the given port name and caches it in a file. If the + port isn't already cached, an unused port is generated and cached. + + Args: + port_name: The name of the port, e.g. metrics_export_port. + default_port: The port to return and cache if no port has already been + cached for the given port_name. If None, an unused port is generated + and cached. + + Returns: + int: The port number. + """ + file_path = os.path.join(self.get_session_dir_path(), "ports_by_node.json") + + # Make sure only the ports in RAY_CACHED_PORTS are cached. + assert port_name in ray_constants.RAY_ALLOWED_CACHED_PORTS + + # Maps a Node.unique_id to a dict that maps port names to port numbers. + ports_by_node: Dict[str, Dict[str, int]] = defaultdict(dict) + + with FileLock(file_path + ".lock"): + if not os.path.exists(file_path): + with open(file_path, "w") as f: + json.dump({}, f) + + with open(file_path, "r") as f: + ports_by_node.update(json.load(f)) + + if ( + self.unique_id in ports_by_node + and port_name in ports_by_node[self.unique_id] + ): + # The port has already been cached at this node, so use it. + port = int(ports_by_node[self.unique_id][port_name]) + else: + # Pick a new port to use and cache it at this node. + allocated_ports = set(ports_by_node[self.unique_id].values()) + + if default_port is not None and default_port in allocated_ports: + # The default port is already in use, so don't use it. + default_port = None + + port = default_port or self._get_unused_port(allocated_ports) + + ports_by_node[self.unique_id][port_name] = port + with open(file_path, "w") as f: + json.dump(ports_by_node, f) + + return port + + def _wait_and_get_for_node_address(self, timeout_s: int = 60) -> str: + """Wait until the RAY_NODE_IP_FILENAME file is avialable. + + RAY_NODE_IP_FILENAME is created when a ray instance is started. + + Args: + timeout_s: If the ip address is not found within this + timeout, it will raise ValueError. + Returns: + The node_ip_address of the current session if it finds it + within timeout_s. + """ + for i in range(timeout_s): + node_ip_address = ray._private.services.get_cached_node_ip_address( + self.get_session_dir_path() + ) + + if node_ip_address is not None: + return node_ip_address + + time.sleep(1) + if i % 10 == 0: + logger.info( + f"Can't find a `{ray_constants.RAY_NODE_IP_FILENAME}` " + f"file from {self.get_session_dir_path()}. " + "Have you started Ray instance using " + "`ray start` or `ray.init`?" + ) + + raise ValueError( + f"Can't find a `{ray_constants.RAY_NODE_IP_FILENAME}` " + f"file from {self.get_session_dir_path()}. " + f"for {timeout_s} seconds. " + "A ray instance hasn't started. " + "Did you do `ray start` or `ray.init` on this host?" + ) + + def start_reaper_process(self): + """ + Start the reaper process. + + This must be the first process spawned and should only be called when + ray processes should be cleaned up if this process dies. + """ + assert ( + not self.kernel_fate_share + ), "a reaper should not be used with kernel fate-sharing" + process_info = ray._private.services.start_reaper(fate_share=False) + assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes + if process_info is not None: + self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [ + process_info, + ] + + def start_log_monitor(self): + """Start the log monitor.""" + stdout_log_fname, stderr_log_fname = self.get_log_file_names( + "log_monitor", unique=True, create_out=True, create_err=True + ) + process_info = ray._private.services.start_log_monitor( + self.get_session_dir_path(), + self._logs_dir, + self.gcs_address, + fate_share=self.kernel_fate_share, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + stdout_filepath=stdout_log_fname, + stderr_filepath=stderr_log_fname, + ) + assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [ + process_info, + ] + + def start_api_server( + self, *, include_dashboard: Optional[bool], raise_on_failure: bool + ): + """Start the dashboard. + + Args: + include_dashboard: If true, this will load all dashboard-related modules + when starting the API server. Otherwise, it will only + start the modules that are not relevant to the dashboard. + raise_on_failure: If true, this will raise an exception + if we fail to start the API server. Otherwise it will print + a warning if we fail to start the API server. + """ + stdout_log_fname, stderr_log_fname = self.get_log_file_names( + "dashboard", unique=True, create_out=True, create_err=True + ) + self._webui_url, process_info = ray._private.services.start_api_server( + include_dashboard, + raise_on_failure, + self._ray_params.dashboard_host, + self.gcs_address, + self.cluster_id.hex(), + self._node_ip_address, + self._temp_dir, + self._logs_dir, + self._session_dir, + port=self._ray_params.dashboard_port, + fate_share=self.kernel_fate_share, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + stdout_filepath=stdout_log_fname, + stderr_filepath=stderr_log_fname, + ) + assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes + if process_info is not None: + self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [ + process_info, + ] + self.get_gcs_client().internal_kv_put( + b"webui:url", + self._webui_url.encode(), + True, + ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + def start_gcs_server(self): + """Start the gcs server.""" + gcs_server_port = self._ray_params.gcs_server_port + assert gcs_server_port > 0 + assert self._gcs_address is None, "GCS server is already running." + assert self._gcs_client is None, "GCS client is already connected." + + stdout_log_fname, stderr_log_fname = self.get_log_file_names( + "gcs_server", unique=True, create_out=True, create_err=True + ) + process_info = ray._private.services.start_gcs_server( + self.redis_address, + log_dir=self._logs_dir, + stdout_filepath=stdout_log_fname, + stderr_filepath=stderr_log_fname, + session_name=self.session_name, + redis_username=self._ray_params.redis_username, + redis_password=self._ray_params.redis_password, + config=self._config, + fate_share=self.kernel_fate_share, + gcs_server_port=gcs_server_port, + metrics_agent_port=self._ray_params.metrics_agent_port, + node_ip_address=self._node_ip_address, + ) + assert ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [ + process_info, + ] + # Connecting via non-localhost address may be blocked by firewall rule, + # e.g. https://github.com/ray-project/ray/issues/15780 + # TODO(mwtian): figure out a way to use 127.0.0.1 for local connection + # when possible. + self._gcs_address = build_address(self._node_ip_address, gcs_server_port) + + def start_raylet( + self, + plasma_directory: str, + fallback_directory: str, + object_store_memory: int, + use_valgrind: bool = False, + use_profiler: bool = False, + ): + """Start the raylet. + + Args: + use_valgrind: True if we should start the process in + valgrind. + use_profiler: True if we should start the process in the + valgrind profiler. + """ + raylet_stdout_filepath, raylet_stderr_filepath = self.get_log_file_names( + ray_constants.PROCESS_TYPE_RAYLET, + unique=True, + create_out=True, + create_err=True, + ) + ( + dashboard_agent_stdout_filepath, + dashboard_agent_stderr_filepath, + ) = self.get_log_file_names( + ray_constants.PROCESS_TYPE_DASHBOARD_AGENT, + unique=True, + create_out=True, + create_err=True, + ) + ( + runtime_env_agent_stdout_filepath, + runtime_env_agent_stderr_filepath, + ) = self.get_log_file_names( + ray_constants.PROCESS_TYPE_RUNTIME_ENV_AGENT, + unique=True, + create_out=True, + create_err=True, + ) + + self.resource_isolation_config.add_system_pids( + self._get_system_processes_for_resource_isolation() + ) + + process_info = ray._private.services.start_raylet( + self.redis_address, + self.gcs_address, + self._node_id, + self._node_ip_address, + self._ray_params.node_manager_port, + self._raylet_socket_name, + self._plasma_store_socket_name, + self.cluster_id.hex(), + self._ray_params.worker_path, + self._ray_params.setup_worker_path, + self._temp_dir, + self._session_dir, + self._runtime_env_dir, + self._logs_dir, + self.get_resource_and_label_spec(), + plasma_directory, + fallback_directory, + object_store_memory, + self.session_name, + is_head_node=self.is_head(), + min_worker_port=self._ray_params.min_worker_port, + max_worker_port=self._ray_params.max_worker_port, + worker_port_list=self._ray_params.worker_port_list, + object_manager_port=self._ray_params.object_manager_port, + redis_username=self._ray_params.redis_username, + redis_password=self._ray_params.redis_password, + metrics_agent_port=self._ray_params.metrics_agent_port, + runtime_env_agent_port=self._ray_params.runtime_env_agent_port, + metrics_export_port=self._metrics_export_port, + dashboard_agent_listen_port=self._ray_params.dashboard_agent_listen_port, + use_valgrind=use_valgrind, + use_profiler=use_profiler, + raylet_stdout_filepath=raylet_stdout_filepath, + raylet_stderr_filepath=raylet_stderr_filepath, + dashboard_agent_stdout_filepath=dashboard_agent_stdout_filepath, + dashboard_agent_stderr_filepath=dashboard_agent_stderr_filepath, + runtime_env_agent_stdout_filepath=runtime_env_agent_stdout_filepath, + runtime_env_agent_stderr_filepath=runtime_env_agent_stderr_filepath, + huge_pages=self._ray_params.huge_pages, + fate_share=self.kernel_fate_share, + socket_to_use=None, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + ray_debugger_external=self._ray_params.ray_debugger_external, + env_updates=self._ray_params.env_vars, + node_name=self._ray_params.node_name, + webui=self._webui_url, + resource_isolation_config=self.resource_isolation_config, + ) + assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] + + def start_monitor(self): + """Start the monitor. + + Autoscaling output goes to these monitor.err/out files, and + any modification to these files may break existing + cluster launching commands. + """ + from ray.autoscaler.v2.utils import is_autoscaler_v2 + + stdout_log_fname, stderr_log_fname = self.get_log_file_names( + "monitor", unique=True, create_out=True, create_err=True + ) + process_info = ray._private.services.start_monitor( + self.gcs_address, + self._logs_dir, + stdout_filepath=stdout_log_fname, + stderr_filepath=stderr_log_fname, + autoscaling_config=self._ray_params.autoscaling_config, + fate_share=self.kernel_fate_share, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + monitor_ip=self._node_ip_address, + autoscaler_v2=is_autoscaler_v2(fetch_from_server=True), + ) + assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info] + + def start_ray_client_server(self): + """Start the ray client server process.""" + stdout_file, stderr_file = self.get_log_file_handles( + "ray_client_server", unique=True + ) + process_info = ray._private.services.start_ray_client_server( + self.address, + self._node_ip_address, + self._ray_params.ray_client_server_port, + stdout_file=stdout_file, + stderr_file=stderr_file, + redis_username=self._ray_params.redis_username, + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share, + runtime_env_agent_address=self.runtime_env_agent_address, + ) + assert ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER not in self.all_processes + self.all_processes[ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER] = [ + process_info + ] + + def _write_cluster_info_to_kv(self): + """Write the cluster metadata to GCS. + Cluster metadata is always recorded, but they are + not reported unless usage report is enabled. + Check `usage_stats_head.py` for more details. + """ + # Make sure the cluster metadata wasn't reported before. + import ray._common.usage.usage_lib as ray_usage_lib + + ray_usage_lib.put_cluster_metadata( + self.get_gcs_client(), ray_init_cluster=self.ray_init_cluster + ) + # Make sure GCS is up. + added = self.get_gcs_client().internal_kv_put( + b"session_name", + self._session_name.encode(), + False, + ray_constants.KV_NAMESPACE_SESSION, + ) + if not added: + curr_val = self.get_gcs_client().internal_kv_get( + b"session_name", ray_constants.KV_NAMESPACE_SESSION + ) + assert curr_val == self._session_name.encode("utf-8"), ( + f"Session name {self._session_name} does not match " + f"persisted value {curr_val}. Perhaps there was an " + f"error connecting to Redis." + ) + + self.get_gcs_client().internal_kv_put( + b"session_dir", + self._session_dir.encode(), + True, + ray_constants.KV_NAMESPACE_SESSION, + ) + self.get_gcs_client().internal_kv_put( + b"temp_dir", + self._temp_dir.encode(), + True, + ray_constants.KV_NAMESPACE_SESSION, + ) + # Add tracing_startup_hook to redis / internal kv manually + # since internal kv is not yet initialized. + if self._ray_params.tracing_startup_hook: + self.get_gcs_client().internal_kv_put( + b"tracing_startup_hook", + self._ray_params.tracing_startup_hook.encode(), + True, + ray_constants.KV_NAMESPACE_TRACING, + ) + + def start_head_processes(self): + """Start head processes on the node.""" + logger.debug( + f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}." + ) + assert self._gcs_address is None + assert self._gcs_client is None + + self.start_gcs_server() + assert self.get_gcs_client() is not None + self._write_cluster_info_to_kv() + + if not self._ray_params.no_monitor: + self.start_monitor() + + if self._ray_params.ray_client_server_port: + self.start_ray_client_server() + + if self._ray_params.include_dashboard is None: + # Default + raise_on_api_server_failure = False + else: + raise_on_api_server_failure = self._ray_params.include_dashboard + + self.start_api_server( + include_dashboard=self._ray_params.include_dashboard, + raise_on_failure=raise_on_api_server_failure, + ) + + def start_ray_processes(self): + """Start all of the processes on the node.""" + logger.debug( + f"Process STDOUT and STDERR is being " f"redirected to {self._logs_dir}." + ) + + if not self.head: + # Get the system config from GCS first if this is a non-head node. + gcs_options = ray._raylet.GcsClientOptions.create( + self.gcs_address, + self.cluster_id.hex(), + allow_cluster_id_nil=False, + fetch_cluster_id_if_nil=False, + ) + global_state = ray._private.state.GlobalState() + global_state._initialize_global_state(gcs_options) + new_config = global_state.get_system_config() + assert self._config.items() <= new_config.items(), ( + "The system config from GCS is not a superset of the local" + " system config. There might be a configuration inconsistency" + " issue between the head node and non-head nodes." + f" Local system config: {self._config}," + f" GCS system config: {new_config}" + ) + self._config = new_config + + # Make sure we don't call `determine_plasma_store_config` multiple + # times to avoid printing multiple warnings. + resource_and_label_spec = self.get_resource_and_label_spec() + if resource_and_label_spec.labels.get( + ray._raylet.RAY_NODE_ACCELERATOR_TYPE_KEY + ): + from ray._common.usage import usage_lib + + usage_lib.record_hardware_usage( + resource_and_label_spec.labels.get( + ray._raylet.RAY_NODE_ACCELERATOR_TYPE_KEY + ) + ) + + ( + plasma_directory, + fallback_directory, + object_store_memory, + ) = ray._private.services.determine_plasma_store_config( + resource_and_label_spec.object_store_memory, + self._temp_dir, + plasma_directory=self._ray_params.plasma_directory, + fallback_directory=self._fallback_directory, + huge_pages=self._ray_params.huge_pages, + ) + + # add plasma store memory to the total system reserved memory + if self.resource_isolation_config.is_enabled(): + self.resource_isolation_config.add_object_store_memory(object_store_memory) + + if self._ray_params.include_log_monitor: + self.start_log_monitor() + + self.start_raylet(plasma_directory, fallback_directory, object_store_memory) + + def _get_system_processes_for_resource_isolation(self) -> str: + """Returns a list of system processes that will be isolated by raylet. + + NOTE: If a new system process is started before the raylet starts up, it needs to be + added to self.all_processes so it can be moved into the raylet's managed cgroup + hierarchy. + """ + system_process_pids = [ + str(p[0].process.pid) for p in self.all_processes.values() + ] + + # If the dashboard api server was started on the head node, then include all of the api server's + # child processes. + if ray_constants.PROCESS_TYPE_DASHBOARD in self.all_processes: + dashboard_pid = self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][ + 0 + ].process.pid + dashboard_process = psutil.Process(dashboard_pid) + system_process_pids += [str(p.pid) for p in dashboard_process.children()] + + return ",".join(system_process_pids) + + def _kill_process_type( + self, + process_type, + allow_graceful: bool = False, + check_alive: bool = True, + wait: bool = False, + ): + """Kill a process of a given type. + + If the process type is PROCESS_TYPE_REDIS_SERVER, then we will kill all + of the Redis servers. + + If the process was started in valgrind, then we will raise an exception + if the process has a non-zero exit code. + + Args: + process_type: The type of the process to kill. + allow_graceful: Send a SIGTERM first and give the process + time to exit gracefully. If that doesn't work, then use + SIGKILL. We usually want to do this outside of tests. + check_alive: If true, then we expect the process to be alive + and will raise an exception if the process is already dead. + wait: If true, then this method will not return until the + process in question has exited. + + Raises: + This process raises an exception in the following cases: + 1. The process had already died and check_alive is true. + 2. The process had been started in valgrind and had a non-zero + exit code. + """ + + # Ensure thread safety + with self.removal_lock: + self._kill_process_impl( + process_type, + allow_graceful=allow_graceful, + check_alive=check_alive, + wait=wait, + ) + + def _kill_process_impl( + self, process_type, allow_graceful=False, check_alive=True, wait=False + ): + """See `_kill_process_type`.""" + if process_type not in self.all_processes: + return + process_infos = self.all_processes[process_type] + if process_type != ray_constants.PROCESS_TYPE_REDIS_SERVER: + assert len(process_infos) == 1 + for process_info in process_infos: + process = process_info.process + # Handle the case where the process has already exited. + if process.poll() is not None: + if check_alive: + raise RuntimeError( + "Attempting to kill a process of type " + f"'{process_type}', but this process is already dead." + ) + else: + continue + + if process_info.use_valgrind: + process.terminate() + process.wait() + if process.returncode != 0: + message = ( + "Valgrind detected some errors in process of " + f"type {process_type}. Error code {process.returncode}." + ) + if process_info.stdout_file is not None: + with open(process_info.stdout_file, "r") as f: + message += "\nPROCESS STDOUT:\n" + f.read() + if process_info.stderr_file is not None: + with open(process_info.stderr_file, "r") as f: + message += "\nPROCESS STDERR:\n" + f.read() + raise RuntimeError(message) + continue + + if process_info.use_valgrind_profiler: + # Give process signal to write profiler data. + os.kill(process.pid, signal.SIGINT) + # Wait for profiling data to be written. + time.sleep(0.1) + + if allow_graceful: + process.terminate() + # Allow the process one second to exit gracefully. + timeout_seconds = 1 + try: + process.wait(timeout_seconds) + except subprocess.TimeoutExpired: + pass + + # If the process did not exit, force kill it. + if process.poll() is None: + process.kill() + # The reason we usually don't call process.wait() here is that + # there's some chance we'd end up waiting a really long time. + if wait: + process.wait() + + del self.all_processes[process_type] + + def kill_redis(self, check_alive: bool = True): + """Kill the Redis servers. + + Args: + check_alive: Raise an exception if any of the processes + were already dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_REDIS_SERVER, check_alive=check_alive + ) + + def kill_raylet(self, check_alive: bool = True): + """Kill the raylet. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_RAYLET, check_alive=check_alive + ) + + def kill_log_monitor(self, check_alive: bool = True): + """Kill the log monitor. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_LOG_MONITOR, check_alive=check_alive + ) + + def kill_dashboard(self, check_alive: bool = True): + """Kill the dashboard. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_DASHBOARD, check_alive=check_alive + ) + + def kill_monitor(self, check_alive: bool = True): + """Kill the monitor. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_MONITOR, check_alive=check_alive + ) + + def kill_gcs_server(self, check_alive: bool = True): + """Kill the gcs server. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_GCS_SERVER, check_alive=check_alive, wait=True + ) + # Clear GCS client and address to indicate no GCS server is running. + self._gcs_address = None + self._gcs_client = None + + def kill_reaper(self, check_alive: bool = True): + """Kill the reaper process. + + Args: + check_alive: Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_REAPER, check_alive=check_alive + ) + + def kill_all_processes(self, check_alive=True, allow_graceful=False, wait=False): + """Kill all of the processes. + + Note that This is slower than necessary because it calls kill, wait, + kill, wait, ... instead of kill, kill, ..., wait, wait, ... + + Args: + check_alive: Raise an exception if any of the processes were + already dead. + wait: If true, then this method will not return until the + process in question has exited. + """ + # Kill the raylet first. This is important for suppressing errors at + # shutdown because we give the raylet a chance to exit gracefully and + # clean up its child worker processes. If we were to kill the plasma + # store (or Redis) first, that could cause the raylet to exit + # ungracefully, leading to more verbose output from the workers. + if ray_constants.PROCESS_TYPE_RAYLET in self.all_processes: + self._kill_process_type( + ray_constants.PROCESS_TYPE_RAYLET, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + if ray_constants.PROCESS_TYPE_GCS_SERVER in self.all_processes: + self._kill_process_type( + ray_constants.PROCESS_TYPE_GCS_SERVER, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + # We call "list" to copy the keys because we are modifying the + # dictionary while iterating over it. + for process_type in list(self.all_processes.keys()): + # Need to kill the reaper process last in case we die unexpectedly + # while cleaning up. + if process_type != ray_constants.PROCESS_TYPE_REAPER: + self._kill_process_type( + process_type, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + if ray_constants.PROCESS_TYPE_REAPER in self.all_processes: + self._kill_process_type( + ray_constants.PROCESS_TYPE_REAPER, + check_alive=check_alive, + allow_graceful=allow_graceful, + wait=wait, + ) + + def live_processes(self): + """Return a list of the live processes. + + Returns: + A list of the live processes. + """ + result = [] + for process_type, process_infos in self.all_processes.items(): + for process_info in process_infos: + if process_info.process.poll() is None: + result.append((process_type, process_info.process)) + return result + + def dead_processes(self): + """Return a list of the dead processes. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + A list of the dead processes ignoring the ones that have been + explicitly killed. + """ + result = [] + for process_type, process_infos in self.all_processes.items(): + for process_info in process_infos: + if process_info.process.poll() is not None: + result.append((process_type, process_info.process)) + return result + + def any_processes_alive(self): + """Return true if any processes are still alive. + + Returns: + True if any process is still alive. + """ + return any(self.live_processes()) + + def remaining_processes_alive(self): + """Return true if all remaining processes are still alive. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + True if any process that wasn't explicitly killed is still alive. + """ + return not any(self.dead_processes()) + + def destroy_external_storage(self): + object_spilling_config = self._config.get("object_spilling_config", {}) + if object_spilling_config: + object_spilling_config = json.loads(object_spilling_config) + from ray._private import external_storage + + storage = external_storage.setup_external_storage( + object_spilling_config, self._node_id, self._session_name + ) + storage.destroy_external_storage() + + def validate_external_storage(self): + """Make sure we can setup the object spilling external storage.""" + + automatic_spilling_enabled = self._config.get( + "automatic_object_spilling_enabled", True + ) + if not automatic_spilling_enabled: + return + + object_spilling_config = self._object_spilling_config + # Try setting up the storage. + # Configure the proper system config. + # We need to set both ray param's system config and self._config + # because they could've been diverged at this point. + deserialized_config = json.loads(object_spilling_config) + self._ray_params._system_config[ + "object_spilling_config" + ] = object_spilling_config + self._config["object_spilling_config"] = object_spilling_config + + is_external_storage_type_fs = deserialized_config["type"] == "filesystem" + self._ray_params._system_config[ + "is_external_storage_type_fs" + ] = is_external_storage_type_fs + self._config["is_external_storage_type_fs"] = is_external_storage_type_fs + + # Validate external storage usage. + from ray._private import external_storage + + # Node ID is available only after GCS is connected. However, + # validate_external_storage() needs to be called before it to + # be able to validate the configs early. Therefore, we use a + # dummy node ID here and make sure external storage can be set + # up based on the provided config. This storage is destroyed + # right after the validation. + dummy_node_id = ray.NodeID.from_random().hex() + storage = external_storage.setup_external_storage( + deserialized_config, dummy_node_id, self._session_name + ) + storage.destroy_external_storage() + external_storage.reset_external_storage() + + def _get_object_spilling_config(self): + """Consolidate the object spilling config from the ray params, environment + variable, and system config. The object spilling directory specified through + ray params will override the one specified through environment variable and + system config.""" + + object_spilling_directory = self._ray_params.object_spilling_directory + if not object_spilling_directory: + object_spilling_directory = self._config.get( + "object_spilling_directory", "" + ) + + if not object_spilling_directory: + object_spilling_directory = os.environ.get( + "RAY_object_spilling_directory", "" + ) + + if object_spilling_directory: + return json.dumps( + { + "type": "filesystem", + "params": {"directory_path": object_spilling_directory}, + } + ) + + object_spilling_config = self._config.get("object_spilling_config", {}) + if not object_spilling_config: + object_spilling_config = os.environ.get("RAY_object_spilling_config", "") + + # If the config is not specified in ray params, system config or environment + # variable, we fill up the default. + if not object_spilling_config: + object_spilling_config = json.dumps( + {"type": "filesystem", "params": {"directory_path": self._session_dir}} + ) + else: + if not is_in_test(): + logger.warning( + "The object spilling config is specified from an unstable " + "API - system config or environment variable. This is " + "subject to change in the future. You can use the stable " + "API - --object-spilling-directory in ray start or " + "object_spilling_directory in ray.init() to specify the " + "object spilling directory instead. If you need more " + "advanced settings, please open a github issue with the " + "Ray team." + ) + + return object_spilling_config + + def _record_stats(self): + # This is only called when a new node is started. + # Initialize the internal kv so that the metrics can be put + from ray._common.usage.usage_lib import ( + TagKey, + record_extra_usage_tag, + record_hardware_usage, + ) + + if not ray.experimental.internal_kv._internal_kv_initialized(): + ray.experimental.internal_kv._initialize_internal_kv(self.get_gcs_client()) + assert ray.experimental.internal_kv._internal_kv_initialized() + if self.head: + # record head node stats + gcs_storage_type = ( + "redis" if os.environ.get("RAY_REDIS_ADDRESS") is not None else "memory" + ) + record_extra_usage_tag(TagKey.GCS_STORAGE, gcs_storage_type) + cpu_model_name = ray._private.utils.get_current_node_cpu_model_name() + if cpu_model_name: + # CPU model name can be an arbitrary long string + # so we truncate it to the first 50 characters + # to avoid any issues. + record_hardware_usage(cpu_model_name[:50]) diff --git a/tools/ray_mlu/nsight.py b/tools/ray_mlu/nsight.py new file mode 100644 index 0000000..c4b3fe0 --- /dev/null +++ b/tools/ray_mlu/nsight.py @@ -0,0 +1,142 @@ +import asyncio +import copy +import logging +import os +import subprocess +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from ray._common.utils import ( + try_to_create_directory, +) +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray.exceptions import RuntimeEnvSetupError + +default_logger = logging.getLogger(__name__) + +# Nsight options used when runtime_env={"_nsight": "default"} +# use default cnperf config, no need to specify any options +NSIGHT_DEFAULT_CONFIG = {} + +def parse_nsight_config(nsight_config: Dict[str, str]) -> List[str]: + """ + Function to convert dictionary of nsight options into + nsight command line + + The function returns: + - List[str]: nsys profile cmd line split into list of str + """ + nsight_cmd = ["cnperf-cli", "record"] + for option, option_val in nsight_config.items(): + # option standard based on + # https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html + if len(option) > 1: + nsight_cmd.append(f"--{option}={option_val}") + else: + nsight_cmd += [f"-{option}", option_val] + return nsight_cmd + + +class NsightPlugin(RuntimeEnvPlugin): + name = "_nsight" + + def __init__(self, resources_dir: str): + self.nsight_cmd = [] + + # replace this with better way to get logs dir + session_dir, runtime_dir = os.path.split(resources_dir) + self._nsight_dir = Path(session_dir) / "logs" / "nsight" + try_to_create_directory(self._nsight_dir) + + async def _check_nsight_script( + self, nsight_config: Dict[str, str] + ) -> Tuple[bool, str]: + """ + Function to validate if nsight_config is a valid nsight profile options + Args: + nsight_config: dictionary mapping nsight option to it's value + Returns: + a tuple consists of a boolean indicating if the nsight_config + is valid option and an error message if the nsight_config is invalid + """ + + # use empty as nsight report test filename + nsight_config_copy = copy.deepcopy(nsight_config) + try_to_create_directory(Path(self._nsight_dir) / "empty") + nsight_config_copy["o"] = str(Path(self._nsight_dir) / "empty/test") + nsight_cmd = parse_nsight_config(nsight_config_copy) + try: + nsight_cmd = nsight_cmd + ["python", "-c", '""'] + process = await asyncio.create_subprocess_exec( + *nsight_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + error_msg = stderr.strip() if stderr.strip() != "" else stdout.strip() + + # cleanup test.cnperf-rep file + clean_up_cmd = ["rm", f"{nsight_config_copy['o']}.cnperf-rep"] + cleanup_process = await asyncio.create_subprocess_exec( + *clean_up_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _, _ = await cleanup_process.communicate() + if process.returncode == 0: + return True, None + else: + return False, error_msg + except FileNotFoundError: + return False, ("cnperf-cli is not installed") + + async def create( + self, + uri: Optional[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger = default_logger, + ) -> int: + nsight_config = runtime_env.nsight() + if not nsight_config: + return 0 + + if nsight_config and sys.platform != "linux": + raise RuntimeEnvSetupError( + "CNPerf CLI is only available in Linux.\n" + "More information can be found in " + "https://docs.nvidia.com/nsight-compute/NsightComputeCli/index.html" + ) + + if isinstance(nsight_config, str): + if nsight_config == "default": + nsight_config = NSIGHT_DEFAULT_CONFIG + else: + raise RuntimeEnvSetupError( + f"Unsupported nsight config: {nsight_config}. " + "The supported config is 'default' or " + "Dictionary of cnperf options" + ) + + is_valid_nsight_cmd, error_msg = await self._check_nsight_script(nsight_config) + if not is_valid_nsight_cmd: + logger.warning(error_msg) + raise RuntimeEnvSetupError( + "cnperf-cli failed to run with the following " + f"error message:\n {error_msg}" + ) + self.nsight_cmd = parse_nsight_config(nsight_config) + return 0 + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, + ): + context.py_executable = " ".join(self.nsight_cmd) + " python" + logger.info("Running CNPerf cmd: %s", context.py_executable) + diff --git a/tools/ray_mlu/test_mlu.py b/tools/ray_mlu/test_mlu.py new file mode 100755 index 0000000..70e81f7 --- /dev/null +++ b/tools/ray_mlu/test_mlu.py @@ -0,0 +1,92 @@ +import os +import sys +import pytest +from unittest.mock import patch + +import ray +from ray._private.accelerators import MLUAcceleratorManager as Accelerator + + +@patch("glob.glob") +@patch("os.listdir") +def test_autodetect_num_mlus(mock_list, mock_glob): + mock_glob.return_value = [f"/dev/davinci{i}" for i in range(4)] + # mock_list.return_value = [] + assert Accelerator.get_current_node_num_accelerators() == 4 + + +@patch("glob.glob") +@patch("os.listdir") +def test_autodetect_num_mlus_without_devices(mock_list, mock_glob): + mock_glob.side_effect = Exception + # mock_list.return_value = [] + assert Accelerator.get_current_node_num_accelerators() == 0 + + +def test_mlu_accelerator_manager_api(): + assert Accelerator.get_resource_name() == "MLU" + assert Accelerator.get_visible_accelerator_ids_env_var() == "MLU_VISIBLE_DEVICES" + assert Accelerator.validate_resource_request_quantity(0.5) == (True, None) + assert Accelerator.validate_resource_request_quantity(1) == (True, None) + + +def test_visible_mlu_type(monkeypatch, shutdown_only): + with patch.object( + Accelerator, "get_current_node_num_accelerators", return_value=4 + ), patch.object( + Accelerator, "get_current_node_accelerator_type", return_value="MLU370" + ): + monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") + manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU") + assert manager.get_current_node_accelerator_type() == "MLU370" + +@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows") +def test_visible_mlu_ids(monkeypatch, shutdown_only): + monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") + with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4): + + ray.init() + manager = ray._private.accelerators.get_accelerator_manager_for_resource("MLU") + assert manager.get_current_node_num_accelerators() == 4 + assert manager.__name__ == "MLUAcceleratorManager" + assert ray.available_resources()["MLU"] == 3 + +def test_get_current_process_visible_accelerator_ids(monkeypatch, shutdown_only): + monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") + assert Accelerator.get_current_process_visible_accelerator_ids() == ["0", "1", "2"] + + monkeypatch.delenv("MLU_VISIBLE_DEVICES") + assert Accelerator.get_current_process_visible_accelerator_ids() is None + + monkeypatch.setenv("MLU_VISIBLE_DEVICES", "") + assert Accelerator.get_current_process_visible_accelerator_ids() == [] + + monkeypatch.setenv("MLU_VISIBLE_DEVICES", "NoDevFiles") + assert Accelerator.get_current_process_visible_accelerator_ids() == [] + + +def test_set_current_process_visible_accelerator_ids(shutdown_only): + Accelerator.set_current_process_visible_accelerator_ids(["0"]) + assert os.environ["MLU_VISIBLE_DEVICES"] == "0" + + Accelerator.set_current_process_visible_accelerator_ids(["0", "1"]) + assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1" + + Accelerator.set_current_process_visible_accelerator_ids(["0", "1", "2"]) + assert os.environ["MLU_VISIBLE_DEVICES"] == "0,1,2" + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not supported mock on Windows") +def test_auto_detected_more_than_visible(monkeypatch, shutdown_only): + with patch.object(Accelerator, "get_current_node_num_accelerators", return_value=4): + # If more MLUs are detected than visible. + monkeypatch.setenv("MLU_VISIBLE_DEVICES", "0,1,2") + + ray.init() + assert ray.available_resources()["MLU"] == 3 + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/tools/ray_mlu/worker.py b/tools/ray_mlu/worker.py new file mode 100644 index 0000000..8232e46 --- /dev/null +++ b/tools/ray_mlu/worker.py @@ -0,0 +1,3785 @@ +import atexit +import faulthandler +import functools +import inspect +import io +import json +import logging +import os +import sys +import threading +import time +import traceback +import urllib +import warnings +from abc import ABCMeta, abstractmethod +from collections.abc import Mapping +from contextlib import contextmanager +from dataclasses import dataclass +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + Callable, + Dict, + Generic, + Iterator, + List, + Literal, + Optional, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from urllib.parse import urlparse + +if TYPE_CHECKING: + import torch + +import colorama + +import ray +import ray._private.node +import ray._private.parameter +import ray._private.profiling as profiling +import ray._private.ray_constants as ray_constants +import ray._private.serialization as serialization +import ray._private.services as services +import ray._private.state +import ray._private.worker + +# Ray modules +import ray.actor +import ray.cloudpickle as pickle # noqa +import ray.job_config +import ray.remote_function +from ray import ActorID, JobID, Language, ObjectRef +from ray._common import ray_option_utils +from ray._common.constants import RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR +from ray._common.utils import load_class +from ray._private.client_mode_hook import client_mode_hook +from ray._private.custom_types import TensorTransportEnum +from ray._private.function_manager import FunctionActorManager +from ray._private.inspect_util import is_cython +from ray._private.ray_logging import ( + global_worker_stdstream_dispatcher, + setup_logger, + stderr_deduplicator, + stdout_deduplicator, +) +from ray._private.ray_logging.logging_config import LoggingConfig +from ray._private.resource_isolation_config import ResourceIsolationConfig +from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR +from ray._private.runtime_env.py_modules import upload_py_modules_if_needed +from ray._private.runtime_env.setup_hook import ( + upload_worker_process_setup_hook_if_needed, +) +from ray._private.runtime_env.working_dir import upload_working_dir_if_needed +from ray._private.utils import get_ray_doc_version +from ray._raylet import ( + ObjectRefGenerator, + TaskID, + raise_sys_exit_with_custom_error_message, +) +from ray.actor import ActorClass +from ray.exceptions import ObjectStoreFullError, RayError, RaySystemError, RayTaskError +from ray.experimental import tqdm_ray +from ray.experimental.compiled_dag_ref import CompiledDAGRef +from ray.experimental.internal_kv import ( + _initialize_internal_kv, + _internal_kv_get, + _internal_kv_initialized, + _internal_kv_reset, +) +from ray.experimental.tqdm_ray import RAY_TQDM_MAGIC +from ray.runtime_env.runtime_env import _merge_runtime_env +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.util.debug import log_once +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from ray.util.tracing.tracing_helper import _import_from_string +from ray.widgets import Template +from ray.widgets.util import repr_with_fallback + +SCRIPT_MODE = 0 +WORKER_MODE = 1 +LOCAL_MODE = 2 +SPILL_WORKER_MODE = 3 +RESTORE_WORKER_MODE = 4 + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + + +T = TypeVar("T") +T0 = TypeVar("T0") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +T7 = TypeVar("T7") +T8 = TypeVar("T8") +T9 = TypeVar("T9") +R = TypeVar("R") + +DAGNode = TypeVar("DAGNode") + + +# Only used for type annotations as a placeholder +Undefined: Any = object() + + +# TypeVar for self-referential generics in `RemoteFunction[N]`. +RF = TypeVar("RF", bound="HasOptions") + + +class HasOptions(Protocol): + def options(self: RF, **task_options) -> RF: + ... + + +class RemoteFunctionNoArgs(HasOptions, Generic[R]): + def __init__(self, function: Callable[[], R]) -> None: + pass + + def remote( + self, + ) -> "ObjectRef[R]": + ... + + def bind( + self, + ) -> "DAGNode[R]": + ... + + +class RemoteFunction0(HasOptions, Generic[R, T0]): + def __init__(self, function: Callable[[T0], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction1(HasOptions, Generic[R, T0, T1]): + def __init__(self, function: Callable[[T0, T1], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction2(HasOptions, Generic[R, T0, T1, T2]): + def __init__(self, function: Callable[[T0, T1, T2], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction3(HasOptions, Generic[R, T0, T1, T2, T3]): + def __init__(self, function: Callable[[T0, T1, T2, T3], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction4(HasOptions, Generic[R, T0, T1, T2, T3, T4]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction5(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction6(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction7(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + __arg7: "Union[T7, ObjectRef[T7]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + __arg7: "Union[T7, DAGNode[T7]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction8(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]): + def __init__( + self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R] + ) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + __arg7: "Union[T7, ObjectRef[T7]]", + __arg8: "Union[T8, ObjectRef[T8]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + __arg7: "Union[T7, DAGNode[T7]]", + __arg8: "Union[T8, DAGNode[T8]]", + ) -> "DAGNode[R]": + ... + + +class RemoteFunction9(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]): + def __init__( + self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R] + ) -> None: + pass + + def remote( + self, + __arg0: "Union[T0, ObjectRef[T0]]", + __arg1: "Union[T1, ObjectRef[T1]]", + __arg2: "Union[T2, ObjectRef[T2]]", + __arg3: "Union[T3, ObjectRef[T3]]", + __arg4: "Union[T4, ObjectRef[T4]]", + __arg5: "Union[T5, ObjectRef[T5]]", + __arg6: "Union[T6, ObjectRef[T6]]", + __arg7: "Union[T7, ObjectRef[T7]]", + __arg8: "Union[T8, ObjectRef[T8]]", + __arg9: "Union[T9, ObjectRef[T9]]", + ) -> "ObjectRef[R]": + ... + + def bind( + self, + __arg0: "Union[T0, DAGNode[T0]]", + __arg1: "Union[T1, DAGNode[T1]]", + __arg2: "Union[T2, DAGNode[T2]]", + __arg3: "Union[T3, DAGNode[T3]]", + __arg4: "Union[T4, DAGNode[T4]]", + __arg5: "Union[T5, DAGNode[T5]]", + __arg6: "Union[T6, DAGNode[T6]]", + __arg7: "Union[T7, DAGNode[T7]]", + __arg8: "Union[T8, DAGNode[T8]]", + __arg9: "Union[T9, DAGNode[T9]]", + ) -> "DAGNode[R]": + ... + + +# Visible for testing. +def _unhandled_error_handler(e: Exception): + logger.error( + f"Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): {e}" + ) + + +class Worker: + """A class used to define the control flow of a worker process. + + Note: + The methods in this class are considered unexposed to the user. The + functions outside of this class are considered exposed. + + Attributes: + node (ray._private.node.Node): The node this worker is attached to. + mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and + WORKER_MODE. + """ + + def __init__(self): + """Initialize a Worker object.""" + self.node = None + self.mode = None + self.actors = {} + # GPU object manager to manage GPU object lifecycles, including coordinating out-of-band + # tensor transfers between actors, storing and retrieving GPU objects, and garbage collection. + # We create the GPU object manager lazily, if a user specifies a + # non-default tensor_transport, to avoid circular import and because it + # imports third-party dependencies like PyTorch. + self._gpu_object_manager = None + # When the worker is constructed. Record the original value of the + # (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, HIP_VISIBLE_DEVICES, + # NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..) environment variables. + self.original_visible_accelerator_ids = ( + ray._private.utils.get_visible_accelerator_ids() + ) + # A dictionary that maps from driver id to SerializationContext + # TODO: clean up the SerializationContext once the job finished. + self.serialization_context_map = {} + self.function_actor_manager = FunctionActorManager(self) + # This event is checked regularly by all of the threads so that they + # know when to exit. + self.threads_stopped = threading.Event() + # If this is set, the next .remote call should drop into the + # debugger, at the specified breakpoint ID. + self.debugger_breakpoint = b"" + # If this is set, ray.get calls invoked on the object ID returned + # by the worker should drop into the debugger at the specified + # breakpoint ID. + self.debugger_get_breakpoint = b"" + # If True, make the debugger external to the node this worker is + # running on. + self.ray_debugger_external = False + self._load_code_from_local = False + # Opened file descriptor to stdout/stderr for this python worker. + self._enable_record_actor_task_log = ( + ray_constants.RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING + ) + # Whether rotation is enabled for out file and err file, task log position report will be skipped if rotation enabled, since the position cannot be accurate. + self._file_rotation_enabled = False + self._out_filepath = None + self._err_filepath = None + # Create the lock here because the serializer will use it before + # initializing Ray. + self.lock = threading.RLock() + # By default, don't show logs from other drivers. This is set to true by Serve + # in order to stream logs from the controller and replica actors across + # different drivers that connect to the same Serve instance. + # See https://github.com/ray-project/ray/pull/35070. + self._filter_logs_by_job = True + # the debugger port for this worker + self._debugger_port = None + # Cache the job id from initialize_job_config() to optimize lookups. + # This is on the critical path of ray.get()/put() calls. + self._cached_job_id = None + # Indicates whether the worker is connected to the Ray cluster. + # It should be set to True in `connect` and False in `disconnect`. + self._is_connected: bool = False + + @property + def gpu_object_manager(self) -> "ray.experimental.GPUObjectManager": + if self._gpu_object_manager is None: + # We create the GPU object manager lazily, if a user specifies a + # non-default tensor_transport, to avoid circular import and because it + # imports third-party dependencies like PyTorch. + from ray.experimental import GPUObjectManager + + self._gpu_object_manager = GPUObjectManager() + return self._gpu_object_manager + + @property + def connected(self): + """bool: True if Ray has been started and False otherwise.""" + return self._is_connected + + def set_is_connected(self, is_connected: bool): + self._is_connected = is_connected + + @property + def node_ip_address(self): + self.check_connected() + return self.node.node_ip_address + + @property + def load_code_from_local(self): + self.check_connected() + return self._load_code_from_local + + @property + def current_job_id(self): + if self._cached_job_id is not None: + return self._cached_job_id + elif hasattr(self, "core_worker"): + return self.core_worker.get_current_job_id() + return JobID.nil() + + @property + def actor_id(self): + if hasattr(self, "core_worker"): + return self.core_worker.get_actor_id() + return ActorID.nil() + + @property + def actor_name(self): + if hasattr(self, "core_worker"): + return self.core_worker.get_actor_name().decode("utf-8") + return None + + @property + def current_task_id(self): + return self.core_worker.get_current_task_id() + + @property + def current_task_name(self): + return self.core_worker.get_current_task_name() + + @property + def current_task_function_name(self): + return self.core_worker.get_current_task_function_name() + + @property + def current_node_id(self): + return self.core_worker.get_current_node_id() + + @property + def task_depth(self): + return self.core_worker.get_task_depth() + + @property + def namespace(self): + return self.core_worker.get_job_config().ray_namespace + + @property + def placement_group_id(self): + return self.core_worker.get_placement_group_id() + + @property + def worker_id(self): + return self.core_worker.get_worker_id().binary() + + @property + def should_capture_child_tasks_in_placement_group(self): + return self.core_worker.should_capture_child_tasks_in_placement_group() + + @property + def current_cluster_and_job(self): + """Get the current session index and job id as pair.""" + assert isinstance(self.node.cluster_id, ray.ClusterID) + assert isinstance(self.current_job_id, ray.JobID) + return self.node.cluster_id, self.current_job_id + + @property + def runtime_env(self): + """Get the runtime env in json format""" + return self.core_worker.get_current_runtime_env() + + @property + def debugger_port(self): + """Get the debugger port for this worker""" + worker_id = self.core_worker.get_worker_id() + return ray._private.state.get_worker_debugger_port(worker_id) + + @property + def job_logging_config(self): + """Get the job's logging config for this worker""" + if not hasattr(self, "core_worker"): + return None + job_config = self.core_worker.get_job_config() + if not job_config.serialized_py_logging_config: + return None + logging_config = pickle.loads(job_config.serialized_py_logging_config) + return logging_config + + @property + def current_node_labels(self): + # Return the node labels of this worker's current node. + return self.node.node_labels + + def set_debugger_port(self, port): + worker_id = self.core_worker.get_worker_id() + ray._private.state.update_worker_debugger_port(worker_id, port) + + def set_cached_job_id(self, job_id): + """Set the cached job id to speed `current_job_id()`.""" + self._cached_job_id = job_id + + @contextmanager + def task_paused_by_debugger(self): + """Use while the task is paused by debugger""" + try: + self.core_worker.update_task_is_debugger_paused( + ray.get_runtime_context()._get_current_task_id(), True + ) + yield + finally: + self.core_worker.update_task_is_debugger_paused( + ray.get_runtime_context()._get_current_task_id(), False + ) + + @contextmanager + def worker_paused_by_debugger(self): + """ + Updates the worker num paused threads when the worker is paused by debugger + """ + try: + worker_id = self.core_worker.get_worker_id() + ray._private.state.update_worker_num_paused_threads(worker_id, 1) + yield + finally: + ray._private.state.update_worker_num_paused_threads(worker_id, -1) + + def set_file_rotation_enabled(self, rotation_enabled: bool) -> None: + """Set whether rotation is enabled for outfile and errfile.""" + self._file_rotation_enabled = rotation_enabled + + def set_err_file(self, err_filepath=Optional[AnyStr]) -> None: + """Set the worker's err file where stderr is redirected to""" + self._err_filepath = err_filepath + + def set_out_file(self, out_filepath=Optional[AnyStr]) -> None: + """Set the worker's out file where stdout is redirected to""" + self._out_filepath = out_filepath + + def record_task_log_start(self, task_id: TaskID, attempt_number: int): + """Record the task log info when task starts executing for + non concurrent actor tasks.""" + if not self._enable_record_actor_task_log and not self.actor_id.is_nil(): + # We are not recording actor task log if not enabled explicitly. + # Recording actor task log is expensive and should be enabled only + # when needed. + # https://github.com/ray-project/ray/issues/35598 + return + + if not hasattr(self, "core_worker"): + return + if self._file_rotation_enabled: + return + + self.core_worker.record_task_log_start( + task_id, + attempt_number, + self.get_out_file_path(), + self.get_err_file_path(), + self.get_current_out_offset(), + self.get_current_err_offset(), + ) + + def record_task_log_end(self, task_id: TaskID, attempt_number: int): + """Record the task log info when task finishes executing for + non concurrent actor tasks.""" + if not self._enable_record_actor_task_log and not self.actor_id.is_nil(): + # We are not recording actor task log if not enabled explicitly. + # Recording actor task log is expensive and should be enabled only + # when needed. + # https://github.com/ray-project/ray/issues/35598 + return + + if not hasattr(self, "core_worker"): + return + + # Disable file offset fetch if rotation enabled (since file offset doesn't make sense for rotated files). + if self._file_rotation_enabled: + return + + self.core_worker.record_task_log_end( + task_id, + attempt_number, + self.get_current_out_offset(), + self.get_current_err_offset(), + ) + + def get_out_file_path(self) -> str: + """Get the out log file path""" + return self._out_filepath if self._out_filepath is not None else "" + + def get_err_file_path(self) -> str: + """Get the err log file path""" + return self._err_filepath if self._err_filepath is not None else "" + + def get_current_out_offset(self) -> int: + """Get the current offset of the out file if seekable, else 0""" + if self._out_filepath is not None: + return os.path.getsize(self._out_filepath) + return 0 + + def get_current_err_offset(self) -> int: + """Get the current offset of the err file if seekable, else 0""" + if self._err_filepath is not None: + return os.path.getsize(self._err_filepath) + return 0 + + def get_serialization_context(self): + """Get the SerializationContext of the job that this worker is processing. + + Returns: + The serialization context of the given job. + """ + # This function needs to be protected by a lock, because it will be + # called by`register_class_for_serialization`, as well as the import + # thread, from different threads. Also, this function will recursively + # call itself, so we use RLock here. + job_id = self.current_job_id + context_map = self.serialization_context_map + with self.lock: + if job_id not in context_map: + # The job ID is nil before initializing Ray. + if JobID.nil() in context_map: + # Transfer the serializer context used before initializing Ray. + context_map[job_id] = context_map.pop(JobID.nil()) + else: + context_map[job_id] = serialization.SerializationContext(self) + return context_map[job_id] + + def check_connected(self): + """Check if the worker is connected. + + Raises: + Exception: An exception is raised if the worker is not connected. + """ + if not self.connected: + raise RaySystemError( + "Ray has not been started yet. You can start Ray with 'ray.init()'." + ) + + def set_mode(self, mode): + """Set the mode of the worker. + + The mode SCRIPT_MODE should be used if this Worker is a driver that is + being run as a Python script or interactively in a shell. It will print + information about task failures. + + The mode WORKER_MODE should be used if this Worker is not a driver. It + will not print information about tasks. + + The mode LOCAL_MODE should be used if this Worker is a driver and if + you want to run the driver in a manner equivalent to serial Python for + debugging purposes. It will not send remote function calls to the + scheduler and will instead execute them in a blocking fashion. + + Args: + mode: One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. + """ + self.mode = mode + + def set_load_code_from_local(self, load_code_from_local): + self._load_code_from_local = load_code_from_local + + def put_object( + self, + value: Any, + owner_address: Optional[str] = None, + _is_experimental_channel: bool = False, + _tensor_transport: str = "object_store", + ): + """Put value in the local object store. + + If the plasma store is full, the worker will automatically + retry up to DEFAULT_PUT_OBJECT_RETRIES times. Each retry + will delay for an exponentially doubling amount of time, + starting with DEFAULT_PUT_OBJECT_DELAY. After this, exception + will be raised. + + Args: + value: The value to put in the object store. + owner_address: The serialized address of object's owner. + _is_experimental_channel: An experimental flag for mutable + objects. If True, then the returned object will not have a + valid value. The object must be written to using the + ray.experimental.channel API before readers can read. + _tensor_transport: [Alpha] The tensor transport backend to use. Currently, this supports "object_store" and "nixl". + Returns: + ObjectRef: The object ref the object was put under. + + Raises: + ray.exceptions.ObjectStoreFullError: This is raised if the attempt + to store the object fails because the object store is full even + after multiple retries. + """ + # Make sure that the value is not an object ref. + if isinstance(value, ObjectRef): + raise TypeError( + "Calling 'put' on an ray.ObjectRef is not allowed. " + "If you really want to do this, you can wrap the " + "ray.ObjectRef in a list and call 'put' on it." + ) + tensors = None + tensor_transport: TensorTransportEnum = TensorTransportEnum.from_str( + _tensor_transport + ) + if tensor_transport not in [ + TensorTransportEnum.OBJECT_STORE, + TensorTransportEnum.NIXL, + ]: + raise ValueError( + "Currently, Ray Direct Transport only supports 'object_store' and 'nixl' for tensor transport in ray.put()." + ) + try: + if tensor_transport != TensorTransportEnum.OBJECT_STORE: + ( + serialized_value, + tensors, + ) = self.get_serialization_context().serialize_gpu_objects(value) + else: + serialized_value = self.get_serialization_context().serialize(value) + except TypeError as e: + sio = io.StringIO() + ray.util.inspect_serializability(value, print_file=sio) + msg = ( + "Could not serialize the put value " + f"{repr(value)}:\n" + f"{sio.getvalue()}" + ) + raise TypeError(msg) from e + + # If the object is mutable, then the raylet should never read the + # object. Instead, clients will keep the object pinned. + pin_object = not _is_experimental_channel + + # This *must* be the first place that we construct this python + # ObjectRef because an entry with 0 local references is created when + # the object is Put() in the core worker, expecting that this python + # reference will be created. If another reference is created and + # removed before this one, it will corrupt the state in the + # reference counter. + ret = self.core_worker.put_object( + serialized_value, + pin_object=pin_object, + owner_address=owner_address, + inline_small_object=True, + _is_experimental_channel=_is_experimental_channel, + tensor_transport_val=tensor_transport.value, + ) + if tensors: + self.gpu_object_manager.put_object(ret, tensor_transport, tensors) + return ret + + def raise_errors(self, serialized_objects, object_refs): + out = self.deserialize_objects(serialized_objects, object_refs) + if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: + return + for e in out: + _unhandled_error_handler(e) + + def deserialize_objects( + self, + serialized_objects, + object_refs, + tensor_transport_hint: Optional[TensorTransportEnum] = None, + ): + gpu_objects: Dict[str, List["torch.Tensor"]] = {} + for obj_ref, (_, _, tensor_transport) in zip(object_refs, serialized_objects): + # TODO: Here tensor_transport_hint is set by the user in ray.get(), tensor_transport is set + # in serialize_objects by ray.method(tensor_transport="xxx"), and obj_ref.tensor_transport() + # is set by ray.put(). We may clean up this logic in the future. + if ( + tensor_transport is None + or tensor_transport == TensorTransportEnum.OBJECT_STORE + ) and ( + obj_ref is None + or obj_ref.tensor_transport() == TensorTransportEnum.OBJECT_STORE.value + ): + # The object is not a gpu object, so we cannot use other external transport to + # fetch it. + continue + + # If the object is a gpu object, we can choose to use the object store or other external + # transport to fetch it. The `tensor_transport_hint` has the highest priority, then the + # tensor_transport in obj_ref.tensor_transport(), then the tensor_transport in serialize_objects, + # then the default value `OBJECT_STORE`. + chosen_tensor_transport = ( + tensor_transport_hint + or ( + TensorTransportEnum(obj_ref.tensor_transport()) if obj_ref else None + ) + or tensor_transport + or TensorTransportEnum.OBJECT_STORE + ) + + object_id = obj_ref.hex() + if object_id not in gpu_objects: + # If using a non-object store transport, then tensors will be sent + # out-of-band. Get them before deserializing the object store data. + gpu_objects[object_id] = self.gpu_object_manager.get_gpu_object( + object_id, tensor_transport=chosen_tensor_transport + ) + + # Function actor manager or the import thread may call pickle.loads + # at the same time which can lead to failed imports + # TODO: We may be better off locking on all imports or injecting a lock + # into pickle.loads (https://github.com/ray-project/ray/issues/16304) + with self.function_actor_manager.lock: + context = self.get_serialization_context() + return context.deserialize_objects( + serialized_objects, object_refs, gpu_objects + ) + + def get_objects( + self, + object_refs: list, + timeout: Optional[float] = None, + return_exceptions: bool = False, + skip_deserialization: bool = False, + _tensor_transport: Optional[str] = None, + ) -> Tuple[List[serialization.SerializedRayObject], bytes]: + """Get the values in the object store associated with the IDs. + + Return the values from the local object store for object_refs. This + will block until all the values for object_refs have been written to + the local object store. + + Args: + object_refs: A list of the object refs + whose values should be retrieved. + timeout: The maximum amount of time in + seconds to wait before returning. + return_exceptions: If any of the objects deserialize to an + Exception object, whether to return them as values in the + returned list. If False, then the first found exception will be + raised. + skip_deserialization: If true, only the buffer will be released and + the object associated with the buffer will not be deserialized. + _tensor_transport: [Alpha] The tensor transport to use to fetch `torch.Tensors` found in the Ray Direct Transport object. Currently, this supports "object_store" and "nixl". + Returns: + list: List of deserialized objects or None if skip_deserialization is True. + bytes: UUID of the debugger breakpoint we should drop + into or b"" if there is no breakpoint. + """ + # Make sure that the values are object refs. + for object_ref in object_refs: + if not isinstance(object_ref, ObjectRef): + raise TypeError( + f"Attempting to call `get` on the value {object_ref}, " + "which is not an ray.ObjectRef." + ) + tensor_transport: TensorTransportEnum = ( + TensorTransportEnum.from_str(_tensor_transport) + if _tensor_transport is not None + else None + ) + assert tensor_transport in [ + TensorTransportEnum.OBJECT_STORE, + TensorTransportEnum.NIXL, + None, + ], "Currently, RDT only supports 'object_store' and 'nixl' for tensor transport in ray.get()." + timeout_ms = ( + int(timeout * 1000) if timeout is not None and timeout != -1 else -1 + ) + serialized_objects: List[ + serialization.SerializedRayObject + ] = self.core_worker.get_objects( + object_refs, + timeout_ms, + ) + + debugger_breakpoint = b"" + for data, metadata, _ in serialized_objects: + if metadata: + metadata_fields = metadata.split(b",") + if len(metadata_fields) >= 2 and metadata_fields[1].startswith( + ray_constants.OBJECT_METADATA_DEBUG_PREFIX + ): + debugger_breakpoint = metadata_fields[1][ + len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) : + ] + if skip_deserialization: + return None, debugger_breakpoint + + values = self.deserialize_objects( + serialized_objects, object_refs, tensor_transport_hint=tensor_transport + ) + if not return_exceptions: + # Raise exceptions instead of returning them to the user. + for i, value in enumerate(values): + if isinstance(value, RayError): + if isinstance(value, ray.exceptions.ObjectLostError): + global_worker.core_worker.log_plasma_usage() + if isinstance(value, RayTaskError): + raise value.as_instanceof_cause() + else: + raise value + + return values, debugger_breakpoint + + def main_loop(self): + """The main loop a worker runs to receive and execute tasks.""" + + def sigterm_handler(signum, frame): + raise_sys_exit_with_custom_error_message( + "The process receives a SIGTERM.", exit_code=1 + ) + # Note: shutdown() function is called from atexit handler. + + ray._private.utils.set_sigterm_handler(sigterm_handler) + self.core_worker.run_task_loop() + sys.exit(0) + + def print_logs(self): + """Prints log messages from workers on all nodes in the same job.""" + subscriber = self.gcs_log_subscriber + subscriber.subscribe() + exception_type = ray.exceptions.RpcError + localhost = services.get_node_ip_address() + try: + # Number of messages received from the last polling. When the batch + # size exceeds 100 and keeps increasing, the worker and the user + # probably will not be able to consume the log messages as rapidly + # as they are coming in. + # This is meaningful only for GCS subscriber. + last_polling_batch_size = 0 + job_id_hex = self.current_job_id.hex() + while True: + # Exit if we received a signal that we should stop. + if self.threads_stopped.is_set(): + return + + data = subscriber.poll() + # GCS subscriber only returns None on unavailability. + if data is None: + last_polling_batch_size = 0 + continue + + if ( + self._filter_logs_by_job + and data["job"] + and data["job"] != job_id_hex + ): + last_polling_batch_size = 0 + continue + + data["localhost"] = localhost + global_worker_stdstream_dispatcher.emit(data) + + lagging = 100 <= last_polling_batch_size < subscriber.last_batch_size + if lagging: + logger.warning( + "The driver may not be able to keep up with the " + "stdout/stderr of the workers. To avoid forwarding " + "logs to the driver, use " + "'ray.init(log_to_driver=False)'." + ) + + last_polling_batch_size = subscriber.last_batch_size + + except (OSError, exception_type) as e: + logger.error(f"print_logs: {e}") + finally: + # Close the pubsub client to avoid leaking file descriptors. + subscriber.close() + + def get_accelerator_ids_for_accelerator_resource( + self, resource_name: str, resource_regex: str + ) -> Union[List[str], List[int]]: + """Get the accelerator IDs that are assigned to the given accelerator resource. + + Args: + resource_name: The name of the resource. + resource_regex: The regex of the resource. + + Returns: + (List[str]) The IDs that are assigned to the given resource pre-configured. + (List[int]) The IDs that are assigned to the given resource. + """ + resource_ids = self.core_worker.resource_ids() + assigned_ids = set() + # Handle both normal and placement group accelerator resources. + # Note: We should only get the accelerator ids from the placement + # group resource that does not contain the bundle index! + import re + + for resource, assignment in resource_ids.items(): + if resource == resource_name or re.match(resource_regex, resource): + for resource_id, _ in assignment: + assigned_ids.add(resource_id) + + # If the user had already set the environment variables + # (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, NEURON_RT_VISIBLE_CORES, + # TPU_VISIBLE_CHIPS, ..) then respect that in the sense that only IDs + # that appear in (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, + # HIP_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..) + # should be returned. + if self.original_visible_accelerator_ids.get(resource_name, None) is not None: + original_ids = self.original_visible_accelerator_ids[resource_name] + assigned_ids = {str(original_ids[i]) for i in assigned_ids} + # Give all accelerator ids in local_mode. + if self.mode == LOCAL_MODE: + if resource_name == ray_constants.GPU: + max_accelerators = self.node.get_resource_and_label_spec().num_gpus + else: + max_accelerators = ( + self.node.get_resource_and_label_spec().resources.get( + resource_name, None + ) + ) + if max_accelerators: + assigned_ids = original_ids[:max_accelerators] + return list(assigned_ids) + + def shutdown_gpu_object_manager(self): + if self._gpu_object_manager: + self._gpu_object_manager.shutdown() + + +_connect_or_shutdown_lock = threading.RLock() + + +def with_connect_or_shutdown_lock(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + with _connect_or_shutdown_lock: + return func(*args, **kwargs) + + return wrapper + + +@PublicAPI +@client_mode_hook +def get_gpu_ids() -> Union[List[int], List[str]]: + """Get the IDs of the GPUs that are available to the worker. + + This method should only be called inside of a task or actor, and not a driver. + + If the CUDA_VISIBLE_DEVICES environment variable was set when the worker + started up, then the IDs returned by this method will be a subset of the + IDs in CUDA_VISIBLE_DEVICES. If not, the IDs will fall in the range + [0, NUM_GPUS - 1], where NUM_GPUS is the number of GPUs that the node has. + + Returns: + A list of GPU IDs. + """ + worker = global_worker + worker.check_connected() + return worker.get_accelerator_ids_for_accelerator_resource( + ray_constants.GPU, f"^{ray_constants.GPU}_group_[0-9A-Za-z]+$" + ) + + +@Deprecated( + message="Use ray.get_runtime_context().get_assigned_resources() instead.", + warning=True, +) +def get_resource_ids(): + """Get the IDs of the resources that are available to the worker. + + Returns: + A dictionary mapping the name of a resource to a list of pairs, where + each pair consists of the ID of a resource and the fraction of that + resource reserved for this worker. + """ + worker = global_worker + worker.check_connected() + + if _mode() == LOCAL_MODE: + raise RuntimeError( + "ray._private.worker.get_resource_ids() does not work in local_mode." + ) + + return global_worker.core_worker.resource_ids() + + +@Deprecated(message="Use ray.init().address_info['webui_url'] instead.") +def get_dashboard_url(): + """Get the URL to access the Ray dashboard. + + Note that the URL does not specify which node the dashboard is on. + + Returns: + The URL of the dashboard as a string. + """ + if ray_constants.RAY_OVERRIDE_DASHBOARD_URL in os.environ: + return _remove_protocol_from_url( + os.environ.get(ray_constants.RAY_OVERRIDE_DASHBOARD_URL) + ) + else: + worker = global_worker + worker.check_connected() + return _global_node.webui_url + + +def _remove_protocol_from_url(url: Optional[str]) -> str: + """ + Helper function to remove protocol from URL if it exists. + """ + if not url: + return url + parsed_url = urllib.parse.urlparse(url) + if parsed_url.scheme: + # Construct URL without protocol + scheme = f"{parsed_url.scheme}://" + return parsed_url.geturl().replace(scheme, "", 1) + return url + + +class BaseContext(metaclass=ABCMeta): + """ + Base class for RayContext and ClientContext + """ + + dashboard_url: Optional[str] + python_version: str + ray_version: str + + @abstractmethod + def disconnect(self): + """ + If this context is for directly attaching to a cluster, disconnect + will call ray.shutdown(). Otherwise, if the context is for a ray + client connection, the client will be disconnected. + """ + pass + + @abstractmethod + def __enter__(self): + pass + + @abstractmethod + def __exit__(self): + pass + + def _context_table_template(self): + if self.dashboard_url: + dashboard_row = Template("context_dashrow.html.j2").render( + dashboard_url="http://" + self.dashboard_url + ) + else: + dashboard_row = None + + return Template("context_table.html.j2").render( + python_version=self.python_version, + ray_version=self.ray_version, + dashboard_row=dashboard_row, + ) + + def _repr_html_(self): + return Template("context.html.j2").render( + context_logo=Template("context_logo.html.j2").render(), + context_table=self._context_table_template(), + ) + + @repr_with_fallback(["ipywidgets", "8"]) + def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]: + """Get the mimebundle for the widget representation of the context. + + Args: + **kwargs: Passed to the _repr_mimebundle_() function for the widget + + Returns: + Dictionary ("mimebundle") of the widget representation of the context. + """ + import ipywidgets + + disconnect_button = ipywidgets.Button( + description="Disconnect", + disabled=False, + button_style="", + tooltip="Disconnect from the Ray cluster", + layout=ipywidgets.Layout(margin="auto 0px 0px 0px"), + ) + + def disconnect_callback(button): + button.disabled = True + button.description = "Disconnecting..." + self.disconnect() + button.description = "Disconnected" + + disconnect_button.on_click(disconnect_callback) + left_content = ipywidgets.VBox( + [ + ipywidgets.HTML(Template("context_logo.html.j2").render()), + disconnect_button, + ], + layout=ipywidgets.Layout(), + ) + right_content = ipywidgets.HTML(self._context_table_template()) + widget = ipywidgets.HBox( + [left_content, right_content], layout=ipywidgets.Layout(width="100%") + ) + return widget._repr_mimebundle_(**kwargs) + + def _repr_mimebundle_(self, **kwargs): + bundle = self._get_widget_bundle(**kwargs) + + # Overwrite the widget html repr and default repr with those of the BaseContext + bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)}) + return bundle + + +@dataclass +class RayContext(BaseContext, Mapping): + """ + Context manager for attached drivers. + """ + + dashboard_url: Optional[str] + python_version: str + ray_version: str + ray_commit: str + + def __init__(self, address_info: Dict[str, Optional[str]]): + super().__init__() + self.dashboard_url = get_dashboard_url() + self.python_version = "{}.{}.{}".format(*sys.version_info[:3]) + self.ray_version = ray.__version__ + self.ray_commit = ray.__commit__ + self.address_info = address_info + + def __getitem__(self, key): + if log_once("ray_context_getitem"): + warnings.warn( + f'Accessing values through ctx["{key}"] is deprecated. ' + f'Use ctx.address_info["{key}"] instead.', + DeprecationWarning, + stacklevel=2, + ) + return self.address_info[key] + + def __len__(self): + if log_once("ray_context_len"): + warnings.warn("len(ctx) is deprecated. Use len(ctx.address_info) instead.") + return len(self.address_info) + + def __iter__(self): + if log_once("ray_context_len"): + warnings.warn( + "iter(ctx) is deprecated. Use iter(ctx.address_info) instead." + ) + return iter(self.address_info) + + def __enter__(self) -> "RayContext": + return self + + def __exit__(self, *exc): + ray.shutdown() + + def disconnect(self): + # Include disconnect() to stay consistent with ClientContext + ray.shutdown() + + +global_worker = Worker() +"""Worker: The global Worker object for this worker process. + +We use a global Worker object to ensure that there is a single worker object +per worker process. +""" + +_global_node = None +"""ray._private.node.Node: The global node object that is created by ray.init().""" + + +def _maybe_modify_runtime_env( + runtime_env: Optional[Dict[str, Any]], _skip_env_hook: bool +) -> Dict[str, Any]: + """ + If you set RAY_ENABLE_UV_RUN_RUNTIME_ENV, which is the default, and run the driver with `uv run`, + this function sets up a runtime environment that replicates the driver's environment to the + workers. Otherwise, if a runtime environment hook is present it will modify the runtime environment. + """ + + if ray_constants.RAY_ENABLE_UV_RUN_RUNTIME_ENV: + from ray._private.runtime_env.uv_runtime_env_hook import ( + _get_uv_run_cmdline, + hook, + ) + + cmdline = _get_uv_run_cmdline() + if cmdline: + # This means the current driver is running in `uv run`, in which case we want + # to propagate the uv environment to the workers. + return hook(runtime_env) + + if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook: + return load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(runtime_env) + + return runtime_env + + +@PublicAPI +@client_mode_hook +def init( + address: Optional[str] = None, + *, + num_cpus: Optional[int] = None, + num_gpus: Optional[int] = None, + resources: Optional[Dict[str, float]] = None, + labels: Optional[Dict[str, str]] = None, + object_store_memory: Optional[int] = None, + local_mode: bool = False, + ignore_reinit_error: bool = False, + include_dashboard: Optional[bool] = None, + dashboard_host: str = ray_constants.DEFAULT_DASHBOARD_IP, + dashboard_port: Optional[int] = None, + job_config: "ray.job_config.JobConfig" = None, + configure_logging: bool = True, + logging_level: int = ray_constants.LOGGER_LEVEL, + logging_format: Optional[str] = None, + logging_config: Optional[LoggingConfig] = None, + log_to_driver: Optional[bool] = None, + namespace: Optional[str] = None, + runtime_env: Optional[Union[Dict[str, Any], "RuntimeEnv"]] = None, # noqa: F821 + enable_resource_isolation: bool = False, + system_reserved_cpu: Optional[float] = None, + system_reserved_memory: Optional[int] = None, + **kwargs, +) -> BaseContext: + """ + Connect to an existing Ray cluster or start one and connect to it. + + This method handles two cases; either a Ray cluster already exists and we + just attach this driver to it or we start all of the processes associated + with a Ray cluster and attach to the newly started cluster. + Note: This method overwrite sigterm handler of the driver process. + + In most cases, it is enough to just call this method with no arguments. + This will autodetect an existing Ray cluster or start a new Ray instance if + no existing cluster is found: + + .. testcode:: + + ray.init() + + To explicitly connect to an existing local cluster, use this as follows. A + ConnectionError will be thrown if no existing local cluster is found. + + .. testcode:: + :skipif: True + + ray.init(address="auto") + + To connect to an existing remote cluster, use this as follows (substituting + in the appropriate address). Note the addition of "ray://" at the beginning + of the address. This requires `ray[client]`. + + .. testcode:: + :skipif: True + + ray.init(address="ray://123.45.67.89:10001") + + More details for starting and connecting to a remote cluster can be found + here: https://docs.ray.io/en/master/cluster/getting-started.html + + You can also define an environment variable called `RAY_ADDRESS` in + the same format as the `address` parameter to connect to an existing + cluster with ray.init() or ray.init(address="auto"). + + Args: + address: The address of the Ray cluster to connect to. The provided + address is resolved as follows: + 1. If a concrete address (e.g., localhost:) is provided, try to + connect to it. Concrete addresses can be prefixed with "ray://" to + connect to a remote cluster. For example, passing in the address + "ray://123.45.67.89:50005" will connect to the cluster at the given + address. + 2. If no address is provided, try to find an existing Ray instance + to connect to. This is done by first checking the environment + variable `RAY_ADDRESS`. If this is not defined, check the address + of the latest cluster started (found in + /tmp/ray/ray_current_cluster) if available. If this is also empty, + then start a new local Ray instance. + 3. If the provided address is "auto", then follow the same process + as above. However, if there is no existing cluster found, this will + throw a ConnectionError instead of starting a new local Ray + instance. + 4. If the provided address is "local", start a new local Ray + instance, even if there is already an existing local Ray instance. + num_cpus: Number of CPUs the user wishes to assign to each + raylet. By default, this is set based on virtual cores. + num_gpus: Number of GPUs the user wishes to assign to each + raylet. By default, this is set based on detected GPUs. + resources: A dictionary mapping the names of custom resources to the + quantities for them available. + labels: [Experimental] The key-value labels of the node. + object_store_memory: The amount of memory (in bytes) to start the + object store with. + By default, this is 30% of available system memory capped by + the shm size and 200G but can be set higher. + local_mode: Deprecated: consider using the Ray Distributed Debugger instead. + ignore_reinit_error: If true, Ray suppresses errors from calling + ray.init() a second time. Ray won't be restarted. + include_dashboard: Boolean flag indicating whether or not to start the + Ray dashboard, which displays the status of the Ray + cluster. If this argument is None, then the UI will be started if + the relevant dependencies are present. + dashboard_host: The host to bind the dashboard server to. Can either be + localhost (127.0.0.1) or 0.0.0.0 (available from all interfaces). + By default, this is set to localhost to prevent access from + external machines. + dashboard_port(int, None): The port to bind the dashboard server to. + Defaults to 8265 and Ray will automatically find a free port if + 8265 is not available. + job_config (ray.job_config.JobConfig): The job configuration. + configure_logging: True (default) if configuration of logging is + allowed here. Otherwise, the user may want to configure it + separately. + logging_level: Logging level for the "ray" logger of the driver process, + defaults to logging.INFO. Ignored unless "configure_logging" is true. + logging_format: Logging format for the "ray" logger of the driver process, + defaults to a string containing a timestamp, filename, line number, and + message. See the source file ray_constants.py for details. Ignored unless + "configure_logging" is true. + logging_config: [Experimental] Logging configuration will be applied to the + root loggers for both the driver process and all worker processes belonging + to the current job. See :class:`~ray.LoggingConfig` for details. + log_to_driver: If true, the output from all of the worker + processes on all nodes will be directed to the driver. + namespace: A namespace is a logical grouping of jobs and named actors. + runtime_env: The runtime environment to use + for this job (see :ref:`runtime-environments` for details). + object_spilling_directory: The path to spill objects to. The same path will + be used as the object store fallback directory as well. + enable_resource_isolation: Enable resource isolation through cgroupv2 by reserving + memory and cpu resources for ray system processes. To use, only cgroupv2 (not cgroupv1) + must be enabled with read and write permissions for the raylet. Cgroup memory and + cpu controllers must also be enabled. + system_reserved_cpu: The number of cpu cores to reserve for ray system processes. + Cores can be fractional i.e. 1.5 means one and a half a cpu core. + By default, the value will be atleast 1 core, and at maximum 3 cores. The default value + is calculated using the formula min(3.0, max(1.0, 0.05 * num_cores_on_the_system)) + This option only works if enable_resource_isolation is True. + system_reserved_memory: The amount of memory (in bytes) to reserve for ray system processes. + By default, the value will be atleast 500MB, and at most 10GB. The default value is + calculated using the formula min(10GB, max(500MB, 0.10 * memory_available_on_the_system)) + This option only works if enable_resource_isolation is True. + _cgroup_path: The path for the cgroup the raylet should use to enforce resource isolation. + By default, the cgroup used for resource isolation will be /sys/fs/cgroup. + The process starting ray must have read/write permissions to this path. + Cgroup memory and cpu controllers be enabled for this cgroup. + This option only works if enable_resource_isolation is True. + _enable_object_reconstruction: If True, when an object stored in + the distributed plasma store is lost due to node failure, Ray will + attempt to reconstruct the object by re-executing the task that + created the object. Arguments to the task will be recursively + reconstructed. If False, then ray.ObjectLostError will be + thrown. + _plasma_directory: Override the plasma mmap file directory. + _node_ip_address: The IP address of the node that we are on. + _driver_object_store_memory: Deprecated. + _memory: Amount of reservable memory resource in bytes rounded + down to the nearest integer. + _redis_username: Prevents external clients without the username + from connecting to Redis if provided. + _redis_password: Prevents external clients without the password + from connecting to Redis if provided. + _temp_dir: If provided, specifies the root temporary + directory for the Ray process. Must be an absolute path. Defaults to an + OS-specific conventional location, e.g., "/tmp/ray". + _metrics_export_port: Port number Ray exposes system metrics + through a Prometheus endpoint. It is currently under active + development, and the API is subject to change. + _system_config: Configuration for overriding + RayConfig defaults. For testing purposes ONLY. + _tracing_startup_hook: If provided, turns on and sets up tracing + for Ray. Must be the name of a function that takes no arguments and + sets up a Tracer Provider, Remote Span Processors, and + (optional) additional instruments. See more at + docs.ray.io/tracing.html. It is currently under active development, + and the API is subject to change. + _node_name: User-provided node name or identifier. Defaults to + the node IP address. + + Returns: + If the provided address includes a protocol, for example by prepending + "ray://" to the address to get "ray://1.2.3.4:10001", then a + ClientContext is returned with information such as settings, server + versions for ray and python, and the dashboard_url. Otherwise, + a RayContext is returned with ray and python versions, and address + information about the started processes. + + Raises: + Exception: An exception is raised if an inappropriate combination of + arguments is passed in. + """ + if log_to_driver is None: + log_to_driver = ray_constants.RAY_LOG_TO_DRIVER + + # Configure the "ray" logger for the driver process. + if configure_logging: + setup_logger(logging_level, logging_format or ray_constants.LOGGER_FORMAT) + else: + logging.getLogger("ray").handlers.clear() + + # Configure the logging settings for the driver process. + if logging_config or ray_constants.RAY_LOGGING_CONFIG_ENCODING: + logging_config = logging_config or LoggingConfig( + encoding=ray_constants.RAY_LOGGING_CONFIG_ENCODING + ) + logging_config._apply() + + # Parse the hidden options + _cgroup_path: str = kwargs.pop("_cgroup_path", None) + + _enable_object_reconstruction: bool = kwargs.pop( + "_enable_object_reconstruction", False + ) + _plasma_directory: Optional[str] = kwargs.pop("_plasma_directory", None) + _object_spilling_directory: Optional[str] = kwargs.pop( + "object_spilling_directory", None + ) + _node_ip_address: str = kwargs.pop("_node_ip_address", None) + _driver_object_store_memory: Optional[int] = kwargs.pop( + "_driver_object_store_memory", None + ) + _memory: Optional[int] = kwargs.pop("_memory", None) + _redis_username: str = kwargs.pop( + "_redis_username", ray_constants.REDIS_DEFAULT_USERNAME + ) + _redis_password: str = kwargs.pop( + "_redis_password", ray_constants.REDIS_DEFAULT_PASSWORD + ) + _temp_dir: Optional[str] = kwargs.pop("_temp_dir", None) + _metrics_export_port: Optional[int] = kwargs.pop("_metrics_export_port", None) + _system_config: Optional[Dict[str, str]] = kwargs.pop("_system_config", None) + _tracing_startup_hook: Optional[Callable] = kwargs.pop( + "_tracing_startup_hook", None + ) + _node_name: str = kwargs.pop("_node_name", None) + # Fix for https://github.com/ray-project/ray/issues/26729 + _skip_env_hook: bool = kwargs.pop("_skip_env_hook", False) + + resource_isolation_config = ResourceIsolationConfig( + enable_resource_isolation=enable_resource_isolation, + cgroup_path=_cgroup_path, + system_reserved_cpu=system_reserved_cpu, + system_reserved_memory=system_reserved_memory, + ) + + # terminate any signal before connecting driver + def sigterm_handler(signum, frame): + sys.exit(signum) + + if threading.current_thread() is threading.main_thread(): + ray._private.utils.set_sigterm_handler(sigterm_handler) + else: + logger.warning( + "SIGTERM handler is not set because current thread " + "is not the main thread." + ) + + # If available, use RAY_ADDRESS to override if the address was left + # unspecified, or set to "auto" in the call to init + address_env_var = os.environ.get(ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE) + if address_env_var and (address is None or address == "auto"): + address = address_env_var + logger.info( + f"Using address {address_env_var} set in the environment " + f"variable {ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE}" + ) + + if address is not None and "://" in address: + # Address specified a protocol, use ray client + builder = ray.client(address, _deprecation_warn_enabled=False) + + # Forward any keyword arguments that were changed from their default + # values to the builder + init_sig = inspect.signature(init) + passed_kwargs = {} + for argument_name, param_obj in init_sig.parameters.items(): + if argument_name in {"kwargs", "address"}: + # kwargs and address are handled separately + continue + default_value = param_obj.default + passed_value = locals()[argument_name] + if passed_value != default_value: + # passed value is different than default, pass to the client + # builder + passed_kwargs[argument_name] = passed_value + passed_kwargs.update(kwargs) + builder._init_args(**passed_kwargs) + ctx = builder.connect() + from ray._common.usage import usage_lib + + if passed_kwargs.get("allow_multiple") is True: + with ctx: + usage_lib.put_pre_init_usage_stats() + else: + usage_lib.put_pre_init_usage_stats() + + usage_lib.record_library_usage("client") + return ctx + + if kwargs.get("allow_multiple"): + raise RuntimeError( + "`allow_multiple` argument is passed to `ray.init` when the " + "ray client is not used (" + f"https://docs.ray.io/en/{get_ray_doc_version()}/cluster" + "/running-applications/job-submission" + "/ray-client.html#connect-to-multiple-ray-clusters-experimental). " + "Do not pass the `allow_multiple` to `ray.init` to fix the issue." + ) + + if kwargs.get("storage"): + raise RuntimeError( + "Cluster-wide storage configuration has been removed. " + "The last Ray version supporting the `storage` argument is `ray==2.47`." + ) + + if kwargs: + # User passed in extra keyword arguments but isn't connecting through + # ray client. Raise an error, since most likely a typo in keyword + unknown = ", ".join(kwargs) + raise RuntimeError(f"Unknown keyword argument(s): {unknown}") + + # Try to increase the file descriptor limit, which is too low by + # default for Ray: https://github.com/ray-project/ray/issues/11239 + try: + import resource + + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + if soft < hard: + # https://github.com/ray-project/ray/issues/12059 + soft = max(soft, min(hard, 65536)) + logger.debug( + f"Automatically increasing RLIMIT_NOFILE to max value of {hard}" + ) + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) + except ValueError: + logger.debug("Failed to raise limit.") + soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + if soft < 4096: + logger.warning( + "File descriptor limit {} is too low for production " + "servers and may result in connection errors. " + "At least 8192 is recommended. --- " + "Fix with 'ulimit -n 8192'".format(soft) + ) + except ImportError: + logger.debug("Could not import resource module (on Windows)") + pass + + if job_config is None: + job_config = ray.job_config.JobConfig() + + if RAY_JOB_CONFIG_JSON_ENV_VAR in os.environ: + injected_job_config_json = json.loads( + os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR) + ) + injected_job_config: ray.job_config.JobConfig = ( + ray.job_config.JobConfig.from_json(injected_job_config_json) + ) + driver_runtime_env = runtime_env + runtime_env = _merge_runtime_env( + injected_job_config.runtime_env, + driver_runtime_env, + override=os.getenv("RAY_OVERRIDE_JOB_RUNTIME_ENV") == "1", + ) + if runtime_env is None: + # None means there was a conflict. + raise ValueError( + "Failed to merge the Job's runtime env " + f"{injected_job_config.runtime_env} with " + f"a ray.init's runtime env {driver_runtime_env} because " + "of a conflict. Specifying the same runtime_env fields " + "or the same environment variable keys is not allowed. " + "Use RAY_OVERRIDE_JOB_RUNTIME_ENV=1 to instruct Ray to " + "combine Job and Driver's runtime environment in the event of " + "a conflict." + ) + + runtime_env = _maybe_modify_runtime_env(runtime_env, _skip_env_hook) + + job_config.set_runtime_env(runtime_env) + # Similarly, we prefer metadata provided via job submission API + for key, value in injected_job_config.metadata.items(): + job_config.set_metadata(key, value) + + # RAY_JOB_CONFIG_JSON_ENV_VAR is only set at ray job manager level and has + # higher priority in case user also provided runtime_env for ray.init() + else: + runtime_env = _maybe_modify_runtime_env(runtime_env, _skip_env_hook) + + if runtime_env: + # Set runtime_env in job_config if passed in as part of ray.init() + job_config.set_runtime_env(runtime_env) + + # Pass the logging_config to job_config to configure loggers of all worker + # processes belonging to the job. + if logging_config is not None: + job_config.set_py_logging_config(logging_config) + + redis_address, gcs_address = None, None + bootstrap_address = services.canonicalize_bootstrap_address(address, _temp_dir) + if bootstrap_address is not None: + gcs_address = bootstrap_address + logger.info("Connecting to existing Ray cluster at address: %s...", gcs_address) + + if local_mode: + driver_mode = LOCAL_MODE + warnings.warn( + "`local_mode` is an experimental feature that is no " + "longer maintained and will be removed in the near future. " + "For debugging consider using the Ray distributed debugger.", + FutureWarning, + stacklevel=2, + ) + else: + driver_mode = SCRIPT_MODE + + global _global_node + + if global_worker.connected: + if ignore_reinit_error: + logger.info("Calling ray.init() again after it has already been called.") + node_id = global_worker.core_worker.get_current_node_id() + return RayContext(dict(_global_node.address_info, node_id=node_id.hex())) + else: + raise RuntimeError( + "Maybe you called ray.init twice by accident? " + "This error can be suppressed by passing in " + "'ignore_reinit_error=True' or by calling " + "'ray.shutdown()' prior to 'ray.init()'." + ) + + _system_config = _system_config or {} + if not isinstance(_system_config, dict): + raise TypeError("The _system_config must be a dict.") + + if bootstrap_address is None: + # In this case, we need to start a new cluster. + + # Don't collect usage stats in ray.init() unless it's a nightly wheel. + from ray._common.usage import usage_lib + + if usage_lib.is_nightly_wheel(): + usage_lib.show_usage_stats_prompt(cli=False) + else: + usage_lib.set_usage_stats_enabled_via_env_var(False) + + # Use a random port by not specifying Redis port / GCS server port. + ray_params = ray._private.parameter.RayParams( + node_ip_address=_node_ip_address, + driver_mode=driver_mode, + redirect_output=None, + num_cpus=num_cpus, + num_gpus=num_gpus, + resources=resources, + labels=labels, + num_redis_shards=None, + redis_max_clients=None, + redis_username=_redis_username, + redis_password=_redis_password, + plasma_directory=_plasma_directory, + object_spilling_directory=_object_spilling_directory, + huge_pages=None, + include_dashboard=include_dashboard, + dashboard_host=dashboard_host, + dashboard_port=dashboard_port, + memory=_memory, + object_store_memory=object_store_memory, + plasma_store_socket_name=None, + temp_dir=_temp_dir, + _system_config=_system_config, + enable_object_reconstruction=_enable_object_reconstruction, + metrics_export_port=_metrics_export_port, + tracing_startup_hook=_tracing_startup_hook, + node_name=_node_name, + resource_isolation_config=resource_isolation_config, + ) + # Start the Ray processes. We set shutdown_at_exit=False because we + # shutdown the node in the ray.shutdown call that happens in the atexit + # handler. We still spawn a reaper process in case the atexit handler + # isn't called. + _global_node = ray._private.node.Node( + ray_params=ray_params, + head=True, + shutdown_at_exit=False, + spawn_reaper=True, + ray_init_cluster=True, + ) + else: + # In this case, we are connecting to an existing cluster. + if num_cpus is not None or num_gpus is not None: + raise ValueError( + "When connecting to an existing cluster, num_cpus " + "and num_gpus must not be provided." + ) + if resources is not None: + raise ValueError( + "When connecting to an existing cluster, " + "resources must not be provided." + ) + if labels is not None: + raise ValueError( + "When connecting to an existing cluster, " + "labels must not be provided." + ) + if object_store_memory is not None: + raise ValueError( + "When connecting to an existing cluster, " + "object_store_memory must not be provided." + ) + if _system_config is not None and len(_system_config) != 0: + raise ValueError( + "When connecting to an existing cluster, " + "_system_config must not be provided." + ) + if _enable_object_reconstruction: + raise ValueError( + "When connecting to an existing cluster, " + "_enable_object_reconstruction must not be provided." + ) + if _node_name is not None: + raise ValueError( + "_node_name cannot be configured when connecting to " + "an existing cluster." + ) + + # In this case, we only need to connect the node. + ray_params = ray._private.parameter.RayParams( + node_ip_address=_node_ip_address, + gcs_address=gcs_address, + redis_address=redis_address, + redis_username=_redis_username, + redis_password=_redis_password, + temp_dir=_temp_dir, + _system_config=_system_config, + enable_object_reconstruction=_enable_object_reconstruction, + metrics_export_port=_metrics_export_port, + ) + try: + _global_node = ray._private.node.Node( + ray_params, + head=False, + shutdown_at_exit=False, + spawn_reaper=False, + connect_only=True, + ) + except (ConnectionError, RuntimeError): + if gcs_address == ray._private.utils.read_ray_address(_temp_dir): + logger.info( + "Failed to connect to the default Ray cluster address at " + f"{gcs_address}. This is most likely due to a previous Ray " + "instance that has since crashed. To reset the default " + "address to connect to, run `ray stop` or restart Ray with " + "`ray start`." + ) + raise ConnectionError + + # Log a message to find the Ray address that we connected to and the + # dashboard URL. + if ray_constants.RAY_OVERRIDE_DASHBOARD_URL in os.environ: + dashboard_url = os.environ.get(ray_constants.RAY_OVERRIDE_DASHBOARD_URL) + else: + dashboard_url = _global_node.webui_url + # Add http protocol to dashboard URL if it doesn't + # already contain a protocol. + if dashboard_url and not urlparse(dashboard_url).scheme: + dashboard_url = "http://" + dashboard_url + + # We logged the address before attempting the connection, so we don't need + # to log it again. + info_str = "Connected to Ray cluster." + if gcs_address is None: + info_str = "Started a local Ray instance." + if dashboard_url: + logger.info( + info_str + " View the dashboard at %s%s%s %s%s", + colorama.Style.BRIGHT, + colorama.Fore.GREEN, + dashboard_url, + colorama.Fore.RESET, + colorama.Style.NORMAL, + ) + else: + logger.info(info_str) + + connect( + _global_node, + _global_node.session_name, + mode=driver_mode, + log_to_driver=log_to_driver, + worker=global_worker, + driver_object_store_memory=_driver_object_store_memory, + job_id=None, + namespace=namespace, + job_config=job_config, + entrypoint=ray._private.utils.get_entrypoint_name(), + ) + if job_config and job_config.code_search_path: + global_worker.set_load_code_from_local(True) + else: + # Because `ray.shutdown()` doesn't reset this flag, for multiple + # sessions in one process, the 2nd `ray.init()` will reuse the + # flag of last session. For example: + # ray.init(load_code_from_local=True) + # ray.shutdown() + # ray.init() + # # Here the flag `load_code_from_local` is still True if we + # # doesn't have this `else` branch. + # ray.shutdown() + global_worker.set_load_code_from_local(False) + + for hook in _post_init_hooks: + hook() + + # Check and show accelerator override warning during driver initialization + from ray._private.ray_constants import env_bool + + override_on_zero = env_bool( + ray._private.accelerators.RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO_ENV_VAR, + True, + ) + if override_on_zero and log_once("ray_accel_env_var_override_on_zero"): + warnings.warn( + "Tip: In future versions of Ray, Ray will no longer override accelerator " + "visible devices env var if num_gpus=0 or num_gpus=None (default). To enable " + "this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0", + FutureWarning, + ) + + node_id = global_worker.core_worker.get_current_node_id() + global_node_address_info = _global_node.address_info.copy() + global_node_address_info["webui_url"] = _remove_protocol_from_url(dashboard_url) + return RayContext(dict(global_node_address_info, node_id=node_id.hex())) + + +# Functions to run as callback after a successful ray init. +_post_init_hooks = [] + + +@PublicAPI +@client_mode_hook +@with_connect_or_shutdown_lock +def shutdown(_exiting_interpreter: bool = False): + """Disconnect the worker, and terminate processes started by ray.init(). + + This will automatically run at the end when a Python process that uses Ray + exits. It is ok to run this twice in a row. The primary use case for this + function is to cleanup state between tests. + + Note that this will clear any remote function definitions, actor + definitions, and existing actors, so if you wish to use any previously + defined remote functions or actors after calling ray.shutdown(), then you + need to redefine them. If they were defined in an imported module, then you + will need to reload the module. + + Args: + _exiting_interpreter: True if this is called by the atexit hook + and false otherwise. If we are exiting the interpreter, we will + wait a little while to print any extra error messages. + """ + # Make sure to clean up compiled dag node if exists. + from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags + + _shutdown_all_compiled_dags() + global_worker.shutdown_gpu_object_manager() + + if _exiting_interpreter and global_worker.mode == SCRIPT_MODE: + # This is a duration to sleep before shutting down everything in order + # to make sure that log messages finish printing. + time.sleep(0.5) + disconnect(_exiting_interpreter) + + # disconnect internal kv + if hasattr(global_worker, "gcs_client"): + del global_worker.gcs_client + _internal_kv_reset() + + # We need to destruct the core worker here because after this function, + # we will tear down any processes spawned by ray.init() and the background + # IO thread in the core worker doesn't currently handle that gracefully. + if hasattr(global_worker, "core_worker"): + if global_worker.mode == SCRIPT_MODE or global_worker.mode == LOCAL_MODE: + global_worker.core_worker.shutdown_driver() + del global_worker.core_worker + # We need to reset function actor manager to clear the context + global_worker.function_actor_manager = FunctionActorManager(global_worker) + # Disconnect global state from GCS. + ray._private.state.state.disconnect() + + # Shut down the Ray processes. + global _global_node + if _global_node is not None: + if _global_node.is_head(): + _global_node.destroy_external_storage() + _global_node.kill_all_processes(check_alive=False, allow_graceful=True) + _global_node = None + + # TODO(rkn): Instead of manually resetting some of the worker fields, we + # should simply set "global_worker" to equal "None" or something like that. + global_worker.set_mode(None) + global_worker.set_cached_job_id(None) + + +atexit.register(shutdown, True) + +# Define a custom excepthook so that if the driver exits with an exception, we +# can push that exception to Redis. +normal_excepthook = sys.excepthook + + +def custom_excepthook(type, value, tb): + import ray.core.generated.common_pb2 as common_pb2 + + # If this is a driver, push the exception to GCS worker table. + if global_worker.mode == SCRIPT_MODE and hasattr(global_worker, "worker_id"): + error_message = "".join(traceback.format_tb(tb)) + worker_id = global_worker.worker_id + worker_type = common_pb2.DRIVER + worker_info = {"exception": error_message} + + ray._private.state.state._check_connected() + ray._private.state.state.add_worker(worker_id, worker_type, worker_info) + # Call the normal excepthook. + normal_excepthook(type, value, tb) + + +sys.excepthook = custom_excepthook + + +def print_to_stdstream(data, ignore_prefix: bool): + should_dedup = data.get("pid") not in ["autoscaler"] + + if data["is_err"]: + if should_dedup: + batches = stderr_deduplicator.deduplicate(data) + else: + batches = [data] + sink = sys.stderr + else: + if should_dedup: + batches = stdout_deduplicator.deduplicate(data) + else: + batches = [data] + sink = sys.stdout + + for batch in batches: + print_worker_logs(batch, sink, ignore_prefix) + + +# Start time of this process, used for relative time logs. +t0 = time.time() +autoscaler_log_fyi_printed = False + + +def filter_autoscaler_events(lines: List[str]) -> Iterator[str]: + """Given raw log lines from the monitor, return only autoscaler events. + + For Autoscaler V1: + Autoscaler events are denoted by the ":event_summary:" magic token. + For Autoscaler V2: + Autoscaler events are published from log_monitor.py which read + them from the `event_AUTOSCALER.log`. + """ + + if not ray_constants.AUTOSCALER_EVENTS: + return + + AUTOSCALER_LOG_FYI = ( + "Tip: use `ray status` to view detailed " + "cluster status. To disable these " + "messages, set RAY_SCHEDULER_EVENTS=0." + ) + + def autoscaler_log_fyi_needed() -> bool: + global autoscaler_log_fyi_printed + if not autoscaler_log_fyi_printed: + autoscaler_log_fyi_printed = True + return True + return False + + from ray.autoscaler.v2.utils import is_autoscaler_v2 + + if is_autoscaler_v2(): + from ray._private.event.event_logger import filter_event_by_level, parse_event + + for event_line in lines: + if autoscaler_log_fyi_needed(): + yield AUTOSCALER_LOG_FYI + + event = parse_event(event_line) + if not event or not event.message: + continue + + if filter_event_by_level( + event, ray_constants.RAY_LOG_TO_DRIVER_EVENT_LEVEL + ): + continue + + yield event.message + else: + # Print out autoscaler events only, ignoring other messages. + for line in lines: + if ray_constants.LOG_PREFIX_EVENT_SUMMARY in line: + if autoscaler_log_fyi_needed(): + yield AUTOSCALER_LOG_FYI + # The event text immediately follows the ":event_summary:" + # magic token. + yield line.split(ray_constants.LOG_PREFIX_EVENT_SUMMARY)[1] + + +def time_string() -> str: + """Return the relative time from the start of this job. + + For example, 15m30s. + """ + delta = time.time() - t0 + hours = 0 + minutes = 0 + while delta > 3600: + hours += 1 + delta -= 3600 + while delta > 60: + minutes += 1 + delta -= 60 + output = "" + if hours: + output += f"{hours}h" + if minutes: + output += f"{minutes}m" + output += f"{int(delta)}s" + return output + + +# When we enter a breakpoint, worker logs are automatically disabled via this. +_worker_logs_enabled = True + + +def print_worker_logs( + data: Dict[str, str], print_file: Any, ignore_prefix: bool = False +): + if not _worker_logs_enabled: + return + + def prefix_for(data: Dict[str, str]) -> str: + """The PID prefix for this log line.""" + if data.get("pid") in ["autoscaler", "raylet"]: + return "" + else: + res = "pid=" + if data.get("actor_name"): + res = f"{data['actor_name']} {res}" + elif data.get("task_name"): + res = f"{data['task_name']} {res}" + return res + + def message_for(data: Dict[str, str], line: str) -> str: + """The printed message of this log line.""" + if ray_constants.LOG_PREFIX_INFO_MESSAGE in line: + return line.split(ray_constants.LOG_PREFIX_INFO_MESSAGE)[1] + return line + + def color_for(data: Dict[str, str], line: str) -> str: + """The color for this log line.""" + if ( + data.get("pid") == "raylet" + and ray_constants.LOG_PREFIX_INFO_MESSAGE not in line + ): + return colorama.Fore.YELLOW + elif data.get("pid") == "autoscaler": + if "Error:" in line or "Warning:" in line: + return colorama.Fore.YELLOW + else: + return colorama.Fore.CYAN + elif os.getenv("RAY_COLOR_PREFIX") == "1": + colors = [ + # colorama.Fore.BLUE, # Too dark + colorama.Fore.MAGENTA, + colorama.Fore.CYAN, + colorama.Fore.GREEN, + # colorama.Fore.WHITE, # Too light + # colorama.Fore.RED, + colorama.Fore.LIGHTBLACK_EX, + colorama.Fore.LIGHTBLUE_EX, + # colorama.Fore.LIGHTCYAN_EX, # Too light + # colorama.Fore.LIGHTGREEN_EX, # Too light + colorama.Fore.LIGHTMAGENTA_EX, + # colorama.Fore.LIGHTWHITE_EX, # Too light + # colorama.Fore.LIGHTYELLOW_EX, # Too light + ] + pid = data.get("pid", 0) + try: + i = int(pid) + except ValueError: + i = 0 + return colors[i % len(colors)] + else: + return colorama.Fore.CYAN + + if data.get("pid") == "autoscaler": + pid = "autoscaler +{}".format(time_string()) + lines = filter_autoscaler_events(data.get("lines", [])) + else: + pid = data.get("pid") + lines = data.get("lines", []) + + ip = data.get("ip") + ip_prefix = "" if ip == data.get("localhost") else f", ip={ip}" + for line in lines: + if RAY_TQDM_MAGIC in line: + process_tqdm(line) + else: + hide_tqdm() + # If RAY_COLOR_PREFIX=0, do not wrap with any color codes + if os.getenv("RAY_COLOR_PREFIX") == "0": + color_pre = "" + color_post = "" + else: + color_pre = color_for(data, line) + color_post = colorama.Style.RESET_ALL + + if ignore_prefix: + print( + f"{message_for(data, line)}", + file=print_file, + ) + else: + print( + f"{color_pre}({prefix_for(data)}{pid}{ip_prefix}){color_post} " + f"{message_for(data, line)}", + file=print_file, + ) + + # Restore once at end of batch to avoid excess hiding/unhiding of tqdm. + restore_tqdm() + + +def process_tqdm(line): + """Experimental distributed tqdm: see ray.experimental.tqdm_ray.""" + try: + data = json.loads(line) + tqdm_ray.instance().process_state_update(data) + except Exception: + if log_once("tqdm_corruption"): + logger.warning( + f"[tqdm_ray] Failed to decode {line}, this may be due to " + "logging too fast. This warning will not be printed again." + ) + + +def hide_tqdm(): + """Hide distributed tqdm bars temporarily to avoid conflicts with other logs.""" + tqdm_ray.instance().hide_bars() + + +def restore_tqdm(): + """Undo hide_tqdm().""" + tqdm_ray.instance().unhide_bars() + + +def listen_error_messages(worker, threads_stopped): + """Listen to error messages in the background on the driver. + + This runs in a separate thread on the driver and pushes (error, time) + tuples to be published. + + Args: + worker: The worker class that this thread belongs to. + threads_stopped (threading.Event): A threading event used to signal to + the thread that it should exit. + """ + + # TODO: we should just subscribe to the errors for this specific job. + worker.gcs_error_subscriber.subscribe() + + try: + if _internal_kv_initialized(): + # Get any autoscaler errors that occurred before the call to + # subscribe. + error_message = _internal_kv_get(ray_constants.DEBUG_AUTOSCALING_ERROR) + if error_message is not None: + logger.warning(error_message.decode()) + while True: + # Exit if received a signal that the thread should stop. + if threads_stopped.is_set(): + return + + _, error_data = worker.gcs_error_subscriber.poll() + if error_data is None: + continue + if error_data["job_id"] is not None and error_data["job_id"] not in [ + worker.current_job_id.binary(), + JobID.nil().binary(), + ]: + continue + + error_message = error_data["error_message"] + print_to_stdstream( + { + "lines": [error_message], + "pid": "raylet", + "is_err": False, + }, + ignore_prefix=False, + ) + except (OSError, ConnectionError) as e: + logger.error(f"listen_error_messages: {e}") + + +@PublicAPI +@client_mode_hook +def is_initialized() -> bool: + """Check if ray.init has been called yet. + + Returns: + True if ray.init has already been called and false otherwise. + """ + return ray._private.worker.global_worker.connected + + +@with_connect_or_shutdown_lock +def connect( + node, + session_name: str, + mode=WORKER_MODE, + log_to_driver: bool = False, + worker=global_worker, + driver_object_store_memory: Optional[int] = None, + job_id=None, + namespace: Optional[str] = None, + job_config=None, + runtime_env_hash: int = 0, + startup_token: int = 0, + ray_debugger_external: bool = False, + entrypoint: str = "", + worker_launch_time_ms: int = -1, + worker_launched_time_ms: int = -1, + debug_source: str = "", +): + """Connect this worker to the raylet, to Plasma, and to GCS. + + Args: + node (ray._private.node.Node): The node to connect. + session_name: The current Ray session name. + mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. + log_to_driver: If true, then output from all of the worker + processes on all nodes will be directed to the driver. + worker: The ray.Worker instance. + driver_object_store_memory: Deprecated. + job_id: The ID of job. If it's None, then we will generate one. + namespace: Namespace to use. + job_config (ray.job_config.JobConfig): The job configuration. + runtime_env_hash: The hash of the runtime env for this worker. + startup_token: The startup token of the process assigned to + it during startup as a command line argument. + ray_debugger_external: If True, make the debugger external to the + node this worker is running on. + entrypoint: The name of the entrypoint script. Ignored if the + mode != SCRIPT_MODE + worker_launch_time_ms: The time when the worker process for this worker + is launched. If the worker is not launched by raylet (e.g., + driver), this must be -1 (default value). + worker_launched_time_ms: The time when the worker process for this worker + finshes launching. If the worker is not launched by raylet (e.g., + driver), this must be -1 (default value). + debug_source: Source information for `CoreWorker`, used for debugging and informational purpose, rather than functional purpose. + """ + # Do some basic checking to make sure we didn't call ray.init twice. + error_message = "Perhaps you called ray.init twice by accident?" + assert not worker.connected, error_message + + # FIXME: tmp disable faulthandler + # # Enable nice stack traces on SIGSEGV etc. + # try: + # if not faulthandler.is_enabled(): + # faulthandler.enable(all_threads=False) + # except io.UnsupportedOperation: + # pass # ignore + + worker.gcs_client = node.get_gcs_client() + assert worker.gcs_client is not None + _initialize_internal_kv(worker.gcs_client) + ray._private.state.state._initialize_global_state( + ray._raylet.GcsClientOptions.create( + node.gcs_address, + node.cluster_id.hex(), + allow_cluster_id_nil=False, + fetch_cluster_id_if_nil=False, + ) + ) + # Initialize some fields. + if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE): + # We should not specify the job_id if it's `WORKER_MODE`. + assert job_id is None + job_id = JobID.nil() + else: + # This is the code path of driver mode. + if job_id is None: + job_id = ray._private.state.next_job_id() + + if mode is not SCRIPT_MODE and mode is not LOCAL_MODE: + process_name = ray_constants.WORKER_PROCESS_TYPE_IDLE_WORKER + if mode is SPILL_WORKER_MODE: + process_name = ray_constants.WORKER_PROCESS_TYPE_SPILL_WORKER_IDLE + elif mode is RESTORE_WORKER_MODE: + process_name = ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE + ray._raylet.setproctitle(process_name) + + if not isinstance(job_id, JobID): + raise TypeError("The type of given job id must be JobID.") + + # All workers start out as non-actors. A worker can be turned into an actor + # after it is created. + worker.node = node + worker.set_mode(mode) + + # For driver's check that the version information matches the version + # information that the Ray cluster was started with. + try: + node.check_version_info() + except Exception as e: + if mode == SCRIPT_MODE: + raise e + elif mode == WORKER_MODE: + traceback_str = traceback.format_exc() + ray._private.utils.publish_error_to_driver( + ray_constants.VERSION_MISMATCH_PUSH_ERROR, + traceback_str, + gcs_client=worker.gcs_client, + ) + + driver_name = "" + interactive_mode = False + if mode == SCRIPT_MODE: + import __main__ as main + + if hasattr(main, "__file__"): + driver_name = main.__file__ + else: + interactive_mode = True + driver_name = "INTERACTIVE MODE" + elif not LOCAL_MODE: + raise ValueError("Invalid worker mode. Expected DRIVER, WORKER or LOCAL.") + + gcs_options = ray._raylet.GcsClientOptions.create( + node.gcs_address, + node.cluster_id.hex(), + allow_cluster_id_nil=False, + fetch_cluster_id_if_nil=False, + ) + if job_config is None: + job_config = ray.job_config.JobConfig() + + if namespace is not None: + ray._private.utils.validate_namespace(namespace) + + # The namespace field of job config may have already been set in code + # paths such as the client. + job_config.set_ray_namespace(namespace) + + # Make sure breakpoint() in the user's code will + # invoke the Ray debugger if we are in a worker or actor process + # (but not on the driver). + if mode == WORKER_MODE: + os.environ["PYTHONBREAKPOINT"] = "ray.util.rpdb.set_trace" + else: + # Add hook to suppress worker logs during breakpoint. + os.environ["PYTHONBREAKPOINT"] = "ray.util.rpdb._driver_set_trace" + + worker.ray_debugger_external = ray_debugger_external + + # If it's a driver and it's not coming from ray client, we'll prepare the + # environment here. If it's ray client, the environment will be prepared + # at the server side. + if mode == SCRIPT_MODE and not job_config._client_job and job_config.runtime_env: + scratch_dir: str = worker.node.get_runtime_env_dir_path() + runtime_env = job_config.runtime_env or {} + runtime_env = upload_py_modules_if_needed( + runtime_env, scratch_dir, logger=logger + ) + runtime_env = upload_working_dir_if_needed( + runtime_env, scratch_dir, logger=logger + ) + runtime_env = upload_worker_process_setup_hook_if_needed( + runtime_env, + worker, + ) + # Remove excludes, it isn't relevant after the upload step. + runtime_env.pop("excludes", None) + job_config.set_runtime_env(runtime_env, validate=True) + + if mode == SCRIPT_MODE: + # Add the directory containing the script that is running to the Python + # paths of the workers. Also add the current directory. Note that this + # assumes that the directory structures on the machines in the clusters + # are the same. + # When using an interactive shell, there is no script directory. + # We also want to skip adding script directory when running from dashboard. + code_paths = [] + if not interactive_mode and not ( + namespace and namespace == ray._raylet.RAY_INTERNAL_DASHBOARD_NAMESPACE + ): + script_directory = os.path.dirname(os.path.realpath(sys.argv[0])) + # If driver's sys.path doesn't include the script directory + # (e.g driver is started via `python -m`, + # see https://peps.python.org/pep-0338/), + # then we shouldn't add it to the workers. + if script_directory in sys.path: + code_paths.append(script_directory) + # In client mode, if we use runtime envs with "working_dir", then + # it'll be handled automatically. Otherwise, add the current dir. + if not job_config._client_job and not job_config._runtime_env_has_working_dir(): + current_directory = os.path.abspath(os.path.curdir) + code_paths.append(current_directory) + if len(code_paths) != 0: + job_config._py_driver_sys_path.extend(code_paths) + + serialized_job_config = job_config._serialize() + if not node.should_redirect_logs(): + # Logging to stderr, so give core worker empty logs directory. + logs_dir = "" + else: + logs_dir = node.get_logs_dir_path() + + worker.core_worker = ray._raylet.CoreWorker( + mode, + node.plasma_store_socket_name, + node.raylet_socket_name, + job_id, + gcs_options, + logs_dir, + node.node_ip_address, + node.node_manager_port, + (mode == LOCAL_MODE), + driver_name, + serialized_job_config, + node.metrics_agent_port, + runtime_env_hash, + startup_token, + session_name, + node.cluster_id.hex(), + "" if mode != SCRIPT_MODE else entrypoint, + worker_launch_time_ms, + worker_launched_time_ms, + debug_source, + ) + + if mode == SCRIPT_MODE: + worker_id = worker.worker_id + worker.gcs_error_subscriber = ray._raylet.GcsErrorSubscriber( + worker_id=worker_id, address=worker.gcs_client.address + ) + worker.gcs_log_subscriber = ray._raylet.GcsLogSubscriber( + worker_id=worker_id, address=worker.gcs_client.address + ) + + if driver_object_store_memory is not None: + logger.warning( + "`driver_object_store_memory` is deprecated" + " and will be removed in the future." + ) + + # If this is a driver running in SCRIPT_MODE, start a thread to print error + # messages asynchronously in the background. Ideally the scheduler would + # push messages to the driver's worker service, but we ran into bugs when + # trying to properly shutdown the driver's worker service, so we are + # temporarily using this implementation which constantly queries the + # scheduler for new error messages. + if mode == SCRIPT_MODE: + worker.listener_thread = threading.Thread( + target=listen_error_messages, + name="ray_listen_error_messages", + args=(worker, worker.threads_stopped), + ) + worker.listener_thread.daemon = True + worker.listener_thread.start() + # If the job's logging config is set, don't add the prefix + # (task/actor's name and its PID) to the logs. + ignore_prefix = global_worker.job_logging_config is not None + + if log_to_driver: + global_worker_stdstream_dispatcher.add_handler( + "ray_print_logs", + functools.partial(print_to_stdstream, ignore_prefix=ignore_prefix), + ) + worker.logger_thread = threading.Thread( + target=worker.print_logs, name="ray_print_logs" + ) + worker.logger_thread.daemon = True + worker.logger_thread.start() + + # Setup tracing here + tracing_hook_val = worker.gcs_client.internal_kv_get( + b"tracing_startup_hook", ray_constants.KV_NAMESPACE_TRACING + ) + if tracing_hook_val is not None: + ray.util.tracing.tracing_helper._enable_tracing() + if not getattr(ray, "__traced__", False): + _setup_tracing = _import_from_string(tracing_hook_val.decode("utf-8")) + _setup_tracing() + ray.__traced__ = True + + # Mark the worker as connected. + worker.set_is_connected(True) + + +def disconnect(exiting_interpreter=False): + """Disconnect this worker from the raylet and object store.""" + # Reset the list of cached remote functions and actors so that if more + # remote functions or actors are defined and then connect is called again, + # the remote functions will be exported. This is mostly relevant for the + # tests. + worker = global_worker + if worker.connected: + # Shutdown all of the threads that we've started. TODO(rkn): This + # should be handled cleanly in the worker object's destructor and not + # in this disconnect method. + worker.threads_stopped.set() + if hasattr(worker, "gcs_error_subscriber"): + worker.gcs_error_subscriber.close() + if hasattr(worker, "gcs_log_subscriber"): + worker.gcs_log_subscriber.close() + if hasattr(worker, "listener_thread"): + worker.listener_thread.join() + if hasattr(worker, "logger_thread"): + worker.logger_thread.join() + worker.threads_stopped.clear() + + # Ignore the prefix if the logging config is set. + ignore_prefix = worker.job_logging_config is not None + for leftover in stdout_deduplicator.flush(): + print_worker_logs(leftover, sys.stdout, ignore_prefix) + for leftover in stderr_deduplicator.flush(): + print_worker_logs(leftover, sys.stderr, ignore_prefix) + global_worker_stdstream_dispatcher.remove_handler("ray_print_logs") + + worker.node = None # Disconnect the worker from the node. + worker.serialization_context_map.clear() + try: + ray_actor = ray.actor + except AttributeError: + ray_actor = None # This can occur during program termination + if ray_actor is not None: + ray_actor._ActorClassMethodMetadata.reset_cache() + + # Mark the worker as disconnected. + worker.set_is_connected(False) + + +@contextmanager +def _changeproctitle(title, next_title): + if _mode() is not LOCAL_MODE: + ray._raylet.setproctitle(title) + try: + yield + finally: + if _mode() is not LOCAL_MODE: + ray._raylet.setproctitle(next_title) + + +@DeveloperAPI +def show_in_dashboard(message: str, key: str = "", dtype: str = "text"): + """Display message in dashboard. + + Display message for the current task or actor in the dashboard. + For example, this can be used to display the status of a long-running + computation. + + Args: + message: Message to be displayed. + key: The key name for the message. Multiple message under + different keys will be displayed at the same time. Messages + under the same key will be overridden. + dtype: The type of message for rendering. One of the + following: text, html. + """ + worker = global_worker + worker.check_connected() + + acceptable_dtypes = {"text", "html"} + assert dtype in acceptable_dtypes, f"dtype accepts only: {acceptable_dtypes}" + + message_wrapped = {"message": message, "dtype": dtype} + message_encoded = json.dumps(message_wrapped).encode() + + worker.core_worker.set_webui_display(key.encode(), message_encoded) + + +# Global variable to make sure we only send out the warning once. +blocking_get_inside_async_warned = False + + +@overload +def get( + object_refs: "Sequence[ObjectRef[Any]]", *, timeout: Optional[float] = None +) -> List[Any]: + ... + + +@overload +def get( + object_refs: "Sequence[ObjectRef[R]]", *, timeout: Optional[float] = None +) -> List[R]: + ... + + +@overload +def get(object_refs: "ObjectRef[R]", *, timeout: Optional[float] = None) -> R: + ... + + +@overload +def get( + object_refs: Sequence[CompiledDAGRef], *, timeout: Optional[float] = None +) -> List[Any]: + ... + + +@overload +def get(object_refs: CompiledDAGRef, *, timeout: Optional[float] = None) -> Any: + ... + + +@PublicAPI +@client_mode_hook +def get( + object_refs: Union[ + "ObjectRef[Any]", + Sequence["ObjectRef[Any]"], + CompiledDAGRef, + Sequence[CompiledDAGRef], + ], + *, + timeout: Optional[float] = None, + _tensor_transport: Optional[str] = None, +) -> Union[Any, List[Any]]: + """Get a remote object or a list of remote objects from the object store. + + This method blocks until the object corresponding to the object ref is + available in the local object store. If this object is not in the local + object store, it will be shipped from an object store that has it (once the + object has been created). If object_refs is a list, then the objects + corresponding to each object in the list will be returned. + + Ordering for an input list of object refs is preserved for each object + returned. That is, if an object ref to A precedes an object ref to B in the + input list, then A will precede B in the returned list. + + This method will issue a warning if it's running inside async context, + you can use ``await object_ref`` instead of ``ray.get(object_ref)``. For + a list of object refs, you can use ``await asyncio.gather(*object_refs)``. + + Passing :class:`~ObjectRefGenerator` is not allowed. + + Related patterns and anti-patterns: + + - :doc:`/ray-core/patterns/ray-get-loop` + - :doc:`/ray-core/patterns/unnecessary-ray-get` + - :doc:`/ray-core/patterns/ray-get-submission-order` + - :doc:`/ray-core/patterns/ray-get-too-many-objects` + + + Args: + object_refs: Object ref of the object to get or a list of object refs + to get. + timeout (Optional[float]): The maximum amount of time in seconds to + wait before returning. Set this to None will block until the + corresponding object becomes available. Setting ``timeout=0`` will + return the object immediately if it's available, else raise + GetTimeoutError in accordance with the above docstring. + _tensor_transport: [Alpha] The tensor transport to use to fetch `torch.Tensors` found in the Ray Direct Transport object. Currently, this supports "object_store" and "nixl". + + Returns: + A Python object or a list of Python objects. + + Raises: + GetTimeoutError: A GetTimeoutError is raised if a timeout is set and + the get takes longer than timeout to return. + Exception: An exception is raised immediately if any task that created + the object or that created one of the objects raised an exception, + without waiting for the remaining ones to finish. + """ + worker = global_worker + worker.check_connected() + + if hasattr(worker, "core_worker") and worker.core_worker.current_actor_is_asyncio(): + global blocking_get_inside_async_warned + if not blocking_get_inside_async_warned: + if ray_constants.env_bool( + RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR, + True, + ): + logger.warning( + "Using blocking ray.get inside async actor. " + "This blocks the event loop. Please use `await` " + "on object ref with asyncio.gather if you want to " + "yield execution to the event loop instead." + ) + blocking_get_inside_async_warned = True + + with profiling.profile("ray.get"): + # TODO(sang): Should make ObjectRefGenerator + # compatible to ray.get for dataset. + if isinstance(object_refs, ObjectRefGenerator): + return object_refs + + if isinstance(object_refs, CompiledDAGRef): + return object_refs.get(timeout=timeout) + + if isinstance(object_refs, list): + all_compiled_dag_refs = True + any_compiled_dag_refs = False + for object_ref in object_refs: + is_dag_ref = isinstance(object_ref, CompiledDAGRef) + all_compiled_dag_refs = all_compiled_dag_refs and is_dag_ref + any_compiled_dag_refs = any_compiled_dag_refs or is_dag_ref + if all_compiled_dag_refs: + return [object_ref.get(timeout=timeout) for object_ref in object_refs] + elif any_compiled_dag_refs: + raise ValueError( + "Invalid type of object refs. 'object_refs' must be a list of " + "CompiledDAGRefs if there is any CompiledDAGRef within it. " + ) + + is_individual_id = isinstance(object_refs, ray.ObjectRef) + if is_individual_id: + object_refs = [object_refs] + + if not isinstance(object_refs, list): + raise ValueError( + f"Invalid type of object refs, {type(object_refs)}, is given. " + "'object_refs' must either be an ObjectRef or a list of ObjectRefs. " + ) + + values, debugger_breakpoint = worker.get_objects( + object_refs, timeout=timeout, _tensor_transport=_tensor_transport + ) + for i, value in enumerate(values): + if isinstance(value, RayError): + if isinstance(value, ray.exceptions.ObjectLostError): + worker.core_worker.log_plasma_usage() + if isinstance(value, RayTaskError): + raise value.as_instanceof_cause() + else: + raise value + + if is_individual_id: + values = values[0] + + if debugger_breakpoint != b"": + frame = sys._getframe().f_back + rdb = ray.util.pdb._connect_ray_pdb( + host=None, + port=None, + patch_stdstreams=False, + quiet=None, + breakpoint_uuid=( + debugger_breakpoint.decode() if debugger_breakpoint else None + ), + debugger_external=worker.ray_debugger_external, + ) + rdb.set_trace(frame=frame) + + return values + + +@PublicAPI +@client_mode_hook +def put( + value: Any, + *, + _owner: Optional["ray.actor.ActorHandle"] = None, + _tensor_transport: str = "object_store", +) -> "ray.ObjectRef": + """Store an object in the object store. + + The object may not be evicted while a reference to the returned ID exists. + + Related patterns and anti-patterns: + + - :doc:`/ray-core/patterns/return-ray-put` + - :doc:`/ray-core/patterns/pass-large-arg-by-value` + - :doc:`/ray-core/patterns/closure-capture-large-objects` + + Args: + value: The Python object to be stored. + _owner [Experimental]: The actor that should own this object. This + allows creating objects with lifetimes decoupled from that of the + creating process. The owner actor must be passed a reference to the + object prior to the object creator exiting, otherwise the reference + will still be lost. *Note that this argument is an experimental API + and should be avoided if possible.* + _tensor_transport: [Alpha] The tensor transport to use for the GPU object. Currently, this supports "object_store" and "nixl" for tensor transport in ray.put(). + + Returns: + The object ref assigned to this value. + """ + worker = global_worker + worker.check_connected() + + if _owner is None: + serialize_owner_address = None + elif isinstance(_owner, ray.actor.ActorHandle): + # Ensure `ray._private.state.state.global_state_accessor` is not None + ray._private.state.state._check_connected() + serialize_owner_address = ( + ray._raylet._get_actor_serialized_owner_address_or_none( + ray._private.state.state.global_state_accessor.get_actor_info( + _owner._actor_id + ) + ) + ) + if not serialize_owner_address: + raise RuntimeError(f"{_owner} is not alive, it's worker_id is empty!") + else: + raise TypeError(f"Expect an `ray.actor.ActorHandle`, but got: {type(_owner)}") + + with profiling.profile("ray.put"): + try: + object_ref = worker.put_object( + value, + owner_address=serialize_owner_address, + _tensor_transport=_tensor_transport, + ) + except ObjectStoreFullError: + logger.info( + "Put failed since the value was either too large or the " + "store was full of pinned objects." + ) + raise + return object_ref + + +# Global variable to make sure we only send out the warning once. +blocking_wait_inside_async_warned = False + + +@PublicAPI +@client_mode_hook +def wait( + ray_waitables: List[Union[ObjectRef, ObjectRefGenerator]], + *, + num_returns: int = 1, + timeout: Optional[float] = None, + fetch_local: bool = True, +) -> Tuple[ + List[Union[ObjectRef, ObjectRefGenerator]], + List[Union[ObjectRef, ObjectRefGenerator]], +]: + """Return a list of IDs that are ready and a list of IDs that are not. + + If timeout is set, the function returns either when the requested number of + IDs are ready or when the timeout is reached, whichever occurs first. If it + is not set, the function simply waits until that number of objects is ready + and returns that exact number of object refs. + + `ray_waitables` is a list of :class:`~ray.ObjectRef` and + :class:`~ray.ObjectRefGenerator`. + + The method returns two lists, ready and unready `ray_waitables`. + + ObjectRef: + object refs that correspond to objects that are available + in the object store are in the first list. + The rest of the object refs are in the second list. + + ObjectRefGenerator: + Generators whose next reference (that will be obtained + via `next(generator)`) has a corresponding object available + in the object store are in the first list. + All other generators are placed in the second list. + + Ordering of the input list of ray_waitables is preserved. That is, if A + precedes B in the input list, and both are in the ready list, then A will + precede B in the ready list. This also holds true if A and B are both in + the remaining list. + + This method will issue a warning if it's running inside an async context. + Instead of ``ray.wait(ray_waitables)``, you can use + ``await asyncio.wait(ray_waitables)``. + + Related patterns and anti-patterns: + + - :doc:`/ray-core/patterns/limit-pending-tasks` + - :doc:`/ray-core/patterns/ray-get-submission-order` + + Args: + ray_waitables: List of :class:`~ObjectRef` or + :class:`~ObjectRefGenerator` for objects that may or may + not be ready. Note that these must be unique. + num_returns: The number of ray_waitables that should be returned. + timeout: The maximum amount of time in seconds to wait before + returning. + fetch_local: If True, wait for the object to be downloaded onto + the local node before returning it as ready. If the `ray_waitable` + is a generator, it will wait until the next object in the generator + is downloaed. If False, ray.wait() will not trigger fetching of + objects to the local node and will return immediately once the + object is available anywhere in the cluster. + + Returns: + A list of object refs that are ready and a list of the remaining object + IDs. + """ + worker = global_worker + worker.check_connected() + + if ( + hasattr(worker, "core_worker") + and worker.core_worker.current_actor_is_asyncio() + and timeout != 0 + ): + global blocking_wait_inside_async_warned + if not blocking_wait_inside_async_warned: + logger.debug( + "Using blocking ray.wait inside async method. " + "This blocks the event loop. Please use `await` " + "on object ref with asyncio.wait. " + ) + blocking_wait_inside_async_warned = True + + if isinstance(ray_waitables, ObjectRef) or isinstance( + ray_waitables, ObjectRefGenerator + ): + raise TypeError( + "wait() expected a list of ray.ObjectRef or ray.ObjectRefGenerator" + ", got a single ray.ObjectRef or ray.ObjectRefGenerator " + f"{ray_waitables}" + ) + + if not isinstance(ray_waitables, list): + raise TypeError( + "wait() expected a list of ray.ObjectRef or " + "ray.ObjectRefGenerator, " + f"got {type(ray_waitables)}" + ) + + if timeout is not None and timeout < 0: + raise ValueError( + "The 'timeout' argument must be nonnegative. " f"Received {timeout}" + ) + + for ray_waitable in ray_waitables: + if not isinstance(ray_waitable, ObjectRef) and not isinstance( + ray_waitable, ObjectRefGenerator + ): + raise TypeError( + "wait() expected a list of ray.ObjectRef or " + "ray.ObjectRefGenerator, " + f"got list containing {type(ray_waitable)}" + ) + worker.check_connected() + + # TODO(swang): Check main thread. + with profiling.profile("ray.wait"): + # TODO(rkn): This is a temporary workaround for + # https://github.com/ray-project/ray/issues/997. However, it should be + # fixed in Arrow instead of here. + if len(ray_waitables) == 0: + return [], [] + + if len(ray_waitables) != len(set(ray_waitables)): + raise ValueError("Wait requires a list of unique ray_waitables.") + if num_returns <= 0: + raise ValueError("Invalid number of objects to return %d." % num_returns) + if num_returns > len(ray_waitables): + raise ValueError( + "num_returns cannot be greater than the number " + "of ray_waitables provided to ray.wait." + ) + + timeout = timeout if timeout is not None else 10**6 + timeout_milliseconds = int(timeout * 1000) + ready_ids, remaining_ids = worker.core_worker.wait( + ray_waitables, + num_returns, + timeout_milliseconds, + fetch_local, + ) + return ready_ids, remaining_ids + + +@PublicAPI +@client_mode_hook +def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHandle": + """Get a handle to a named actor. + + Gets a handle to an actor with the given name. The actor must + have been created with Actor.options(name="name").remote(). This + works for both detached & non-detached actors. + + This method is a sync call and it'll timeout after 60s. This can be modified + by setting OS env RAY_gcs_server_request_timeout_seconds before starting + the cluster. + + Args: + name: The name of the actor. + namespace: The namespace of the actor, or None to specify the current + namespace. + + Returns: + ActorHandle to the actor. + + Raises: + ValueError: if the named actor does not exist. + """ + if not name: + raise ValueError("Please supply a non-empty value to get_actor") + + if namespace is not None: + ray._private.utils.validate_namespace(namespace) + + worker = global_worker + worker.check_connected() + return worker.core_worker.get_named_actor_handle(name, namespace or "") + + +@PublicAPI +@client_mode_hook +def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True): + """Kill an actor forcefully. + + This will interrupt any running tasks on the actor, causing them to fail + immediately. ``atexit`` handlers installed in the actor will not be run. + + If you want to kill the actor but let pending tasks finish, + you can call ``actor.__ray_terminate__.remote()`` instead to queue a + termination task. Any ``atexit`` handlers installed in the actor *will* + be run in this case. + + If the actor is a detached actor, subsequent calls to get its handle via + ray.get_actor will fail. + + Args: + actor: Handle to the actor to kill. + no_restart: Whether or not this actor should be restarted if + it's a restartable actor. + """ + worker = global_worker + worker.check_connected() + if not isinstance(actor, ray.actor.ActorHandle): + raise ValueError( + "ray.kill() only supported for actors. For tasks, try ray.cancel(). " + "Got: {}.".format(type(actor)) + ) + worker.core_worker.kill_actor(actor._ray_actor_id, no_restart) + + +@PublicAPI +@client_mode_hook +def cancel( + ray_waitable: Union["ObjectRef[R]", "ObjectRefGenerator[R]"], + *, + force: bool = False, + recursive: bool = True, +) -> None: + """Cancels a task. + + Cancel API has a different behavior depending on if it is a remote function + (Task) or a remote Actor method (Actor Task). + + Task: + If the specified Task is pending execution, it is cancelled and not + executed. If the Task is currently executing, the behavior depends + on the `force` flag. When `force=False`, a KeyboardInterrupt is + raised in Python and when `force=True`, the executing Task + immediately exits. If the Task is already finished, nothing happens. + + Cancelled Tasks aren't retried. `max_task_retries` aren't respected. + + Calling ray.get on a cancelled Task raises a TaskCancelledError + if the Task has been scheduled or interrupted. + It raises a WorkerCrashedError if `force=True`. + + If `recursive=True`, all the child Tasks and Actor Tasks + are cancelled. If `force=True` and `recursive=True`, `force=True` + is ignored for child Actor Tasks. + + Actor Task: + If the specified Task is pending execution, it is cancelled and not + executed. If the Task is currently executing, the behavior depends + on the execution model of an Actor. If it is a regular Actor + or a threaded Actor, the execution isn't cancelled. + Actor Tasks cannot be interrupted because Actors have + states. If it is an async Actor, Ray cancels a `asyncio.Task`. + The semantic of cancellation is equivalent to asyncio's cancellation. + https://docs.python.org/3/library/asyncio-task.html#task-cancellation + If the Task has finished, nothing happens. + + Only `force=False` is allowed for an Actor Task. Otherwise, it raises + `ValueError`. Use `ray.kill(actor)` instead to kill an Actor. + + Cancelled Tasks aren't retried. `max_task_retries` aren't respected. + + Calling ray.get on a cancelled Task raises a TaskCancelledError + if the Task has been scheduled or interrupted. Also note that + only async actor tasks can be interrupted. + + If `recursive=True`, all the child Tasks and actor Tasks + are cancelled. + + Args: + ray_waitable: :class:`~ObjectRef` and + :class:`~ObjectRefGenerator` + returned by the task that should be canceled. + force: Whether to force-kill a running task by killing + the worker that is running the task. + recursive: Whether to try to cancel tasks submitted by the + task specified. + """ + worker = ray._private.worker.global_worker + worker.check_connected() + + if isinstance(ray_waitable, ray._raylet.ObjectRefGenerator): + assert hasattr(ray_waitable, "_generator_ref") + ray_waitable = ray_waitable._generator_ref + + if not isinstance(ray_waitable, ray.ObjectRef): + raise TypeError( + "ray.cancel() only supported for object refs. " + f"For actors, try ray.kill(). Got: {type(ray_waitable)}." + ) + return worker.core_worker.cancel_task(ray_waitable, force, recursive) + + +def _mode(worker=global_worker): + """This is a wrapper around worker.mode. + + We use this wrapper so that in the remote decorator, we can call _mode() + instead of worker.mode. The difference is that when we attempt to + serialize remote functions, we don't attempt to serialize the worker + object, which cannot be serialized. + """ + return worker.mode + + +def _make_remote(function_or_class, options): + if not function_or_class.__module__: + function_or_class.__module__ = "global" + + if inspect.isfunction(function_or_class) or is_cython(function_or_class): + ray_option_utils.validate_task_options(options, in_options=False) + return ray.remote_function.RemoteFunction( + Language.PYTHON, + function_or_class, + None, + options, + ) + + if inspect.isclass(function_or_class): + ray_option_utils.validate_actor_options(options, in_options=False) + return ray.actor._make_actor(function_or_class, options) + + raise TypeError( + "The @ray.remote decorator must be applied to either a function or a class." + ) + + +class RemoteDecorator(Protocol): + @overload + def __call__(self, __function: Callable[[], R]) -> RemoteFunctionNoArgs[R]: + ... + + @overload + def __call__(self, __function: Callable[[T0], R]) -> RemoteFunction0[R, T0]: + ... + + @overload + def __call__(self, __function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2], R] + ) -> RemoteFunction2[R, T0, T1, T2]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3], R] + ) -> RemoteFunction3[R, T0, T1, T2, T3]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4], R] + ) -> RemoteFunction4[R, T0, T1, T2, T3, T4]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5], R] + ) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6], R] + ) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R] + ) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R] + ) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]: + ... + + @overload + def __call__( + self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R] + ) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: + ... + + # Pass on typing actors for now. The following makes it so no type errors + # are generated for actors. + @overload + def __call__(self, __t: type) -> Any: + ... + + +@overload +def remote(__t: Type[T]) -> ActorClass[T]: + ... + + +@overload +def remote(__function: Callable[[], R]) -> RemoteFunctionNoArgs[R]: + ... + + +@overload +def remote(__function: Callable[[T0], R]) -> RemoteFunction0[R, T0]: + ... + + +@overload +def remote(__function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]: + ... + + +@overload +def remote(__function: Callable[[T0, T1, T2], R]) -> RemoteFunction2[R, T0, T1, T2]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3], R] +) -> RemoteFunction3[R, T0, T1, T2, T3]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4], R] +) -> RemoteFunction4[R, T0, T1, T2, T3, T4]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5], R] +) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6], R] +) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R] +) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R] +) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]: + ... + + +@overload +def remote( + __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R] +) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: + ... + + +# Passing options +@overload +def remote( + *, + num_returns: Union[int, Literal["streaming"]] = Undefined, + num_cpus: Union[int, float] = Undefined, + num_gpus: Union[int, float] = Undefined, + resources: Dict[str, float] = Undefined, + accelerator_type: str = Undefined, + memory: Union[int, float] = Undefined, + max_calls: int = Undefined, + max_restarts: int = Undefined, + max_task_retries: int = Undefined, + max_retries: int = Undefined, + runtime_env: Dict[str, Any] = Undefined, + retry_exceptions: bool = Undefined, + scheduling_strategy: Union[ + None, Literal["DEFAULT"], Literal["SPREAD"], PlacementGroupSchedulingStrategy + ] = Undefined, + label_selector: Dict[str, str] = Undefined, + fallback_strategy: List[Dict[str, Any]] = Undefined, +) -> RemoteDecorator: + ... + + +@PublicAPI +def remote( + *args, **kwargs +) -> Union[ray.remote_function.RemoteFunction, ray.actor.ActorClass]: + """Defines a remote function or an actor class. + + This function can be used as a decorator with no arguments + to define a remote function or actor as follows: + + .. testcode:: + + import ray + + @ray.remote + def f(a, b, c): + return a + b + c + + object_ref = f.remote(1, 2, 3) + result = ray.get(object_ref) + assert result == (1 + 2 + 3) + + @ray.remote + class Foo: + def __init__(self, arg): + self.x = arg + + def method(self, a): + return self.x + a + + actor_handle = Foo.remote(123) + object_ref = actor_handle.method.remote(321) + result = ray.get(object_ref) + assert result == (123 + 321) + + Equivalently, use a function call to create a remote function or actor. + + .. testcode:: + + def g(a, b, c): + return a + b + c + + remote_g = ray.remote(g) + object_ref = remote_g.remote(1, 2, 3) + assert ray.get(object_ref) == (1 + 2 + 3) + + class Bar: + def __init__(self, arg): + self.x = arg + + def method(self, a): + return self.x + a + + RemoteBar = ray.remote(Bar) + actor_handle = RemoteBar.remote(123) + object_ref = actor_handle.method.remote(321) + result = ray.get(object_ref) + assert result == (123 + 321) + + + It can also be used with specific keyword arguments as follows: + + .. testcode:: + + @ray.remote(num_gpus=1, max_calls=1, num_returns=2) + def f(): + return 1, 2 + + @ray.remote(num_cpus=2, resources={"CustomResource": 1}) + class Foo: + def method(self): + return 1 + + Remote task and actor objects returned by @ray.remote can also be + dynamically modified with the same arguments as above using + ``.options()`` as follows: + + .. testcode:: + :hide: + + ray.shutdown() + + ray.init(num_cpus=5, num_gpus=5) + + .. testcode:: + + @ray.remote(num_gpus=1, max_calls=1, num_returns=2) + def f(): + return 1, 2 + + f_with_2_gpus = f.options(num_gpus=2) + object_refs = f_with_2_gpus.remote() + assert ray.get(object_refs) == [1, 2] + + @ray.remote(num_cpus=2, resources={"CustomResource": 1}) + class Foo: + def method(self): + return 1 + + Foo_with_no_resources = Foo.options(num_cpus=1, resources=None) + foo_actor = Foo_with_no_resources.remote() + assert ray.get(foo_actor.method.remote()) == 1 + + + A remote actor will be terminated when all actor handle to it + in Python is deleted, which will cause them to complete any outstanding + work and then shut down. If you only have 1 reference to an actor handle, + calling ``del actor`` *could* trigger actor deletion. Note that your program + may have multiple references to the same ActorHandle, and actor termination + will not occur until the reference count goes to 0. See the Python + documentation for more context about object deletion. + https://docs.python.org/3.9/reference/datamodel.html#object.__del__ + + If you want to kill actors immediately, you can also call ``ray.kill(actor)``. + + .. tip:: + Avoid repeatedly passing in large arguments to remote task or method calls. + + Instead, use ray.put to create a copy of the object in the object store. + + See :ref:`more info here `. + + Args: + num_returns: This is only for *remote functions*. It specifies + the number of object refs returned by the remote function + invocation. The default value is 1. + Pass "dynamic" to allow the task to decide how many + return values to return during execution, and the caller will + receive an ObjectRef[DynamicObjectRefGenerator]. + See :ref:`dynamic generators ` for more details. + num_cpus: The quantity of CPU resources to reserve + for this task or for the lifetime of the actor. + By default, tasks use 1 CPU resource and actors use 1 CPU + for scheduling and 0 CPU for running + (This means, by default, actors cannot get scheduled on a zero-cpu node, + but an infinite number of them can run on any non-zero cpu node. + The default value for actors was chosen for historical reasons. + It's recommended to always explicitly set num_cpus for actors + to avoid any surprises. + If resources are specified explicitly, + they are required for both scheduling and running.) + See :ref:`specifying resource requirements ` + for more details. + num_gpus: The quantity of GPU resources to reserve + for this task or for the lifetime of the actor. + The default value is 0. + See :ref:`Ray GPU support ` for more details. + resources (Dict[str, float]): The quantity of various + :ref:`custom resources ` + to reserve for this task or for the lifetime of the actor. + This is a dictionary mapping strings (resource names) to floats. + By default it is empty. + label_selector: [Experimental] If specified, the labels required for the node on + which this actor can be scheduled on. The label selector consist of key-value pairs, + where the keys are label names and the value are expressions consisting of an operator + with label values or just a value to indicate equality. + fallback_strategy: [Experimental] If specified, expresses soft constraints for scheduling + through a list of dicts of decorator options to fall back on when scheduling on a node. + Decorator options are evaluated together during scheduling. The first satisfied + dict of options is used. Currently only `label_selector` is a supported option. + accelerator_type: If specified, requires that the task or actor run + on a node with the specified type of accelerator. + See :ref:`accelerator types `. + memory: The heap memory request in bytes for this task/actor, + rounded down to the nearest integer. + max_calls: Only for *remote functions*. This specifies the + maximum number of times that a given worker can execute + the given remote function before it must exit + (this can be used to address :ref:`memory leaks ` in third-party + libraries or to reclaim resources that cannot easily be + released, e.g., GPU memory that was acquired by TensorFlow). + By default this is infinite for CPU tasks and 1 for GPU tasks + (to force GPU tasks to release resources after finishing). + max_restarts: Only for *actors*. This specifies the maximum + number of times that the actor should be restarted when it dies + unexpectedly. The minimum valid value is 0 (default), + which indicates that the actor doesn't need to be restarted. + A value of -1 indicates that an actor should be restarted + indefinitely. + See :ref:`actor fault tolerance ` for more details. + max_task_retries: Only for *actors*. How many times to + retry an actor task if the task fails due to a system error, + e.g., the actor has died. If set to -1, the system will + retry the failed task until the task succeeds, or the actor + has reached its max_restarts limit. If set to `n > 0`, the + system will retry the failed task up to n times, after which the + task will throw a `RayActorError` exception upon :obj:`ray.get`. + Note that Python exceptions are not considered system errors + and will not trigger retries. + The default value is 0. + See :ref:`actor fault tolerance ` for more details. + max_retries: Only for *remote functions*. This specifies + the maximum number of times that the remote function + should be rerun when the worker process executing it + crashes unexpectedly. The minimum valid value is 0, + the default value is 3, and a value of -1 indicates + infinite retries. + See :ref:`task fault tolerance ` for more details. + allow_out_of_order_execution: Only for *actors*. Whether Ray executes actor + tasks out of order. If you're using multi-threaded (``max_concurrency > 1``) + or async actors, you can't set this to False. Defaults to True if you're + using multi-threaded or async actors, and False otherwise. Actor task + retries are always executed out of order. + runtime_env (Dict[str, Any]): Specifies the runtime environment for + this actor or task and its children. See + :ref:`runtime-environments` for detailed documentation. + retry_exceptions: Only for *remote functions*. This specifies whether + application-level errors should be retried up to max_retries times. + This can be a boolean or a list of exceptions that should be retried. + See :ref:`task fault tolerance ` for more details. + scheduling_strategy: Strategy about how to + schedule a remote function or actor. Possible values are + None: ray will figure out the scheduling strategy to use, it + will either be the PlacementGroupSchedulingStrategy using parent's + placement group if parent has one and has + placement_group_capture_child_tasks set to true, + or "DEFAULT"; + "DEFAULT": default hybrid scheduling; + "SPREAD": best effort spread scheduling; + `PlacementGroupSchedulingStrategy`: + placement group based scheduling; + `NodeAffinitySchedulingStrategy`: + node id based affinity scheduling. + See :ref:`Ray scheduling strategies ` + for more details. + _labels: The key-value labels of a task or actor. + """ + # "callable" returns true for both function and class. + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + # "args[0]" is the class or function under the decorator. + return _make_remote(args[0], {}) + assert len(args) == 0 and len(kwargs) > 0, ray_option_utils.remote_args_error_string + return functools.partial(_make_remote, options=kwargs) diff --git a/vllm_mlu/__init__.py b/vllm_mlu/__init__.py new file mode 100644 index 0000000..1690cec --- /dev/null +++ b/vllm_mlu/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + + +def register_mlu_platform(): + """Register the MLU platform.""" + return "vllm_mlu.platforms.mlu.MLUPlatform" + + +def register_mlu_hijack(): + """Register the MLU models and hijack.""" + from vllm_mlu import mlu_hijack + from vllm_mlu.model_executor.models import register_model + register_model() + return \ No newline at end of file diff --git a/vllm_mlu/_mlu_ops.py b/vllm_mlu/_mlu_ops.py new file mode 100644 index 0000000..22f8f71 --- /dev/null +++ b/vllm_mlu/_mlu_ops.py @@ -0,0 +1,1853 @@ +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union + +import torch + +import math +import triton +import triton.language as tl + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import torch_mlu_ops as tmo + import torch_mlu_ops.triton_ops as triton_ops +except ImportError as e: + logger.warning("Failed to import from TMO OPS with %r", e) + + +from vllm.distributed import ( + get_ep_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + get_data_parallel_group_world_size, + get_tp_group, + get_tp_world_group, + get_dp_group, + get_data_parallel_group_rank, + get_tp_world_world_size, + get_tp_world_rank, + get_parallel_rank_with_group, +) + + +@triton.jit +def _triton_advance_step(input_tokens_ptr, + sampled_token_ids_ptr, + input_positions_ptr, + seq_lens_ptr, + slot_mapping_ptr, + block_tables_ptr, + block_tables_stride, + num_seqs, + num_queries, + block_size, + TILE_SIZE: tl.constexpr, +): + """ + The triton implementation of advance step. + Reference: https://github.com/vllm-project/vllm/blob/v0.6.1/csrc/prepare_inputs/advance_step.cu#L14-L55 + """ + # Set meta info. + pid = tl.program_id(axis=0) + offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offsets < num_queries + + # Update input_tokens. + sampled_token_ids = tl.load(sampled_token_ids_ptr + offsets, mask=mask) + tl.store(input_tokens_ptr + offsets, sampled_token_ids, mask=mask) + + seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask) + next_seq_lens = seq_lens + 1 + next_input_pos = next_seq_lens - 1 + + # Update seq_lens. + tl.store(seq_lens_ptr + offsets, next_seq_lens, mask=mask) + + # Update input_positions. + tl.store(input_positions_ptr + offsets, next_input_pos, mask=mask) + + # Calculate slot num. + block_index = next_input_pos // block_size + block_offset = next_input_pos % block_size + block_tables = tl.load(block_tables_ptr + block_tables_stride * offsets + block_index, mask=mask) + slot_num = block_tables * block_size + block_offset + + # Update slot_mapping. + tl.store(slot_mapping_ptr + offsets, slot_num, mask=mask) + + + +def rotary_embedding( + input: torch.Tensor, + sin_cache: torch.Tensor, + cos_cache: torch.Tensor, + position_ids: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + interleaved: bool, + discrete: bool, + dynamic_ntk: bool, + max_seqlen: int, +) -> torch.Tensor: + return tmo.apply_rotary( + input, sin_cache, cos_cache, + position_ids, cu_seqlens, interleaved, + discrete, dynamic_ntk, max_seqlen) + + +def fused_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + bias: torch.Tensor, + eps: float, + store_output_before_norm: bool, + quant_scale: torch.Tensor = None, + dynamic_quant: bool = False, + out: torch.Tensor = None, +) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]: + return tmo.fused_rms_norm( + x, residual, gamma, beta, bias, + eps, store_output_before_norm, quant_scale, + out, dynamic_quant) + + +def fused_layer_norm( + x: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + bias: torch.Tensor, + eps: float, + store_output_before_norm: bool, + quant_scale: torch.Tensor = None, + dynamic_quant: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + return tmo.fused_layer_norm( + x, residual, gamma, beta, bias, + eps, store_output_before_norm, quant_scale, + None, dynamic_quant) + + +def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seq_lens_q: Optional[torch.Tensor], + cu_seq_lens_kv: Optional[torch.Tensor], + alibi_slope: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], + max_seq_len_q: int, + max_seq_len_kv: int, + softmax_scale: float, + is_causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + compute_dtype: torch.dtype = torch.float, + return_lse: bool = False, + block_tables: Optional[torch.Tensor] = None, + out_quant_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, + q_quant_dtype: Optional[torch.dtype] = None, + k_quant_dtype: Optional[torch.dtype] = None, + v_quant_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if v is None: + v = k + return tmo.flash_attention( + q, k, v, out, + cu_seq_lens_q, cu_seq_lens_kv, + alibi_slope, attn_bias, + max_seq_len_q, max_seq_len_kv, + softmax_scale, is_causal, + window_size_left, window_size_right, + compute_dtype, return_lse, + block_tables, k_quant_scale, + v_quant_scale, q_quant_scale, + out_quant_scale, out_dtype) + + +def split_head_nums(q_head_num, kv_head_num, max_q_head_num): + """ + Split q_head_num such that: + 1. The maximum value of the split q_head_num does not exceed max_q_head_num. + 2. kv_head_num is split into the same number of parts as q_head_num. + 3. Each split q_head_num can be evenly divided by the corresponding kv_head_num. + 4. If kv_head_num < 1, it is adjusted to 1. + + Parameters: + - q_head_num: int, the q_head_num to be split. + - kv_head_num: int, the kv_head_num to be split. + - max_q_head_num: int, the maximum supported q_head_num after splitting. + + Returns: + - q_splits: list, the split q_head_num. + - kv_splits: list, the split kv_head_num. + """ + if q_head_num <= 0 or kv_head_num <= 0: + return "q_head_num and kv_head_num must be positive integers!" + + q_splits = [] + kv_splits = [] + + # Residual value + remaining_q = q_head_num + remaining_kv = kv_head_num + + while remaining_q > 0: + # Attempt to split q_head_num such that the maximum value does not exceed max_q_head_num. + for q_part in range(min(max_q_head_num, remaining_q), 0, -1): + # Ensure that q_part can be allocated and the corresponding kv_part is greater than or equal to 1. + if remaining_q % q_part == 0: + # Ensure that kv_part is greater than or equal to 1. + kv_part = max(remaining_kv // (remaining_q // q_part), 1) + # Ensure that q_part is divisible by kv_part. + if q_part % kv_part == 0: + # Record the split values. + q_splits.append(q_part) + kv_splits.append(kv_part) + remaining_q -= q_part + remaining_kv -= kv_part + break + else: + err_msg = f"Unable to find split method for q_head_num:{q_head_num}, kv_head_num:{kv_head_num}" + raise RuntimeError(err_msg) + + return q_splits, kv_splits + + +def repeat_elements(input_list, n): + """ + Repeat each element in the list n times consecutively. + + Parameters: + - input_list: list, the input list. + - n: int, the number of times each element should be repeated. + + Returns: + - list, a new list containing the repeated elements. + """ + if not isinstance(input_list, list) or not isinstance(n, int) or n < 0: + raise ValueError("The input must be a list, and the repetition count n must be an integer greater than or equal to 0.") + + # Repeat each element n times using a list comprehension. + return [item for item in input_list for _ in range(n)] + + +def single_query_cached_kv_attn( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + out: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + k_cache_quant_scale: Optional[torch.Tensor], + v_cache_quant_scale: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_contxt_len: int, + windows_size_left: int, + windows_size_right: int, + softmax_scale: float, + return_lse: bool = False, + q_head_dim: Optional[int] = 2, + kv_head_dim: Optional[int] = 1, + seq_q_dim: Optional[int] = 1, + max_seq_q_mul_q_divide_kv: Optional[int] = 128, + head_size_v: Optional[int] = -1, + compute_dtype: Optional[torch.dtype] = torch.float32, + q_quant_scale: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + q_quant_dtype: Optional[torch.dtype] = None, + q_scale_dtype: Optional[torch.dtype] = None, + learnable_sink: Optional[torch.Tensor] = None, +) -> None: + windows_size_right = -1 + seq_q = q.shape[seq_q_dim] + if q_quant_dtype is not None and q.dtype != q_quant_dtype and q_quant_scale is None: + q, q_quant_scale = tmo.scaled_quantize(q.contiguous(), quant_type=q_quant_dtype, quant_mode="dynamic_per_token") + if k_cache is not None and k_cache.dtype == torch.uint8: + k_cache = k_cache.view(torch.float8_e4m3fn) + if v_cache is not None and v_cache.dtype == torch.uint8: + v_cache = v_cache.view(torch.float8_e4m3fn) + + if k_cache is not None and k_cache.dtype == torch.bfloat16: + max_seq_q_mul_q_divide_kv = 256 + + # single_query_cached_kv_attn limits seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv now, + # and this limitation only applies when using kv8 or floating point computation. + # When the limitation is fixed, we should delete the split process. + q_head_num = q.shape[q_head_dim] + kv_head_num = k_cache.shape[kv_head_dim] + q_divide_kv = q_head_num // kv_head_num + if seq_q * q_divide_kv <= max_seq_q_mul_q_divide_kv or q_quant_scale is not None: + tmo.single_query_cached_kv_attn( + q, k_cache, v_cache, out, + block_tables, context_lens, + k_cache_quant_scale, v_cache_quant_scale, + alibi_slopes, max_contxt_len, + windows_size_left, windows_size_right, softmax_scale, return_lse, + q_quant_scale=q_quant_scale, + head_size_v=head_size_v, + compute_dtype=compute_dtype, + mask=mask, + learnable_sink=learnable_sink) + else: + max_q_head_num = max_seq_q_mul_q_divide_kv * kv_head_num // seq_q + q_head_num_sizes, kv_head_num_sizes = split_head_nums(q_head_num, kv_head_num, max_q_head_num) + parts_num = len(q_head_num_sizes) + q_parts = torch.split(q, q_head_num_sizes, dim=q_head_dim) + out_parts = torch.split(out, q_head_num_sizes, dim=q_head_dim) + alibi_slopes_parts = [None] * parts_num + if alibi_slopes: + alibi_slopes_parts = torch.split(alibi_slopes, q_head_num_sizes, dim=0) + + kv_parts_num = parts_num + if parts_num > kv_head_num: + assert parts_num % kv_head_num == 0, f"parts_num:{parts_num} need by divided by kv_head_num:{kv_head_num} when parts_num > kv_head_num" + kv_parts_num = kv_head_num + kv_head_num_sizes = kv_head_num_sizes[:kv_parts_num] + + if len(kv_head_num_sizes) > 1: + k_cache_parts = torch.split(k_cache, kv_head_num_sizes, dim=kv_head_dim) + v_cache_parts = torch.split(v_cache, kv_head_num_sizes, dim=kv_head_dim) + k_cache_quant_scale_parts = [None] * kv_parts_num + v_cache_quant_scale_parts = [None] * kv_parts_num + if k_cache_quant_scale: + k_cache_quant_scale_dim = 1 if k_cache_quant_scale.dim() == 2 else kv_head_dim + k_cache_quant_scale_parts = torch.split(k_cache_quant_scale, kv_head_num_sizes, dim=k_cache_quant_scale_dim) + if v_cache_quant_scale: + v_cache_quant_scale_dim = 1 if v_cache_quant_scale.dim() == 2 else kv_head_dim + v_cache_quant_scale_parts = torch.split(v_cache_quant_scale, kv_head_num_sizes, dim=v_cache_quant_scale_dim) + else: + k_cache_parts = [k_cache] + v_cache_parts = [v_cache] + k_cache_quant_scale_parts = [k_cache_quant_scale] + v_cache_quant_scale_parts = [v_cache_quant_scale] + + if parts_num > kv_parts_num: + repeate_num = parts_num // kv_parts_num + k_cache_parts = repeat_elements(k_cache_parts, repeate_num) + v_cache_parts = repeat_elements(v_cache_parts, repeate_num) + k_cache_quant_scale_parts = repeat_elements(k_cache_quant_scale_parts, repeate_num) + v_cache_quant_scale_parts = repeat_elements(v_cache_quant_scale_parts, repeate_num) + + for q_value, k_cache_value, v_cache_value, out_value, k_cache_quant_scale_value, v_cache_quant_scale_value, alibi_slopes_value in zip( + q_parts, k_cache_parts, v_cache_parts, out_parts, k_cache_quant_scale_parts, v_cache_quant_scale_parts, + alibi_slopes_parts): + tmo.single_query_cached_kv_attn( + q_value, k_cache_value.contiguous(), v_cache_value.contiguous() if v_cache_value is not None else None, + out_value, block_tables, context_lens, + k_cache_quant_scale_value, v_cache_quant_scale_value, + alibi_slopes_value, max_contxt_len, + windows_size_left, windows_size_right, softmax_scale, return_lse, + head_size_v=head_size_v, + compute_dtype=compute_dtype) + + return(None, None) # TODO(liangxuegang): to fix return (output, lse) + + +def reshape_paged_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_mapping: torch.Tensor +) -> None: + tmo.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping) + + +def swap_blocks( + dst: torch.Tensor, + src: torch.Tensor, + block_mapping: torch.Tensor +) -> None: + # FIXME: Remove this conversion after + # tmo.swap_blocks support block_mapping tensor. + block_mapping = block_mapping.tolist() + block_mapping = {src: dst for src, dst in block_mapping} + return tmo.swap_blocks(dst, src, block_mapping) + + +def copy_blocks( + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + block_mapping: torch.Tensor +) -> None: + # FIXME: Remove this conversion after + # tmo.swap_blocks support block_mapping tensor. + block_mapping = block_mapping.tolist() + result_dict = {} + for row in block_mapping: + key = row[0] + values = row[1:] + if key in result_dict: + result_dict[key].extend(values) + else: + result_dict[key] = values + return tmo.copy_blocks(k_caches, v_caches, result_dict) + + +def active( + input: torch.Tensor, + act_mode: str, + is_gated: bool +) -> torch.Tensor: + return tmo.active(input, act_mode, is_gated) + +def fused_moe(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], + residual: Optional[torch.Tensor], + input_smooth: Optional[torch.Tensor], + act_smooth: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + topk: int, + renormalize: bool, + gated: bool, + act_mode: str, + start_expert_id: int = 0, + block_n: int = 0, + cncl_comm: int = 0, + avg_moe: bool=False, + class_reduce_weight: Optional[torch.Tensor] = None, + class_expert_id: Optional[torch.Tensor] = None, + w1_quant_flag: Optional[List] = None, + w2_quant_flag: Optional[List] = None, + world_size: int = 0, + shared_expert_num: int = 0, + parallel_mode: str = 'ep'): + dtype = hidden_states.dtype + ori_input_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + tokens = hidden_states.size(0) + gating_output = gating_output.reshape(-1, gating_output.size(-1)) + residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None + expert_num = gating_output.size(-1) + expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1) + + per_token_sq = False + # check quant + check_list = [input_smooth, act_smooth, w1_scale, w2_scale] + if all(x is not None for x in check_list): + per_token_sq = True + + if not (all(x is None for x in check_list) or all(x is not None for x in check_list)): + raise ValueError( + "input_smooth, act_smooth, w1_scale and w2_scale must be " + "present and absent at the same time." + ) + + # softmax_topk + reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalize) + + # append shared + if shared_expert_num > 0: + reduce_weight, expert_id = tmo.moe_append_shared_expert(reduce_weight, expert_id, expert_num, + shared_expert_num, world_size, parallel_mode) + if parallel_mode == "ep": + avg_shared_expert_num = (world_size + shared_expert_num - 1) // world_size + expert_num += avg_shared_expert_num * world_size + else: + expert_num += shared_expert_num + + if avg_moe: + n_tokens = hidden_states.shape[0] + reduce_weight = class_reduce_weight[:n_tokens] + expert_id = class_expert_id[:n_tokens] + # gen_idx + expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num) + + if per_token_sq: + quant_input, input_scale = tmo.moe_quantize(hidden_states, + input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx, + cusum_token_count[start_expert_id].unsqueeze(0)) + else: + expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx, + cusum_token_count, start_expert_id, expert_size) + + # group gemm + if per_token_sq: + gemm1_out = tmo.smooth_quant_group_gemm(quant_input, + w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, + input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag) + else: + gemm1_out = tmo.group_gemm(expand_hidden_states, + w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, + None, + None, + None, tokens) + if per_token_sq: + quant_input = quant_input[:, :gemm1_out.shape[-1] // 2] if gated else quant_input[:, :gemm1_out.shape[-1]] + input_scale = input_scale[:gemm1_out.shape[0]] + quant_input, input_scale = tmo.moe_quantize(gemm1_out, act_smooth, None, + token_count[start_expert_id:start_expert_id+expert_size], + output=quant_input, + output_scale=input_scale, + act_mode=act_mode, + is_gated=gated) + else: + act_out = gemm1_out[:, :gemm1_out.shape[-1] // 2] if gated else gemm1_out + act_out = tmo.moe_active(gemm1_out, act_mode, gated, act_out, bias1, cusum_token_count, start_expert_id, expert_size) + if cncl_comm > 0: + raise ValueError("not support communication and computing fusion currently.") + else: + if per_token_sq: + gemm2_out = tmo.smooth_quant_group_gemm(quant_input, + w2, token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag) + else: + gemm2_out = tmo.group_gemm(act_out, + w2, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, tokens) + + output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx, + residual, cusum_token_count, start_expert_id, + expert_size, bias2) + return output.reshape(ori_input_shape) + + +def matmul( + a: torch.Tensor, + b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + act_mode: str = 'none', + alpha: float = 1.0, + beta: float = .0 +) -> torch.Tensor: + return tmo.matmul(a, b, bias, c, act_mode, alpha, beta) + + +def weight_only_quant_matmul( + a: torch.Tensor, + b: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor = None, + bias: torch.Tensor = None, + c: torch.Tensor = None, + act_mode: str = "none", + quant_bit_size: int = 8, + alpha: float = 1.0, + beta: float = 1.0 +) -> torch.Tensor: + assert False, "[weight_only_quant_matmul] is deprecated." + + +def smooth_quant_matmul( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + dtype: torch.dtype, + bias: torch.Tensor = None, + c: torch.Tensor = None, + act_mode: str = "none", + alpha: float = 1.0, + beta: float = 1.0, + use_hp_active: bool = False, + b_quant_bit_size: int = 8, + output: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return tmo.scaled_matmul(a, b, a_scale, b_scale, dtype, bias, c, act_mode, + b_quant_bit_size, alpha, beta, use_hp_active) + + +def per_token_smooth_quantize(x: torch.Tensor, + smooth: torch.Tensor, + zero: torch.Tensor = None, + token_count: torch.Tensor = None, + act_mode: str = "none", + active_coef: float = 1.0, + is_gated: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + if act_mode == "none": + is_gated = False + if token_count is None: + output, output_scale = tmo.scaled_quantize(x, smooth, zero, None, torch.int8, + "dynamic_per_token", act_mode, active_coef, + is_gated) + else: + output, output_scale = tmo.moe_quantize(x, smooth, zero, token_count, None, None, None, + None, True, act_mode, active_coef, is_gated) + return (output, output_scale) + + +def quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor = None +) -> torch.Tensor: + assert False, "[quantize] is deprecated." + + +def quant_to_paged_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_cache_quant_scale: torch.Tensor, + v_cache_quant_scale: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + if k_cache is not None and k_cache.dtype == torch.uint8: + k_cache = k_cache.view(torch.float8_e4m3fn) + if v_cache is not None and v_cache.dtype == torch.uint8: + v_cache = v_cache.view(torch.float8_e4m3fn) + return tmo.quant_to_paged_cache( + k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping + ) + + +def advance_step(num_seqs: int, + num_queries: int, + block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, + slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + TILE_SIZE: int = 64) -> None: + """ + Advance a step on MLU for existing inputs for a multi-step runner, which + will update input_tokens/seq_lens/input_positions/slot_mapping inplace. + """ + def verify_tensor( + name: str, + tensor: torch.Tensor, + size_0: int, + size_1: int, + dtype: torch.dtype, + ): + """ + Auxiliary function to check whether input is valid. + """ + size_0_cond = (size_0 == -1 or tensor.size(0) == size_0) + size_1_cond = (size_1 == -1 or tensor.size(1) == size_1) + if not (size_0_cond and size_1_cond and tensor.is_contiguous and tensor.dtype == dtype): + raise ValueError( + f"The input to advance_step is invalid with tensor name = {name}, " + f"shape = {tensor.shape}, " + f"is_cont = {tensor.is_contiguous()}, " + f"type = {tensor.dtype}, " + f"is not as expected: shape[{size_0}, {size_1}], type = {dtype}" + ) + + verify_tensor("input_tokens", input_tokens, num_seqs, -1, torch.int64) + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, torch.int64) + verify_tensor("input_positions", input_positions, num_seqs, -1, torch.int32) + verify_tensor("seq_lens", seq_lens, num_seqs, -1, torch.int32) + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, torch.int32) + verify_tensor("block_tables", block_tables, num_seqs, -1, torch.int32) + + grid = (math.ceil(num_queries / TILE_SIZE), ) + _triton_advance_step[grid](input_tokens, + sampled_token_ids, + input_positions, + seq_lens, + slot_mapping, + block_tables, + block_tables.stride(0), + num_seqs, + num_queries, + block_size, + TILE_SIZE) + + +#Moe inner kernels +def moe_softmax_topk(input: torch.Tensor, + topk: int, + normalize: bool = False, + num_expert_group: int = -1, + topk_group: int = 0, + mask: Optional[torch.Tensor] = None, + normed_by : str = "topk_logit", + route_scale : float = 1.0, + reduce_weight: Optional[torch.Tensor] = None, + expert_id: Optional[torch.Tensor] = None, + score_bias: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]: + return tmo.moe_softmax_topk(input, topk, normalize, num_expert_group, + topk_group, mask, normed_by, route_scale, + reduce_weight, expert_id, score_bias) + +def moe_sigmoid_topk(input: torch.Tensor, + topk: int, + normalize: bool = False, + num_expert_group: int = -1, + topk_group: int = 0, + route_scale: float = 1.0, + score_bias: Optional[torch.Tensor] = None, + reduce_weight: Optional[torch.Tensor] = None, + expert_id: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor]: + return tmo.moe_sigmoid_topk(input, topk, normalize, num_expert_group, + topk_group, route_scale = route_scale, + score_bias = score_bias, + reduce_weight=reduce_weight, + expert_id=expert_id) + +def moe_softplus_topk( + input: torch.Tensor, + topk: int, + input_ids: Optional[torch.Tensor] = None, + tid2eid: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + route_scale: float = 1.0, + reduce_weight: Optional[torch.Tensor] = None, + expert_id: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return tmo.moe_softplus_topk( + input, + topk, + input_ids, + tid2eid, + bias, + route_scale, + reduce_weight, + expert_id, + ) + +def moe_gen_idx(expert_id: torch.Tensor, + expert_num: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return tmo.moe_gen_idx(expert_id, expert_num) + +def moe_expand_input(input: torch.Tensor, + gather_idx: torch.Tensor, + cusum_token_count: Optional[torch.Tensor] = None, + start_expert_id: int = 0, + expert_size: int = 0) -> torch.Tensor: + return tmo.moe_expand_input(input, gather_idx, + cusum_token_count, + start_expert_id, expert_size) + +def moe_active(input: torch.Tensor, + act_mode: str, + is_gated: bool, + output: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cusum_token_count: Optional[torch.Tensor] = None, + start_expert_id: int = 0, + expert_size: int = 0) -> torch.Tensor: + return tmo.moe_active(input, act_mode, is_gated, output, + bias, cusum_token_count, + start_expert_id, expert_size) + +def group_gemm(a: torch.Tensor, + b: torch.Tensor, + m_list: torch.Tensor, + expand_idx: Optional[torch.Tensor], + c: Optional[torch.Tensor], + alpha: Optional[torch.Tensor], + beta: Optional[torch.Tensor], + max_m: int = 0, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return tmo.group_gemm(a, b, m_list, expand_idx, + c, alpha, beta, max_m, d=output) + +def smooth_quant_group_gemm(a: torch.Tensor, + b: torch.Tensor, + m_list: torch.Tensor, + expand_idx: Optional[torch.Tensor], + c: Optional[torch.Tensor], + alpha: Optional[torch.Tensor], + beta: Optional[torch.Tensor], + a_scale: torch.Tensor, + b_scale: torch.Tensor, + dtype, + max_m: int = 0, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return tmo.smooth_quant_group_gemm(a, b, m_list, expand_idx, c, alpha, beta, + a_scale, b_scale, dtype, max_m, d=output) + +def moe_combine_result(input: torch.Tensor, + reduce_weight: torch.Tensor, + gather_ids: torch.Tensor, + residual: Optional[torch.Tensor], + cusum_token_count: Optional[torch.Tensor], + start_expert_id: int, + expert_size: int, + bias: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return tmo.moe_combine_result(input, reduce_weight, gather_ids, + residual, cusum_token_count, + start_expert_id, expert_size, bias, output=output) + +def moe_quantize(x: torch.Tensor, + smooth: torch.Tensor, + zero: Optional[torch.Tensor] = None, + token_count: Optional[torch.Tensor] = None, + gather_index: Optional[torch.Tensor] = None, + gather_index_start_position: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + dynamic_quant: bool = True, + act_mode: str = "none", + active_coef: float = 1.0, + is_gated: bool = False, + quant_type: torch.dtype = torch.int8 + ) -> Tuple[torch.Tensor, torch.Tensor]: + return tmo.moe_quantize(x, smooth, zero, token_count, gather_index, gather_index_start_position, + output, output_scale, dynamic_quant, act_mode, active_coef, is_gated, quant_type) + + +def dequant_from_paged_cache(key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + key_cache_quant_scale: torch.Tensor, + value_cache_quant_scale: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor], + block_tables: torch.Tensor, + quant_mode: int = 0, + quant_bit: int = 8) -> None: + tmo.dequant_from_paged_cache( + key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, + context_lengths, max_context_len, context_seq_offset, block_tables, quant_mode, quant_bit) + +def random_sample( + probs: torch.Tensor, + is_gumbel_max: bool, + generators: dict[int, torch.Generator], +) -> torch.Tensor: + return tmo.random_sample(probs, is_gumbel_max, generators) + +def rejection_sample(draft_token_ids: torch.Tensor, + num_draft_tokens: torch.Tensor, + cu_num_draft_tokens: torch.Tensor, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + uniform_rand: torch.Tensor, + uniform_probs: torch.Tensor, + max_spec_len: int, + high_acc: bool = True) -> torch.Tensor: + return tmo.rejection_sample( + draft_token_ids, num_draft_tokens, cu_num_draft_tokens, draft_probs, + target_probs, bonus_token_ids, uniform_rand, uniform_probs, max_spec_len, high_acc) + +def apply_topkp_v2(logits: torch.Tensor, + index_in: torch.Tensor, + temperature_list: torch.Tensor, + minp_list: torch.Tensor, + topk_list: torch.Tensor, + topp_list: torch.Tensor, + logits_out: Optional[torch.Tensor] = None, + sorted_logits_out: Optional[torch.Tensor] = None, + index_out: Optional[torch.Tensor] = None, + true_select_len: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return tmo.apply_topkp_v2(logits, index_in, temperature_list, minp_list, topk_list, topp_list, + logits_out, sorted_logits_out, index_out, true_select_len) + + +def scaled_quantize( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + zero: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + quant_type: torch.dtype = torch.int8, + quant_mode: str = "dynamic_per_token", + act_mode: str = "none", + active_coef: float = 1.0, + is_gated: bool = False +) -> Tuple[torch.Tensor]: + """ + Apply activation and quantization to the input tensor x. + + Args: + x (torch.Tensor): The tensor to be quantized, shape is (..., C), must be continuous between 0 and -2 dimensions. + scale (Optional[torch.Tensor], optional): The scale multipled to the input tensor. Shape is (C) or (1). + zero (Optional[torch.Tensor], optional): Not supported, must pass None. + scale_ub (Optional[torch.Tensor], optional): The output_scale upper bound. + Take effect only if quant_type == torch.float8_e4m3fn and quant_mode == "dynamic_per_token". + quant_type (optional): Output data type, can be torch.int8, torch.float8_e4m3fn. Defaults to torch.int8. + quant_mode (str, optional): quantize mode, which can be "dynamic_per_token", "dynamic_per_tensor", "static_per_tensor" + and "static_per_channel". Defaults to "dynamic_per_token". + act_mode (str): The mode of activation, must be "none", "gelu", "silu", "swish". + active_coef(float): The coefficient used in the swish activation. Default is 1.0. + is_gated (bool): A boolean parameter that indicates whether a gating mechanism is applied. It only + takes effect when act_mode is not "none". + + Type: + input: float, half, bfloat16. + scale: float. + scale_ub: float. + act_mode: str + active_coef: float + is_gated: bool + + Returns: + Tuple[torch.Tensor]: Returns (output, output_scale) if quant_mode is "dynamic_per_token" or "dynamic_per_tensor", + otherwise returns output only. + """ + return tmo.scaled_quantize(input, + scale, + zero, + scale_ub, + quant_type, + quant_mode, + act_mode, + active_coef, + is_gated) + +def scaled_matmul(a: torch.Tensor, + b: torch.Tensor, + a_scale: Optional[torch.Tensor], + b_scale: torch.Tensor, + output_dtype: torch.dtype, + bias: torch.Tensor = None, + c: torch.Tensor = None, + act_mode: str = "none", + quant_bit_size: int = 8, + alpha: float = 1.0, + beta: float = 1.0, + use_hp_active: bool = False, + a_quant_bit_size: int = 8, + a_calib: Optional[torch.Tensor] = None, + b_calib: Optional[torch.Tensor] = None,): + """ + Perform quantized matrix multiplication on tensor a and b. + + Args: + a (torch.Tensor): Shape is (M, K). + b (torch.Tensor): If quant_bit_size = 8, shape is (N, K). + If quant_bit_size = 4, shape is (N, K//2). + a_scale (Optional[torch.Tensor]): Shape can be (M). + b_scale (torch.Tensor): If use groupwise quantization, shape must be (N, group_num), data type must be + the same as a; otherwise shape must be (N), data type must be float. + output_dtype (torch.dtype): Specify the data type of output, must be torch.half or torch.bfloat16. + bias (torch.Tensor, optional): Shape is (N). + c (torch.Tensor, optional): Shape is (M, N). + act_mode (str, optional): Choose the activation algorithm, must be 'silu', 'gelu' or 'none'. If use groupwise + quantization, act_mode must be 'none'. + quant_bit_size (int, optional): The data format of b. Defaults to 8. + alpha (float, optional): coefficient of acted. Defaults to 1.0. + beta (float, optional): coefficient of c. Defaults to 1.0. + use_hp_active (bool, optional): Describing the algorithm that used in the implementation of the activation function. + When the value is true, use the high-precision algorithm, otherwise use the fastest algorithm of activation. + Defaults to False. + a_quant_bit_size(int, optional):The data format of a. Defaults to -1. + a_calib (Optional[torch.Tensor]): The calibration of a, shape can be (M, 2). + b_calib (Optional[torch.Tensor]): The calibration of b, shape can be (M, 2). + + Type: + a: int8, half, bfloat16, float8_e4m3fn, int4X2 + a_scale: float + b: int8, float8_e4m3fn, int4X2 + b_scale: float, half, bfloat16 + bias: half, float, bfloat16 + c: half, float, bfloat16 + output: half, bfloat16 + a_calib: float + b_calib: float + + Returns: + A tensor with the shape of (M, N). + """ + return tmo.scaled_matmul(a, + b, + a_scale, + b_scale, + output_dtype, + bias, + c, + act_mode, + quant_bit_size, + alpha, + beta, + use_hp_active, + a_quant_bit_size, + a_calib, + b_calib,) + +def fused_mla_kv(kv: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + position_id: torch.Tensor, + gamma: torch.Tensor, + kv_cache: torch.Tensor, + kv_cache_scale: Optional[torch.Tensor], + slot_mapping: Optional[torch.Tensor], + cache_bs_id: Optional[torch.Tensor] = None, + cache_seq_offset: Optional[torch.Tensor] = None, + is_paged_cache: bool = True, + eps: float = 1e-5, + interleaved: bool = True): + quant_mode = "static_per_channel" if kv_cache_scale is None else "dynamic_per_token" + return tmo.fused_mla_kv( + kv, sin, cos, position_id, gamma, kv_cache, kv_cache_scale, slot_mapping, cache_bs_id, + cache_seq_offset, + quant_mode=quant_mode, + is_paged_cache=is_paged_cache, + eps=eps, + interleaved=interleaved, + ) + +def fused_mla_q(q: torch.Tensor, + gamma: torch.Tensor, + smooth_quant_scale: torch.Tensor, + weight_b: torch.Tensor, + weight_b_scale: torch.Tensor, + weight_c: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + position_id: torch.Tensor, + output: Optional[torch.Tensor] = None, + eps: float = 1e-6, + interleaved: bool = True, + output_quant_mode: str = 'none', + output_scale: Optional[torch.Tensor] = None, + output_norm: Optional[torch.Tensor] = None) -> torch.Tensor: + return tmo.fused_mla_q( + q, gamma, smooth_quant_scale, weight_b, weight_b_scale, weight_c, sin, cos, position_id, + output, eps, interleaved, output_quant_mode, output_scale, + store_norm=(output_norm is not None), + output_norm= output_norm, + ) + + +def gather_cache( + kv_cache: List[torch.Tensor], # [[1, num_blocks, num_kv_heads, block_size, head_size] + # [1, num_blocks, num_kv_heads, block_size] if kv_cache_dtype=int8] + dst: torch.Tensor, # [tot_tokens, entrys...] + block_table: torch.Tensor, # [batch, block_indices] + cu_seq_lens: torch.Tensor, # [batch+1] + batch_size: int, + seq_starts: torch.Tensor = None, # Optional: [batch] + kv_cache_dtype: str = 'auto', +) -> None: + """ + Gathers sequences from src_cache into dst based on block_table and cu_seq_lens. + + Args: + src_cache: Source KV cache tensor of shape [[1, num_blocks, num_kv_heads, block_size, head_size], + [1, num_blocks, num_kv_heads, block_size] if cache_dtype=int8]. + dst: Destination tensor of shape [tot_tokens, entrys...]. + block_table: Tensor of shape [batch, block_indices] mapping sequences to blocks. + cu_seq_lens: Tensor of shape [batch+1] with cumulative sequence lengths. + batch_size: Number of sequences in the batch. + seq_starts: Optional tensor of shape [batch] for block index offsets. + """ + assert len(kv_cache) > 0 and kv_cache[0].numel() > 0, "kv cache can't be empty in gather_cache" + src_cache = kv_cache[0][0] + # Validate inputs + assert src_cache.device == dst.device == block_table.device == cu_seq_lens.device, \ + "All tensors must be on the same device" + assert block_table.dtype == torch.int32, "block_table must be int32" + assert cu_seq_lens.dtype == torch.int32, "cu_seq_lens must be int32" + quant_kv_cache = kv_cache_dtype != 'auto' + if not quant_kv_cache: + assert src_cache.dtype == dst.dtype, "src_cache and dst must have the same dtype when no quantized" + if seq_starts is not None: + assert seq_starts.dtype == torch.int32, "seq_starts must be int32" + assert seq_starts.device == src_cache.device, "seq_starts must be on the same device" + + # Extract dimensions + num_blocks, num_kv_heads, block_size, head_size = src_cache.shape + # When using MLA during decode it becomes MQA, the num_kv_heads is fixed to 1, + # so src_cache can be view to [num_blocks, block_size, head_size] + assert num_kv_heads == 1, "mla force num_kv_heads to 1" + src_cache = src_cache.view(num_blocks, block_size, -1) + entry_shape = src_cache.shape[2:] # ENTRIES... + tot_tokens = cu_seq_lens[-1] + assert tot_tokens > 0, "tot_tokens should > 0" + assert tot_tokens <= dst.shape[0], "tot_tokens should <= dst.shape[0]" + dst_cache = dst[:tot_tokens] + + # Ensure cu_seq_lens matches batch_size + assert cu_seq_lens.size(0) == batch_size + 1, "cu_seq_lens must have batch_size + 1 elements" + + # Compute sequence lengths + seq_lens = cu_seq_lens[1:] - cu_seq_lens[:-1] # [BATCH] + tot_blocks_per_seq = (seq_lens + block_size - 1) // block_size # ceil_div + + # Handle seq_starts offset + block_offsets = torch.zeros(batch_size, dtype=torch.int32, device=src_cache.device) + if seq_starts is not None: + block_offsets = seq_starts // block_size + + # Flatten src_cache for easier indexing: [NUM_BLOCKS * BLOCK_SIZE, ENTRIES...] + src_flat = src_cache.view(num_blocks * block_size, *entry_shape) + + # Prepare output indices + dst_indices = [] + for bid in range(batch_size): + seq_len = seq_lens[bid] + if seq_len <= 0: + continue + seq_start = cu_seq_lens[bid] + tot_blocks = tot_blocks_per_seq[bid] + offset = block_offsets[bid] + + # Compute block indices for this sequence + block_ids = block_table[bid, offset:offset + tot_blocks] + + # Compute token indices within blocks + token_indices = torch.arange(seq_len, device=src_cache.device) + block_indices = token_indices // block_size + within_block = token_indices % block_size + + # Map to src_flat indices + src_indices = block_ids[block_indices] * block_size + within_block + dst_indices.append(src_indices) + + # Concatenate all indices + dst_indices = torch.cat(dst_indices) + + # Gather data + dst_flat = src_flat[dst_indices] + if quant_kv_cache: + src_cache_scale = kv_cache[1][0] + src_scale_flat = src_cache_scale.view(num_blocks * block_size) + dst_scale_flat = src_scale_flat[dst_indices] + dst_flat = dst_flat * dst_scale_flat.unsqueeze(-1) + + dst_cache.view(-1, *entry_shape).copy_(dst_flat.view(tot_tokens, *entry_shape)) + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + """ + Merges partial attention states (prefix and suffix) into a single output. + Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005. + + Args: + output: Output tensor of shape [num_tokens, num_query_heads, head_size]. + prefix_output: Prefix attention output, same shape as output. + prefix_lse: Prefix log-sum-exp, shape [num_query_heads, num_tokens]. + suffix_output: Suffix attention output, same shape as output. + suffix_lse: Suffix log-sum-exp, same shape as prefix_lse. + output_lse: Optional output log-sum-exp, same shape as prefix_lse. + """ + # Input validation + assert output.shape == prefix_output.shape == suffix_output.shape, \ + "Output and input tensors must have the same shape" + assert prefix_lse.shape == suffix_lse.shape, \ + "Prefix and suffix LSE tensors must have the same shape" + if output_lse is not None: + assert output_lse.shape == prefix_lse.shape, \ + "Output LSE must have the same shape as input LSE tensors" + + # Handle inf values (replace inf with -inf for consistency) + p_lse = torch.where( + prefix_lse == float('inf'), + torch.tensor(float('-inf'), device=prefix_lse.device), + prefix_lse + ) + s_lse = torch.where( + suffix_lse == float('inf'), + torch.tensor(float('-inf'), device=suffix_lse.device), + suffix_lse + ) + + # Compute maximum LSE for numerical stability + max_lse = torch.maximum(p_lse, s_lse) # Shape: [num_query_heads, num_tokens] + + # Normalize LSE terms + p_lse = p_lse - max_lse # Shape: [num_query_heads, num_tokens] + s_lse = s_lse - max_lse # Shape: [num_query_heads, num_tokens] + + # Compute sum of exponentials + out_se = torch.exp(p_lse) + torch.exp(s_lse) # Shape: [num_query_heads, num_tokens] + + # Compute output_lse if provided + if output_lse is not None: + output_lse.copy_(torch.log(out_se) + max_lse) + + # Compute scaling factors + p_scale = torch.exp(p_lse) / out_se # Shape: [num_query_heads, num_tokens] + s_scale = torch.exp(s_lse) / out_se # Shape: [num_query_heads, num_tokens] + + # Reshape scales for broadcasting + p_scale = p_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1] + s_scale = s_scale.unsqueeze(-1) # Shape: [num_query_heads, num_tokens, 1] + + # Transpose outputs to match scaling dimensions + prefix_output = prefix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size] + suffix_output = suffix_output.permute(1, 0, 2) # Shape: [num_query_heads, num_tokens, head_size] + + # Compute merged output + out = prefix_output * p_scale + suffix_output * s_scale # Shape: [num_query_heads, num_tokens, head_size] + + # Transpose back and store in output + output.copy_(out.permute(1, 0, 2)) # Shape: [num_tokens, num_query_heads, head_size] + +def moe_all2all_create(dispatch_token_byte: int, + combine_token_byte: int, + max_expert_num: int, + max_token_num: int, + rank: int, + nrank: int) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Create the handle of MOE All-to-All communication. + API call order: + 1.Call torch_mlu_ops.moe_all2all_create(...) to obtain the CNCLEP handle and buffer tensor for All-to-All communication. Only needs to be done once. + 2.Gather all_exchange_info by performing an All-Gather operation on exchange_info across nrank processes. Only needs to be done once. + 3.Call torch.distributed.barrier() to ensure step 2 finish. Only needs to be done once. + 4.Call torch_mlu_ops.moe_all2all_init(...) to configure the all_exchange_info into the handle. Only needs to be done once. + 5.Call torch_mlu_ops.moe_all2all_dispatch(...) to route tokens to their designated experts. + 6.Call torch_mlu_ops.moe_all2all_combine(...) to restore tokens to their original locations. + 7.Call torch_mlu_ops.moe_all2all_destroy(...) to release the CNCLEP handle. Only needs to be done once. + + Args: + dispatch_token_byte (int): Byte size of a single token for dispatch All-to-All operation. + combine_token_byte (int): Byte size of a single token for combine All-to-All operation. + max_expert_num (int): Maximum number of experts participating in the All-to-All operation. + max_token_num (int): Maximum number of tokens to be processed. + rank (int): Rank ID of the current process [0~nrank-1]. + nrank (int): Total number of processes in the distributed group. + + Return: + A tuple of (handle, exchange_info_size, exchange_info, dispatch_send, dispatch_recv, combine_send and combine_recv). + handle: The CNCLEP handle with type of integer. + exchange_info_size: The size of exchange_info. + exchange_info: CPU tensor, shape is [exchange_info_size], and data type is torch.int8. + dispatch_send: MLU tensor, shape is [max_token_num * dispatch_token_byte], and data type is torch.int8. + dispatch_recv: MLU tensor, shape is [nrank * max_token_num * dispatch_token_byte], and data type is torch.int8. + combine_send: MLU tensor, shape is [max_token_num * combine_token_byte], and data type is torch.int8. + combine_recv: MLU tensor, shape is [nrank * max_token_num * combine_token_byte], and data type is torch.int8. + """ + return tmo.moe_all2all_create(dispatch_token_byte, combine_token_byte, max_expert_num, max_token_num, rank, nrank) + +def moe_all2all_init(handle: int, + all_exchange_info: torch.Tensor) -> None: + tmo.moe_all2all_init(handle, all_exchange_info) + +def moe_all2all_destroy(handle: int) -> None: + tmo.moe_all2all_destroy(handle) + +def moe_all2all_dispatch(handle: int, + token_byte: int, + token_num: int, + send_layout: torch.Tensor, + send_token_num: torch.Tensor, + recv_layout: torch.Tensor, + recv_token_num: torch.Tensor, + send_token: Optional[torch.Tensor] = None, + recv_token: Optional[torch.Tensor] = None, + ) -> None: + tmo.moe_all2all_dispatch(handle, token_byte, token_num, send_layout, send_token_num, recv_layout, recv_token_num, send_token, recv_token) + +def moe_all2all_combine(handle: int, + token_byte: int, + token_num: int, + send_src_layout: torch.Tensor, + send_dst_layout: torch.Tensor, + send_token: Optional[torch.Tensor] = None, + recv_token: Optional[torch.Tensor] = None, + ) -> None: + tmo.moe_all2all_combine(handle, token_byte, token_num, send_src_layout, send_dst_layout, send_token, recv_token) + +def gather_split(input: torch.Tensor, + gather_index: torch.Tensor, + valid_token_num: torch.Tensor, + output1: torch.Tensor, + output2: Optional[torch.Tensor] = None) -> None: + tmo.gather_split(input, + gather_index, + valid_token_num, + output1, + output2) + +def moe_all2all_gen_send_layout(token_count: torch.Tensor, + nrank: int) -> torch.Tensor: + return tmo.moe_all2all_gen_send_layout(token_count, nrank) + +def moe_all2all_gen_gather_index(token_num: torch.Tensor, pad_num: int, + return_cusum_token_count: bool = False): + if not return_cusum_token_count: + gather_by_expert_index, gather_by_rank_index, token_count, token_sum = \ + tmo.moe_all2all_gen_gather_index(token_num, pad_num) + return gather_by_expert_index, gather_by_rank_index, token_count, token_sum + else: + gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count = \ + tmo.moe_all2all_gen_gather_index(token_num, pad_num, return_cusum_token_count=True) + return gather_by_expert_index, gather_by_rank_index, token_count, token_sum, cusum_token_count + +def reshape_from_cache( + key: torch.Tensor, + value: Optional[torch.Tensor], + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + context_lengths: torch.Tensor, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + cache_seq_offset: Optional[torch.Tensor] = None, +) -> None: + tmo.reshape_from_cache( + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + context_lengths=context_lengths, + max_context_len=max_context_len, + context_seq_offset=context_seq_offset, + block_tables=block_tables, + cache_seq_offset=cache_seq_offset, + ) + + +def masked_indexer_select_paged_kv(query: torch.Tensor, + k_cache: torch.Tensor, + weights: torch.Tensor, + kv_cache_block_table: torch.Tensor, + cu_seq_q_lens: Optional[torch.Tensor], + cu_seq_k_lens: Optional[torch.Tensor], + k_context_lens: Optional[torch.Tensor], + k_cache_block_table: Optional[torch.Tensor], + is_prefill: bool, + index_topk: int, + kv_cache_block_size: int, + softmax_scale: float, + q_scale: Optional[torch.Tensor] = None, + k_scale_cache: Optional[torch.Tensor] = None, + sparse_block_table: Optional[torch.Tensor] = None, + sparse_context_lens: Optional[torch.Tensor] = None): + tmo.masked_indexer_select_paged_kv(query=query, + k_cache=k_cache, + weights=weights, + kv_cache_block_table=kv_cache_block_table, + cu_seq_q_lens=cu_seq_q_lens, + cu_seq_k_lens=cu_seq_k_lens, + k_context_lens=k_context_lens, + k_cache_block_table=k_cache_block_table, + is_prefill=is_prefill, + index_topk=index_topk, + kv_cache_block_size=kv_cache_block_size, + softmax_scale=softmax_scale, + q_scale=q_scale, + k_scale_cache=k_scale_cache, + sparse_block_table=sparse_block_table, + sparse_context_lens=sparse_context_lens) + +def masked_indexer_select_paged_kv_prefill( + query: torch.Tensor, + key_value: torch.Tensor, + weights: torch.Tensor, + kv_cache_block_table: torch.Tensor, + cu_seq_q_lens: torch.Tensor, + cu_seq_k_lens: torch.Tensor, + index_topk: int, + kv_cache_block_size: int, + softmax_scale: float, + q_scale: Optional[torch.Tensor] = None, + k_scale_cache: Optional[torch.Tensor] = None, + sparse_block_table: Optional[torch.Tensor] = None, + sparse_context_lens: Optional[torch.Tensor] = None, + kv_cache_block_table_offset: Optional[torch.Tensor] = None, + compress_ratio: int = 1, +): + return tmo.masked_indexer_select_paged_kv( + query=query, + k_cache=key_value, + weights=weights, + kv_cache_block_table=kv_cache_block_table, + cu_seq_q_lens=cu_seq_q_lens, + cu_seq_k_lens=cu_seq_k_lens, + k_context_lens=None, + k_cache_block_table=None, + is_prefill=True, + index_topk=index_topk, + kv_cache_block_size=kv_cache_block_size, + softmax_scale=softmax_scale, + q_scale=q_scale, + k_scale_cache=k_scale_cache, + sparse_block_table=sparse_block_table, + sparse_context_lens=sparse_context_lens, + kv_cache_block_table_offset=kv_cache_block_table_offset, + compress_ratio=compress_ratio, + is_score_float=True, + ) + +def masked_indexer_select_paged_kv_decode( + query: torch.Tensor, + k_cache: torch.Tensor, + weights: torch.Tensor, + kv_cache_block_table: torch.Tensor, + k_context_lens: Optional[torch.Tensor], + k_cache_block_table: Optional[torch.Tensor], + index_topk: int, + kv_cache_block_size: int, + softmax_scale: float, + q_scale: Optional[torch.Tensor] = None, + k_scale_cache: Optional[torch.Tensor] = None, + sparse_block_table: Optional[torch.Tensor] = None, + sparse_context_lens: Optional[torch.Tensor] = None, + kv_cache_block_table_offset: Optional[torch.Tensor] = None, + compress_ratio: int = 1, +): + query_len = query.shape[1] + #k_context_lens = k_context_lens // compress_ratio + return tmo.masked_indexer_select_paged_kv( + query=query, + k_cache=k_cache, + weights=weights, + kv_cache_block_table=kv_cache_block_table, + cu_seq_q_lens=None, + cu_seq_k_lens=None, + k_context_lens=k_context_lens, + k_cache_block_table=k_cache_block_table, + is_prefill=False, + index_topk=index_topk, + kv_cache_block_size=kv_cache_block_size, + softmax_scale=softmax_scale, + q_scale=q_scale, + k_scale_cache=k_scale_cache, + sparse_block_table=sparse_block_table, + sparse_context_lens=sparse_context_lens, + kv_cache_block_table_offset=kv_cache_block_table_offset, + compress_ratio=compress_ratio, + is_score_float=True, + ) + + +def concat_block_table( + first_block_table: torch.Tensor, + first_context_lens: torch.Tensor, + second_block_table: torch.Tensor, + second_context_lens: torch.Tensor, + new_block_table: Optional[torch.Tensor] = None, + new_context_lens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Concatenate two different block tables, return the concatenated result. + Math: + new_context_lens = first_context_lens + second_context_lens + total_seq = first_context_lens.size(0) + for i in range(total_seq): + new_block_table[i, :first_context_lens[i]] = first_block_table[i, :first_context_lens[i]] + new_block_table[i, first_context_lens[i]:first_context_lens[i]+second_context_lens[i]] = second_block_table[i, :second_context_lens[i]] + Args: + first_block_table (torch.Tensor): + The first block table of shape `[total_seq, first_max_blkn]`. + first_context_lens (torch.Tensor): + The context lens of the first block table of shape `[total_seq,]`. + second_block_table (torch.Tensor): + The second block table of shape `[total_seq, second_max_blkn]`. + second_context_lens (torch.Tensor): + The context lens of the second block table of shape `[total_seq,]`. + new_block_table (Optional[torch.Tensor]): + The new block table of shape `[total_seq, max_new_block_number]`. + if not None, the max_new_block_number must be large enough for the concatenated block_table + Default: `None`. + new_context_lens (Optional[torch.Tensor]): + The new context lens of shape `[total_seq,]`. Default: `None`. + + Returns: + new_block_table (torch.Tensor): + The concatenated block table of shape `[total_seq, max_new_block_number]`. + new_context_lens (torch.Tensor): + The new context lens of shape `[total_seq,]`, equals first_context_lens + second_context_lens + Type: + INT32 + """ + return tmo.concat_block_table( + first_block_table, + first_context_lens, + second_block_table, + second_context_lens, + new_block_table, + new_context_lens, + ) + +def fused_mhc_post( + x: torch.Tensor, # (N, D) float|bf16 + residual: torch.Tensor, # (N, HC, D) float|bf16 + post: torch.Tensor, # (N, HC) 固定为float + comb: torch.Tensor, # (N, HC, HC) 固定为float + compute_rms: bool, + eps: float, + output: torch.Tensor = None, # (N, HC, D) 同输入类型 + output_rms = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + + Math: + output = post * x + (comb * residual).sum(dim=1) + output_rms = rsqrt(x.square().mean(dim=-1)) + + Args: + x (torch.Tensor): Shape is [N, D]. + residual (torch.Tensor): Shape is [N, HC, D]. + post (torch.Tensor): Shape is [N, HC]. + comb (torch.Tensor): Shape is [N, HC, HC]. + compute_rms (bool): Whether to compute output_rms. + eps (float): The eps of normalization. + output (torch.Tensor, optional): Shape is [N, HC, D]. Defaults to None. + output_rms (torch.Tensor, optional): Shape is [N]. Defaults to None. + + Returns: + output, output_rms + + Limitation: + D must be 4096. + HC must be 4. + """ + out = tmo.fused_mhc_post( + x, + residual, + post, + comb, + compute_rms, + eps, + output, + output_rms, + ) + + return out if compute_rms else (out, None) + +def fused_compress_multi_kv(kv: torch.Tensor, # (BS, D) float|bf16 + score: torch.Tensor, # (BS, D) float|bf16 + kv_state: torch.Tensor, # (max_B, coff * R, D) float + score_state: torch.Tensor, # (max_B, coff * R, D) float + batch_ids: torch.Tensor, # (B,) int32 + cu_seqlens: torch.Tensor, # (B,) int32 + ape: torch.Tensor, # (R, D) float + max_seqlen:int, + overlap: bool, + compressed_kv: torch.Tensor # (BS, head_dim) float|bf16 + ): + tmo.fused_compress_multi_kv( + kv = kv, + score = score, + kv_state = kv_state, + score_state = score_state, + cu_seqlens = cu_seqlens, + batch_ids = batch_ids, + ape = ape, + max_seqlen = max_seqlen, + overlap = overlap, + compressed_kv = compressed_kv, + ) + +def fused_compress_single_kv( + kv: torch.Tensor, # (T, D) float|bf16 + score: torch.Tensor, # (T, D) float|bf16 + position: torch.Tensor, # (B,) int32 + ape: torch.Tensor, # (ratio, D) float|bf16 + kv_state: torch.Tensor, # (B, R, D) float|bf16 + score_state: torch.Tensor, # (B, R, D) float|bf16 + gamma: torch.Tensor, # (d) + sin: torch.Tensor, # (-1, rope_dim) + cos: torch.Tensor, # (-1, rope_dim) + hadamard_matrix: Optional[torch.Tensor], # (d, d) + slot_mapping: torch.Tensor, # (B,) int32 + kv_cache: torch.Tensor, # (-1, BLKS, head_dim) bf16|int8|fp8 + kv_cache_scale: Optional[torch.Tensor], # (-1, BLKS) float + eps: float, + overlap: bool, + rotate: bool, + state_idx: torch.Tensor, + cu_query_len: torch.Tensor | None = None, +): + """ + + Math: + + + Args: + kv (torch.Tensor): Shape is [B, S, D]. + score (torch.Tensor): Shape is [B, S, D]. + position (torch.Tensor): Shape is [B]. + ape (torch.Tensor): Shape is [ratio, D]. + kv_state (torch.Tensor): Shape is [max_B, R, D]. + score_state (torch.Tensor): Shape is [max_B, R, D]. + gamma (torch.Tensor): Shape is [head_dim]. + sin (torch.Tensor): Shape is [table_len, rope_dim]. + cos (torch.Tensor): Shape is [table_len, rope_dim]. + hadamard_matrix (torch.Tensor): Shape is [head_dim, head_dim]. + slot_mapping (torch.Tensor): Shape is [B]. + kv_cache (torch.Tensor): Shape is [cache_len, block_size, hs]. + kv_cache_scale (torch.Tensor): Shape is [cache_len, block_size]. + eps (flost): The eps of normalization. + overlap (bool): Whether to overlap. + rotate (bool): Whether to rotate. + + Type: + kv: BF16, FP32 + score: same as kv + position: INT32 + ape: FP32 + kv_state: FP32 + score_state: FP32 + gamma: same as kv + sin: same as kv + cos: same as kv + hadamard_matrix: same as kv + slot_mapping: INT32 + kv_cache: BF16, FP32 + kv_cache_scale: FP32 + + Returns: + Only support inplace outputs, include kv_state, score_state, kv_cache, kv_cache_scale + + Note: + coff = overlap + 1 + D = coff * head_dim + R = coff * ratio + """ + token_num, coff_dim = kv.shape + + # TODO: force user_tmo = 0 after supporting mtp. + bsz = state_idx.numel() + kv = kv.unsqueeze(1) + score = score.unsqueeze(1) + if kv_cache.dim() == 4: + paged_num, head_num, block_size, head_dim = kv_cache.shape + assert head_num == 1 + kv_cache = kv_cache.view(paged_num, block_size, head_dim) + return tmo.fused_compress_single_kv( + kv=kv, + score=score, + position=position, + state_ids=state_idx, + ape=ape, + kv_state=kv_state, + score_state=score_state, + gamma=gamma, + sin=sin, + cos=cos, + hadamard_matrix=hadamard_matrix if rotate else None, + slot_mapping=slot_mapping, + kv_cache=kv_cache, + kv_cache_scale=kv_cache_scale, + eps=eps, + overlap=overlap, + ) + +def convertBlockTable(block_table, blks, incseq): + if blks == 1: + return block_table + else: + expanded = block_table.unsqueeze(1).repeat(1, blks) + result = expanded * blks + incseq + return result.flatten() + +def get_window_block_tables(window_size : int, + block_size : int, #blocksize of block_table + seq_k_lens: torch.Tensor, + query_start_loc: torch.Tensor, + block_table: Optional[torch.Tensor]=None, # shape (batch, max_blocks) + window_block_tables:Optional[torch.Tensor]=None, # shape (total_seq, max_blocks) + window_context_lens:Optional[torch.Tensor]=None): # shape (total_seq) + tmo.get_window_block_tables(window_block_tables = window_block_tables, + window_context_lens = window_context_lens, + seq_k_lens = seq_k_lens, + query_start_loc = query_start_loc, + block_table = block_table, + block_size = block_size, + window_size = window_size,) + +def get_compress_block_tables(ratio: int, + block_size: int, + seq_k_lens: torch.Tensor, # k lens before compression, shape (batch) + query_start_loc: torch.Tensor, # shape (batch+1) + offset: torch.Tensor, # shape (batch) + block_table: torch.Tensor, # shape (batch, max_blocks) + compress_block_tables: torch.Tensor, # shape (total_seq, max_blocks) + compress_context_lens: torch.Tensor): # shape (total_seq) + tmo.get_compress_block_tables( + compress_block_tables = compress_block_tables, + compress_context_lens = compress_context_lens, + seq_k_lens = seq_k_lens, + query_start_loc = query_start_loc, + offset = offset, + block_table = block_table, + block_size = block_size, + ratio = ratio, + ) + + +def hc_split_sinkhorn(mixes: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + pre_scale: Optional[torch.Tensor] = None, + hc_mult: int = 4, + sinkhorn_iter: int = 20, + eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return tmo.hc_split_sinkhorn( + mixes = mixes, + hc_scale = hc_scale, + hc_base = hc_base, + pre_scale = pre_scale, + hc_mult = hc_mult, + sinkhorn_iter = sinkhorn_iter, + eps = eps, + ) + + +def fused_indexer_q(q: torch.Tensor, + w_q: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + position_id: torch.Tensor, + output: Optional[torch.Tensor] = None, + hadamard_matrix: Optional[torch.Tensor] = None, + w_q_scale: Optional[torch.Tensor] = None, + output_quant_mode: str = 'none', + output_scale: Optional[torch.Tensor] = None, + interleaved: bool = True, + rope_at_front: bool = True): + return tmo.fused_indexer_q( + q = q, + w_q = w_q, + sin = sin, + cos = cos, + position_id = position_id, + output = output, + hadamard_matrix = hadamard_matrix, + w_q_scale = w_q_scale, + output_quant_mode = output_quant_mode, + output_scale = output_scale, + interleaved = interleaved, + rope_at_front = rope_at_front) + +def fused_mla_q_v2( + input_q: torch.Tensor, + gamma: torch.Tensor, + smooth_quant_scale: Optional[torch.Tensor], + weight_b: torch.Tensor, + weight_b_scale: Optional[torch.Tensor], + sin: torch.Tensor, + cos: torch.Tensor, + position_id: torch.Tensor, + output: Optional[torch.Tensor] = None, + eps: float = 1e-6, + interleaved: bool = True, + store_norm: bool = False, + output_norm: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + This function applies MLA (Multi-head Latent Attention) v2 Query (Q) preprocessing. + The fusion logic includes: RMSNorm -> Quant(Optional) -> MatMul -> RMSNorm -> RoPE. + + Math: + qr = rmsnorm(input_q, gamma, eps) + if quant: + qr, q_scale = per_token_quant(norm_out, smooth_quant_scale) + q = matmul(qr, q_scale, weight_b, weight_b_scale) + q = q.reshape(batch, seq, n_local_heads, head_dim) + q = rsqrt(q.square().mean(-1, keepdim=True) + eps) + out = apply_rotary_embedding(q, sin, cos, position_id, interleaved) + + Args: + input_q (torch.Tensor): + The input latent query tensor. Shape is (batch, seq, q_lora_rank). + gamma (torch.Tensor): + The scaling parameter for the initial RMSNorm. Shape is (q_lora_rank). + smooth_quant_scale (Optional[torch.Tensor]): + Scale tensor for SmoothQuant migration. Can be None. Shape is (q_lora_rank). + weight_b (torch.Tensor): + The Q-projection weight tensor. Shape is (n_local_heads, head_dim, q_lora_rank). + weight_b_scale (Optional[torch.Tensor]): + The per-channel quantization scales for weight_b. Shape is (n_local_heads, head_dim). + sin (torch.Tensor): + Rotary embedding sine table. Shape is (max_rotary_seq_len, rotary_head_dim). + cos (torch.Tensor): + Rotary embedding cosine table. Shape is (max_rotary_seq_len, rotary_head_dim). + position_id (torch.Tensor): + Indices for the RoPE tables. Shape is (batch,). + output (Optional[torch.Tensor]): + Optional output tensor for the final processed Q. Shape is (batch, seq, n_local_heads, head_dim). + eps (float): + Small constant for RMSNorm numerical stability. Default: 1e-6. + interleaved (bool): + If True, apply interleaved rotary embedding, otherwise folded. Default: True. + store_norm (bool): + If True, the intermediate RMSNorm result (pre-MatMul) will be returned. Default: False. + output_norm (Optional[torch.Tensor]): + Optional tensor to store the intermediate RMSNorm result. Shape: (batch, seq, q_lora_rank). + + Type: + input_q, gamma, sin, cos: bfloat16. + weight_b: int8, same as input_q. + weight_b_scale, smooth_quant_scale: float32. + position_id: int32. + output: same as input_q. + + Return: + Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + - If store_norm=False: output + - If store_norm=True: (..., output_norm) is appended to the return. + """ + return tmo.fused_mla_q_v2( + input_q=input_q, + gamma=gamma, + smooth_quant_scale=smooth_quant_scale, + weight_b=weight_b, + weight_b_scale=weight_b_scale, + sin=sin, + cos=cos, + position_id=position_id, + output=output, + eps=eps, + interleaved=interleaved, + store_norm=store_norm, + output_norm=output_norm, + ) + +def update_compressor_states( + kv_state, # (max_batch, (overlap+1)*ratio + K, dim) + score_state, # (max_batch, (overlap+1)*ratio + K, dim) + accept_tokens: torch.Tensor, # (bsz,) + batch_to_kv_state: torch.Tensor, # (bsz,) + positions: torch.Tensor, # (bsz,) + cu_query_len: torch.Tensor, # (bsz+1,) + overlap: bool, + K: int +): + bsz = batch_to_kv_state.numel() + ratio = (kv_state.size(1) - K) // (overlap + 1) + start_positions = positions[cu_query_len[:bsz]] + end_positions = start_positions + accept_tokens + + for i in range(bsz): + start_pos = start_positions[i] + end_pos = end_positions[i] + # Skip if sequence len does not exceed coff * ratio. + if (overlap and end_pos < 2 * ratio) or (not overlap and end_pos < ratio): + continue + + # Skip if compression condition does not meets. + if (start_pos // ratio) == (end_pos // ratio) and start_pos % ratio != 0: + continue + + state_idx = batch_to_kv_state[i] + + if overlap: + length = end_pos - start_pos + start_pos % ratio + else: + length = end_pos % ratio + start = ratio + end = start + length + + if length == 0: + continue + + kv_state[state_idx, :length] = kv_state[state_idx, start:end].clone() + score_state[state_idx, :length] = score_state[state_idx, start:end].clone() diff --git a/vllm_mlu/_mlu_utils.py b/vllm_mlu/_mlu_utils.py new file mode 100644 index 0000000..19f29fd --- /dev/null +++ b/vllm_mlu/_mlu_utils.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import os +import torch +import vllm.envs as envs + + +def _check_env(env, default=False): + if env in os.environ: + return os.environ[env].lower() in ["true", "1"] + return default + + +def _check_env_value(env, default=0): + if env in os.environ: + if not os.environ[env].isdigit(): + raise ValueError(f"'{env}' should be set with integer") + value = int(os.environ[env]) + return value + return default + + +def _check_env_float(env, default=0): + if env in os.environ: + try: + value = float(os.environ[env]) + except ValueError: + raise ValueError(f"'{env}' should be set with float") + return value + return default + + +# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency. +VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False) + +# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency. +VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False) + +# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference. +VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False) + +# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference. +VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False) + +# VLLM_DUMP_MLU_INFO_DEBUG: Dump device debug info when running vLLM inference. +VLLM_DUMP_MLU_INFO_DEBUG = _check_env("VLLM_DUMP_MLU_INFO_DEBUG", default=False) + +# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler. +VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False) + +# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True. +# Set to False to disable warning messages. +VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True) + +# VLLM_AVG_MOE_EN: make moe experts workload balance, default value is False. +VLLM_AVG_MOE_EN = _check_env("VLLM_AVG_MOE_EN", default=False) or _check_env("VLLM_RANDOM_MOE_EN", default=False) +VLLM_RANDOM_MOE_EN = _check_env("VLLM_RANDOM_MOE_EN", default=False) + +# VLLM_LOGITS_USE_ALL_GATHER: use allgather for logits collection, default value is False. +VLLM_LOGITS_USE_ALL_GATHER = _check_env("VLLM_LOGITS_USE_ALL_GATHER", default=False) + +VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE) +VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE) +VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO) +VLLM_DUMP_MLU_INFO_DEBUG = (VLLM_DUMP_MLU_INFO_DEBUG and VLLM_DUMP_MLU_INFO_EN) + +# VLLM_V1_USE_UNCHUNK_SCHED: v1 use unchunk scheduler, default value is True. +VLLM_V1_USE_UNCHUNK_SCHED = _check_env("VLLM_V1_USE_UNCHUNK_SCHED", default=True) + +# VLLM_V1_MIN_PREFILL_BATCH: the min scheduling batch in v1, default is 1. +VLLM_V1_MIN_PREFILL_BATCH = _check_env_value("VLLM_V1_MIN_PREFILL_BATCH", default=1) + +# VLLM_V1_USE_FULL_GRAPH: v1 use full graph capture, default value is True. +VLLM_V1_USE_FULL_GRAPH = _check_env("VLLM_V1_USE_FULL_GRAPH", default=True) + +# VLLM_V1_BENCHMARK: v1 benchmark, default value is False. +VLLM_V1_BENCHMARK = _check_env("VLLM_V1_BENCHMARK", default=False) + +# VLLM_MTP_DEBUG: use to show mtp accepted rate, default value is False. +VLLM_MTP_DEBUG = _check_env("VLLM_MTP_DEBUG", default=False) + +# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None +VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False) + +# VLLM_MTP_FIXED_ACCEPTANCE_RATE: use fixed acceptance rate, default value is None. +VLLM_MTP_FIXED_ACCEPTANCE_RATE = _check_env_float("VLLM_MTP_FIXED_ACCEPTANCE_RATE", default=None) + +# VLLM_MTP_NO_QUANT: mtp use origin dtype, quant_config use None +VLLM_MTP_NO_QUANT = _check_env("VLLM_MTP_NO_QUANT", default=False) + +# VLLM_V1_UNCHUNK_SCHED_LOG: print v1 unchunk scheduler state +VLLM_V1_UNCHUNK_SCHED_LOG = _check_env("VLLM_V1_UNCHUNK_SCHED_LOG", default=False) + +# VLLM_MOE_PREFILL_CHUNK_SIZE: in number of tokens. enabled when > 0. +VLLM_MOE_PREFILL_CHUNK_SIZE = _check_env_value("VLLM_MOE_PREFILL_CHUNK_SIZE", default=0) + +# VLLM_CI_ACCURACY_TEST: CI accuracy test, default value is False. +VLLM_CI_ACCURACY_TEST = _check_env("VLLM_CI_ACCURACY_TEST", default=False) + +# VLLM_DISAGG_TRANS_ALL_BLOCKS: optimize the performance of disagg +VLLM_DISAGG_TRANS_ALL_BLOCKS = _check_env("VLLM_DISAGG_TRANS_ALL_BLOCKS", default=True) + +# vllm disagg debug +VLLM_DISAGG_CNPX_EXECUTE = _check_env("VLLM_DISAGG_CNPX_EXECUTE", default=False) +VLLM_DISAGG_CNPX_REQUEST = _check_env("VLLM_DISAGG_CNPX_REQUEST", default=False) +VLLM_DISAGG_FAKE_DECODER = _check_env("VLLM_DISAGG_FAKE_DECODER", default=False) \ No newline at end of file diff --git a/vllm_mlu/attention/__init__.py b/vllm_mlu/attention/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/attention/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/attention/layer.py b/vllm_mlu/attention/layer.py new file mode 100644 index 0000000..121164d --- /dev/null +++ b/vllm_mlu/attention/layer.py @@ -0,0 +1,351 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Any, cast + +import torch +from torch import nn + +import vllm.envs as envs +from vllm.attention import AttentionType +from vllm.attention.backends.abstract import MLAAttentionImpl +from vllm.attention.layer import Attention, MLAAttention, _init_kv_cache_quant +from vllm.attention.selector import get_attn_backend + +from vllm.config.cache import CacheConfig +from vllm.config.vllm import QuantizationConfig, VllmConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.platforms import current_platform +from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm_mlu.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu.v1.kv_cache_interface import ( + MLUFullAttentionSpec, + MLUMLAAttentionSpec, + MLUSlidingWindowSpec, +) + +@maybe_transfer_kv_layer +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + kwargs: dict[str, Any] = {}, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add return for self.impl.forward and it's param kwargs + ''' + output = self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + kwargs=kwargs, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + return output + +class Attention_MluHijack(Attention): + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + if self.sliding_window is not None: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: replace SlidingWindowSpec with MLUSlidingWindowSpec. + ''' + return MLUSlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + sliding_window=self.sliding_window, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + else: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: replace FullAttentionSpec with MLUFullAttentionSpec. + ''' + return MLUFullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + +class MLAAttention_MluHijack(MLAAttention): + def __init__( + self, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + kv_b_proj: ColumnParallelLinear, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_sparse: bool = False, + indexer: object | None = None, + **extra_impl_args, + ) -> None: + nn.Module.__init__(self) + self.num_heads = num_heads + self.scale = scale + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + # self.head_size = kv_lora_rank + qk_rope_head_dim + self.layer_name = prefix + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: insert num_kv_heads for mlu platform + ''' + self.head_size = qk_nope_head_dim + qk_rope_head_dim + self.num_kv_heads = extra_impl_args.pop("num_kv_heads", None) + if self.num_kv_heads is None: + self.num_kv_heads = num_heads + + self.decoder_attn_dtype = None + decoder_attn_dtype = get_current_vllm_config().mlu_config.decoder_attn_dtype + if decoder_attn_dtype in ["int8", "fp8_e4m3", "fp8"]: + self.decoder_attn_dtype = ( + torch.int8 if decoder_attn_dtype == "int8" + else torch.float8_e4m3fn + ) + extra_impl_args['decoder_attn_dtype'] = self.decoder_attn_dtype + ''' + ================== + End of MLU Hijack + ================== + ''' + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + calculate_kv_scales = False + + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + ) + + dtype = torch.get_default_dtype() + self.attn_backend = get_attn_backend( + self.head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=True, + use_sparse=use_sparse, + ) + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) + self.impl = impl_cls( + self.num_heads, + self.head_size, + self.scale, + self.num_kv_heads, + None, # alibi_slops + None, # sliding_window + kv_cache_dtype, + None, # logits_soft_cap + AttentionType.DECODER, # attn_dtype + None, # kv_sharing_target_layer_name + **extra_impl_args, + ) + self.dtype = dtype + + self.use_direct_call = not current_platform.opaque_attention_op() + + if current_platform.is_out_of_tree(): + self.use_direct_call = False + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support kv8 and deepseek v3.2 + ''' + self.kv_cache = [ + [torch.tensor([]), torch.tensor([]), torch.tensor([])] + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) + ] + self.impl.use_mla = True + ''' + ================== + End of MLU Hijack + ================== + ''' + + self.use_sparse = use_sparse + + # Initialize q/k/v range constants. + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + kv_cache_dtype = kv_cache_dtype_str_to_dtype( + self.kv_cache_dtype, vllm_config.model_config + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: replace MLAAttentionSpec with MLUMLAAttentionSpec. + ''' + index_head_dim, index_n_heads = 0, 0 + if vllm_config.model_config.hf_text_config.model_type == "deepseek_v32": + index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim + index_n_heads = 1 + + if vllm_config.model_config.hf_text_config.model_type == "deepseek_v4": + index_head_dim = vllm_config.model_config.hf_text_config.index_head_dim + index_n_heads = 1 + + return MLUMLAAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_size, + dtype=kv_cache_dtype, + cache_dtype_str=vllm_config.cache_config.cache_dtype, + index_head_dim=index_head_dim, + index_n_heads=index_n_heads, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.Size | None = None, + kwargs: dict[str, Any] = {}, + ) -> torch.Tensor: + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) + + assert not self.use_direct_call, "MLU-V1 does not support direct call." + if self.attn_backend.accept_output_buffer: + output_lse = None + output_shape = (output_shape if output_shape is not None else query.shape) + output_shape = [output_shape[0], self.num_heads * self.v_head_dim] + + output = torch.empty( + output_shape, + dtype=self.dtype if query.dtype == torch.int8 else query.dtype, + device=query.device, + ) + hidden_size = output_shape[-1] + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.v_head_dim) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.v_head_dim) + if not kwargs: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name + ) + attn_output_list = output + else: + attn_output_list = unified_attention_with_output( + query, key, value, output, self.layer_name, kwargs=kwargs) + if isinstance(attn_output_list, (list, tuple)) and len(attn_output_list) > 1: + output_lse = attn_output_list[1] + if output_lse is not None: + return output.view(-1, hidden_size), output_lse + else: + return output.view(-1, hidden_size) + ''' + ================== + End of MLU Hijack + ================== + ''' + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name + ) + +MluHijackObject.apply_hijack( + Attention, + Attention.get_kv_cache_spec, + Attention_MluHijack.get_kv_cache_spec, +) +MluHijackObject.apply_hijack( + MLAAttention, + MLAAttention.__init__, + MLAAttention_MluHijack.__init__, +) +MluHijackObject.apply_hijack( + MLAAttention, + MLAAttention.get_kv_cache_spec, + MLAAttention_MluHijack.get_kv_cache_spec, +) +MluHijackObject.apply_hijack( + MLAAttention, + MLAAttention.forward, + MLAAttention_MluHijack.forward, +) diff --git a/vllm_mlu/attention/utils/__init__.py b/vllm_mlu/attention/utils/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/attention/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/attention/utils/kv_transfer_utils.py b/vllm_mlu/attention/utils/kv_transfer_utils.py new file mode 100644 index 0000000..3fb9ee1 --- /dev/null +++ b/vllm_mlu/attention/utils/kv_transfer_utils.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +import inspect +from collections.abc import Callable +from functools import wraps + +from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) + + +def maybe_transfer_kv_layer(func: Callable) -> Callable: + """Decorator that handles KV layer transfer prior and after execution of + an attention layer, if enabled. Otherwise, the wrapper is a no-op. + + On entry: waits for the KV layer from the connector. + On exit: saves the KV layer to the connector. + """ + # Import at runtime to avoid circular dependency + from vllm.attention.layer import get_attention_context + + # Inspect the signature ONCE when the decorator is applied. + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + + # Find the index of 'layer_name' parameter. + try: + layer_name_index = param_names.index("layer_name") + except ValueError as e: + raise TypeError( + f"Function {func.__name__} must have a 'layer_name' parameter" + ) from e + + @wraps(func) + def wrapper(*args, **kwargs): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return func(*args, **kwargs) + + layer_name: str = args[layer_name_index] + + # Extract attention context (layer-specific metadata, layer, and kv_cache) + attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name) + connector = get_kv_transfer_group() + if attn_metadata is None or not connector.has_connector_metadata(): + return func(*args, **kwargs) + + # Wait for KV layer on entry + connector.wait_for_layer_load(layer_name) + + # Execute the function + result = func(*args, **kwargs) + + # Save KV cache layer on exit + + if kwargs is None or kwargs.get("save_kv_layer", True): + connector.save_kv_layer(layer_name, kv_cache, attn_metadata) + + return result + + return wrapper diff --git a/vllm_mlu/benchmarks/__init__.py b/vllm_mlu/benchmarks/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/benchmarks/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/benchmarks/datasets.py b/vllm_mlu/benchmarks/datasets.py new file mode 100644 index 0000000..6a4bd8d --- /dev/null +++ b/vllm_mlu/benchmarks/datasets.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena +""" + +from tempfile import NamedTemporaryFile + +import numpy as np + +from vllm.benchmarks.datasets import RandomMultiModalDataset +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video( + self, width: int, height: int, num_frames: int + ) -> dict: + """Generate synthetic video with random values. + + Creates a video with random pixel values, encodes it to MP4 format, + and returns the content as bytes. + """ + import cv2 + + random_pixels = self._rng.integers( + 0, + 256, + (num_frames, height, width, 3), + dtype=np.uint8, + ) + + # Create a temporary video file in memory + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + fps = 30 # frames per second + + with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: + temp_path = temp_file.name + + # Create video writer + video_writer = cv2.VideoWriter( + temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height) + ) + + if not video_writer.isOpened(): + raise RuntimeError("Failed to create video writer") + + for frame in random_pixels: + video_writer.write(frame) + + video_writer.release() + temp_file.close() + + # Read the video file content + with open(temp_path, "rb") as f: + video_content = f.read() + + return {"bytes": video_content} + + +MluHijackObject.apply_hijack( + RandomMultiModalDataset, + RandomMultiModalDataset.generate_synthetic_video, + vllm__benchmarks__datasets__RandomMultiModalDataset__generate_synthetic_video, +) \ No newline at end of file diff --git a/vllm_mlu/compilation/__init__.py b/vllm_mlu/compilation/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/compilation/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/compilation/fix_functionalization.py b/vllm_mlu/compilation/fix_functionalization.py new file mode 100644 index 0000000..f72fd73 --- /dev/null +++ b/vllm_mlu/compilation/fix_functionalization.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 + +import operator +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.platforms import current_platform +from vllm.logger import init_logger + +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fx_utils import is_func + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +class FixFunctionalizationPass_MluHijack(FixFunctionalizationPass): + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph): + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug( + "XPU platform does not support fix functionalizationpass currently." + ) + return + + self.nodes_to_remove: list[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: skip custom op on mlu + ''' + if current_platform.is_out_of_tree(): + continue # skip the count on mlu + ''' + ================== + End of MLU Hijack + ================== + ''' + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs["query"] + key = kwargs["key"] + getitem_nodes = self.getitem_users(node) + + if ( + is_func(query, operator.getitem) + and is_func(key, operator.getitem) + and query.args[0] == key.args[0] + and is_func(query.args[0], torch.ops.aten.split_with_sizes.default) + and all( + is_func(user, torch.ops.aten.slice_scatter.default) + for getitem_node in getitem_nodes.values() + for user in getitem_node.users + ) + ): + # Pattern where query and key are slices of an mm_node. + # While functionalized, results at [1] and [2] are scattered + # back into mm_node. So after de-functionalization, we can + # just use mm_node directly. + + mm_node = query.args[0].args[0] + for user in getitem_nodes.values(): + for user_of_getitem in user.users: + if is_func( + user_of_getitem, torch.ops.aten.slice_scatter.default + ): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + else: + # Directly replace the auto_functionalize(rotary_embedding) + # with the inplace rotary_embedding. In theory, we shouldn't + # do this blindly, but in practice in vLLM it's ok. The best + # solution is to use auto_functionalization_v2 and then use + # inductor's builtin defunctionalization (reinplacing) pass. + mutated_args = {1: "query", 2: "key"} + self.defunctionalize(graph, node, mutated_args) + + # rms_norm replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: "input", 2: "residual"} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + mutated_args = {1: "result", 2: "residual"} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: "result", 2: "scale", 3: "residual"} + self.defunctionalize(graph, node, mutated_args) + elif at_target in [ + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default, + ]: + mutated_args = {1: "result"} + self.defunctionalize(graph, node, mutated_args) + # For some reason we need to specify the args for both + # silu_and_mul and silu_and_mul_quant. The kwargs + # pathway gets the wrong answer. + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: "result"} + self.defunctionalize( + graph, node, mutated_args, args=("result", "input") + ) + elif at_target == torch.ops._C.silu_and_mul_quant.default: + mutated_args = {1: "result"} + self.defunctionalize( + graph, node, mutated_args, args=("result", "input", "scale") + ) + elif ( + hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant") + and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default + ): + mutated_args = {1: "result", 2: "result_block_scale"} + self.defunctionalize( + graph, + node, + mutated_args, + args=( + "result", + "result_block_scale", + "input", + "input_global_scale", + ), + ) + # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. + elif at_target == torch.ops._C.fused_qk_norm_rope.default: + mutated_args = {1: "qkv"} + args = ( + "qkv", + "num_heads_q", + "num_heads_k", + "num_heads_v", + "head_dim", + "eps", + "q_weight", + "k_weight", + "cos_sin_cache", + "is_neox", + "position_ids", + ) + self.defunctionalize(graph, node, mutated_args=mutated_args, args=args) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug( + "De-functionalized %s nodes, removed %s nodes", count, count_removed + ) + self.nodes_to_remove.clear() + + +MluHijackObject.apply_hijack( + FixFunctionalizationPass, + FixFunctionalizationPass.__call__, + FixFunctionalizationPass_MluHijack.__call__ +) \ No newline at end of file diff --git a/vllm_mlu/compilation/mlu_graph.py b/vllm_mlu/compilation/mlu_graph.py new file mode 100644 index 0000000..f1ac7ed --- /dev/null +++ b/vllm_mlu/compilation/mlu_graph.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import dataclasses +from collections.abc import Callable +from contextlib import ExitStack +from typing import Any +from unittest.mock import patch + +import torch + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.torch_utils import weak_ref_tensors +from vllm.compilation.cuda_graph import ( + CUDAGraphEntry, + CUDAGraphWrapper, + CUDAGraphOptions, +) +from vllm_mlu.v1.attention.backends.utils import MLUInferMode + +logger = init_logger(__name__) + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: specialized graph entry for prefill graphs +''' +@dataclasses.dataclass +class PrefillGraphEntry: + batch_size: int = 0 + seq_len: int = 0 + cudagraph: torch.mlu.MLUGraph | None = None + output: Any | None = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: list[int] | None = None +''' +================== +End of MLU Hijack +================== +''' + + +class MLUGraphWrapper(CUDAGraphWrapper): + + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + cudagraph_options: CUDAGraphOptions | None = None, + ): + super().__init__(runnable, vllm_config, runtime_mode, cudagraph_options) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add separate dict for prefill graph entries + ''' + self.prefill_mlugraph_entry: PrefillGraphEntry | None = None + ''' + ================== + End of MLU Hijack + ================== + ''' + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: check if running in prefill mode + ''' + def is_running_in_prefill(self, entry: PrefillGraphEntry | None = None) -> bool: + forward_context = get_forward_context() + if forward_context.attn_metadata is None: + return False + infer_mode = forward_context.attn_metadata['common_metadata'].infer_mode + seq_lens_cpu = forward_context.attn_metadata['common_metadata'].seq_lens_cpu + if entry is not None \ + and infer_mode == MLUInferMode.PREFILL_ONLY \ + and seq_lens_cpu.size(0) == entry.batch_size \ + and (seq_lens_cpu == entry.seq_len).all().item(): + return True + return False + ''' + ================== + End of MLU Hijack + ================== + ''' + + def __call__( + self, + is_capturing_prefill: bool = False, + prefill_enable_mlugraph: bool = False, + prefill_batch_size: int = 0, + prefill_seq_len: int = 0, + is_running_drafter: bool = False, + *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + if ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode != self.runtime_mode + ): + # CUDAGraphMode.NONE could mean the profile run, a warmup run, or + # running without cudagraphs. + # We do not trigger capture/replay if the runtime mode is not + # matches. This enables properly dispatching to the correct + # CUDAGraphWrapper when nesting multiple instances with different + # runtime modes. + return self.runnable(*args, **kwargs) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: handle prefill graph separately + @brief: skip check in running drafter model + ''' + if is_capturing_prefill: # PREFILL capture + self.prefill_mlugraph_entry = PrefillGraphEntry( + batch_size=prefill_batch_size, + seq_len=prefill_seq_len) + else: # FULL/DECODE capture + if batch_descriptor not in self.concrete_cudagraph_entries: + # create a new entry for this batch descriptor + self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry( + batch_descriptor=batch_descriptor + ) + + if ((self.is_running_in_prefill(self.prefill_mlugraph_entry) and prefill_enable_mlugraph) + or is_capturing_prefill): + entry = self.prefill_mlugraph_entry + logger.debug( + f"Hitting a prefill cudagraph on {self.runtime_mode.name}, " + f"batch_size: {entry.batch_size}, seq_len: {entry.seq_len}") + else: # FULL/DECODE capture + entry = self.concrete_cudagraph_entries[batch_descriptor] + logger.debug( + "Hitting a decode cudagraph on (%s, %s)", + self.runtime_mode.name, + entry.batch_descriptor, + ) + + if entry.cudagraph is None: + if self.cudagraph_options.debug_log_enable: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. E.g. we only log it for the first subgraph in + # piecewise mode. + if is_capturing_prefill: + logger.debug( + "Capturing a prefill cudagraph on (%s, batch_size=%d, seq_len=%d)", + self.runtime_mode.name, + entry.batch_size, + entry.seq_len, + ) + else: + logger.debug( + "Capturing a decode cudagraph on (%s, %s)", + self.runtime_mode.name, + entry.batch_descriptor, + ) + if ((not is_capturing_prefill) and (not is_running_drafter)): + # validate that cudagraph capturing is legal at this point. + validate_cudagraph_capturing_enabled() + ''' + ================== + End of MLU Hijack + ================== + ''' + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.mlu.MLUGraph() + + with ExitStack() as stack: + if self.cudagraph_options.gc_disable: + # during every model forward for piecewise cudagraph + # mode, we will capture many pieces of cudagraphs + # (roughly one per layer). running gc again and again + # across layers will make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context(patch("torch.mlu.empty_cache", lambda: None)) + + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) + # mind-exploding: carefully manage the reference and memory. + with torch.mlu.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = self.runnable(*args, **kwargs) + if self.cudagraph_options.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph in piecewise cuadgraph mode, because + # the output of the last graph will not be used by + # any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + f"Input addresses for cudagraphs are different " + f"during replay. Expected {entry.input_addresses}, " + f"got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/vllm_mlu/config/__init__.py b/vllm_mlu/config/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/config/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/config/model.py b/vllm_mlu/config/model.py new file mode 100644 index 0000000..e8849b7 --- /dev/null +++ b/vllm_mlu/config/model.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.config.model import ModelConfig +from vllm.logger import init_logger + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +def vllm__config__model__ModelConfig__is_embedding_task(self) -> bool: + return self.runner_type == "pooling" + +def vllm__config__model__ModelConfig__get_head_size(self) -> int: + # TODO remove hard code + if self.is_deepseek_mla: + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) + if self.use_mla: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + if hasattr(self.hf_text_config, "model_type") and ( + self.hf_text_config.model_type == "zamba2" + ): + return self.hf_text_config.attention_head_dim + + if self.is_attention_free: + return 0 + + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: + return self.hf_text_config.hidden_size_per_head + + # FIXME(woosuk): This may not be true for all models. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust num_heads and num_attention_heads. + ''' + if hasattr(self.hf_text_config, "num_heads"): + num_attention_heads = self.hf_text_config.num_heads + else: + num_attention_heads = self.hf_text_config.num_attention_heads + + return (self.hf_text_config.hidden_size // num_attention_heads) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +MluHijackObject.apply_hijack( + ModelConfig, + "is_embedding_task", + vllm__config__model__ModelConfig__is_embedding_task, +) +MluHijackObject.apply_hijack( + ModelConfig, + ModelConfig.get_head_size, + vllm__config__model__ModelConfig__get_head_size, +) \ No newline at end of file diff --git a/vllm_mlu/config/scheduler.py b/vllm_mlu/config/scheduler.py new file mode 100644 index 0000000..57abd59 --- /dev/null +++ b/vllm_mlu/config/scheduler.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + + +from typing_extensions import Self + +from vllm.config.scheduler import SchedulerConfig +from vllm.logger import init_logger + +from vllm_mlu._mlu_utils import VLLM_V1_BENCHMARK +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +def vllm__config__scheduler__SchedulerConfig__verify_max_model_len( + self, max_model_len: int, +) -> Self: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: This restriction is removed when VLLM_V1_BENCHMARK is set to True + ''' + if not VLLM_V1_BENCHMARK: + if ( + self.max_num_batched_tokens < max_model_len + and not self.enable_chunked_prefill + ): + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len." + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs})." + ) + + if self.max_num_batched_tokens > self.max_num_seqs * max_model_len: + logger.warning( + "max_num_batched_tokens (%d) exceeds max_num_seqs " + "* max_model_len (%d). This may lead to unexpected behavior.", + self.max_num_batched_tokens, + self.max_num_seqs * max_model_len, + ) + + if self.max_num_partial_prefills > 1: + if not self.enable_chunked_prefill: + raise ValueError( + "Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1." + ) + + if self.long_prefill_token_threshold > max_model_len: + raise ValueError( + "long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than the max_model_len ({max_model_len})." + ) + + if self.max_long_partial_prefills > self.max_num_partial_prefills: + raise ValueError( + f"{self.max_long_partial_prefills=} must be less than or equal to " + f"{self.max_num_partial_prefills=}." + ) + + return self + + +MluHijackObject.apply_hijack( + SchedulerConfig, + SchedulerConfig.verify_max_model_len, + vllm__config__scheduler__SchedulerConfig__verify_max_model_len, +) \ No newline at end of file diff --git a/vllm_mlu/config/speculative.py b/vllm_mlu/config/speculative.py new file mode 100644 index 0000000..8ef9ed8 --- /dev/null +++ b/vllm_mlu/config/speculative.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.config.parallel import ParallelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.logger import init_logger + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + +@staticmethod +def vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, +) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + ''' + ============================= + Modify by vllm_mlu + @brief: add draft data parallel parameters + ============================= + ''' + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config.pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + distributed_executor_backend=target_parallel_config.distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + # add draft data parallel parameters + data_parallel_size=target_parallel_config.data_parallel_size, + data_parallel_size_local=target_parallel_config.data_parallel_size_local, + data_parallel_master_ip=target_parallel_config.data_parallel_master_ip, + data_parallel_rpc_port=target_parallel_config.data_parallel_rpc_port, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + return draft_parallel_config + + +vllm__config__speculative__SpeculativeConfig____post_init___org = SpeculativeConfig.__post_init__ +def vllm__config__speculative__SpeculativeConfig____post_init__(self): + if self.model is None and self.num_speculative_tokens is not None and self.method is None: + self.method = "mtp" + vllm__config__speculative__SpeculativeConfig____post_init___org(self) + + +MluHijackObject.apply_hijack( + SpeculativeConfig, + SpeculativeConfig.create_draft_parallel_config, + vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config, +) +MluHijackObject.apply_hijack( + SpeculativeConfig, + SpeculativeConfig.__post_init__, + vllm__config__speculative__SpeculativeConfig____post_init__, +) \ No newline at end of file diff --git a/vllm_mlu/config/vllm.py b/vllm_mlu/config/vllm.py new file mode 100644 index 0000000..c6bf221 --- /dev/null +++ b/vllm_mlu/config/vllm.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import os + +from vllm.config.vllm import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.logger import init_logger + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +def vllm__config__vllm__VllmConfig___set_cudagraph_sizes(self): + """ + vLLM defines the default candidate list of batch sizes for CUDA graph + capture as: + + ```python + max_graph_size = min(max_num_seqs * 2, 512) + # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 + # up to max_graph_size + cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( + range(256, max_graph_size + 1, 16)) + + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in ascending order). + + These sizes are used to capture and reuse CUDA graphs for + performance-critical paths (e.g., decoding). Capturing enables + significantly faster kernel dispatch by avoiding Python overhead. The + list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on + most GPUs), which controls the total allowed number of tokens in a + batch. Since each sequence may have a variable number of tokens, the + maximum usable batch size will depend on actual sequence lengths. + + Example: + With `max_num_batched_tokens = 8192`, and typical sequences + averaging ~32 tokens, most practical batch sizes fall below 256. + However, the system will still allow capture sizes up to 512 if + shape and memory permit. + + Note: + If users explicitly specify cudagraph capture sizes in the + compilation config, those will override this default logic. + At runtime: + + - If batch size <= one of the `cudagraph_capture_sizes`, the closest + padded CUDA graph will be used. + - If batch size > largest `cudagraph_capture_sizes`, cudagraph will + not be used. + """ + if hasattr(self.compilation_config, "_has_set_capture_list"): + # avoid set capture list twice while init + return + + if ( + self.model_config is not None + and not self.model_config.enforce_eager + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + # determine the initial max_cudagraph_capture_size + max_cudagraph_capture_size = ( + self.compilation_config.max_cudagraph_capture_size + ) + if max_cudagraph_capture_size is None: + max_cudagraph_capture_size = min( + self.scheduler_config.max_num_seqs * 2, 512 + ) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) + + assert max_cudagraph_capture_size >= 1, ( + "Maximum cudagraph size should be greater than or equal to 1 " + "when using cuda graph." + ) + + # determine the cudagraph_capture_sizes + if self.compilation_config.cudagraph_capture_sizes is not None: + assert len(self.compilation_config.cudagraph_capture_sizes) > 0, ( + "cudagraph_capture_sizes should contain at least one element " + "when using cuda graph." + ) + # de-duplicate the sizes provided by the config + dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes)) + cudagraph_capture_sizes = [ + i for i in dedup_sizes if i <= max_num_tokens + ] + # sort to make sure the sizes are in ascending order + cudagraph_capture_sizes.sort() + else: + cudagraph_capture_sizes = [ + i for i in [1, 2, 4] if i <= max_cudagraph_capture_size + ] + if max_cudagraph_capture_size >= 8: + # Step size 8 for small batch sizes, up to 256(not included) + cudagraph_capture_sizes += list( + range(8, min(max_cudagraph_capture_size + 1, 256), 8) + ) + if max_cudagraph_capture_size >= 256: + # Step size 16 for larger batch sizes + cudagraph_capture_sizes += list( + range(256, max_cudagraph_capture_size + 1, 16) + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: + 1) check batch_size_capture_list when enable mtp because bs * (K + 1) + may greater than max_num_batched_tokens + 2) capture MLUGraph by given batch list + ''' + mlu_graph_capture_list = os.getenv("MLU_GRAPH_CAPTURE_LIST", None) + if mlu_graph_capture_list: + if "-" in mlu_graph_capture_list: + batch_info = mlu_graph_capture_list.split("-") + assert len(batch_info) == 3, \ + f"Got invalid graph_capture_list={mlu_graph_capture_list}, " + \ + f"but expected format 'min_bs-max_bs(may not include)-step'." + start, end, step = mlu_graph_capture_list.split("-") + cudagraph_capture_sizes = [1, 2, 4] + [ + i for i in range(int(start), int(end), int(step)) + ] + cudagraph_capture_sizes = sorted(list(set(cudagraph_capture_sizes))) + else: + cudagraph_capture_sizes = [int(x) for x in mlu_graph_capture_list.split(",")] + + if (self.speculative_config is not None + and self.speculative_config.num_speculative_tokens > 0 + ): + K = self.speculative_config.num_speculative_tokens + cudagraph_capture_sizes = [x * (1 + K) for x in cudagraph_capture_sizes] + + cudagraph_capture_sizes = [ + size for size in cudagraph_capture_sizes + if size <= self.scheduler_config.max_num_batched_tokens + ] + ''' + ================== + End of MLU Hijack + ================== + ''' + + if ( + self.parallel_config.tensor_parallel_size > 1 + and self.compilation_config.pass_config.enable_sequence_parallelism + ): + cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( + cudagraph_capture_sizes + ) + + # user-specific compilation_config.max_cudagraph_capture_size get + # truncated to valid_max_size when they are inconsistent. + valid_max_size = ( + cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0 + ) + if ( + self.compilation_config.max_cudagraph_capture_size is not None + and self.compilation_config.max_cudagraph_capture_size != valid_max_size + ): + # raise error only when both two flags are user-specified + # and they are inconsistent with each other + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "customized max_cudagraph_capture_size" + f"(={self.compilation_config.max_cudagraph_capture_size}) " + "should be consistent with the max value of " + f"cudagraph_capture_sizes(={valid_max_size})" + ) + + logger.warning( + "Truncating max_cudagraph_capture_size to %d", + valid_max_size, + ) + # always set the final max_cudagraph_capture_size + self.compilation_config.max_cudagraph_capture_size = valid_max_size + + if self.compilation_config.cudagraph_capture_sizes is not None and len( + cudagraph_capture_sizes + ) < len(self.compilation_config.cudagraph_capture_sizes): + # If users have specified capture sizes, we only need to + # compare the lens before and after modification since the modified + # list is only the subset of the original list. + logger.warning( + ( + "cudagraph_capture_sizes specified in compilation_config" + " %s is overridden by config %s" + ), + self.compilation_config.cudagraph_capture_sizes, + cudagraph_capture_sizes, + ) + # always write back the final sizes + self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes + + else: + # no cudagraph in use + self.compilation_config.max_cudagraph_capture_size = 0 + self.compilation_config.cudagraph_capture_sizes = [] + + # complete the remaining process. + self.compilation_config.post_init_cudagraph_sizes() + + setattr(self.compilation_config, "_has_set_capture_list", True) + + +MluHijackObject.apply_hijack( + VllmConfig, + VllmConfig._set_cudagraph_sizes, + vllm__config__vllm__VllmConfig___set_cudagraph_sizes, +) \ No newline at end of file diff --git a/vllm_mlu/device_allocator/__init__.py b/vllm_mlu/device_allocator/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/device_allocator/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/device_allocator/cnmem.py b/vllm_mlu/device_allocator/cnmem.py new file mode 100644 index 0000000..d0e09cc --- /dev/null +++ b/vllm_mlu/device_allocator/cnmem.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# cn_api based pytorch pluggable allocator to implement sleep mode. + +import dataclasses +import gc +import os +from collections.abc import Callable +from contextlib import contextmanager +from typing import Any + +import torch + +from vllm.logger import init_logger +from vllm.utils.platform_utils import is_pin_memory_available + +logger = init_logger(__name__) + + +def find_loaded_library(lib_name) -> str | None: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), ( + f"Unexpected filename: {filename} for library {lib_name}" + ) + return path + + +cnmem_available = False +try: + from vllm_mlu.vllm_mlu_C import ( + init_module, + python_create_and_map, + python_unmap_and_release, + python_cn_memcpy, + ) + + lib_name = find_loaded_library("vllm_mlu_C") + cnmem_available = True +except ModuleNotFoundError as e: + logger.error("Failed to import cnmem_allocator:%s", e) + init_module = None + python_create_and_map = None + python_unmap_and_release = None + lib_name = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: torch.Tensor | None = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]] +) -> torch.mlu.memory.MLUPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.mlu.memory.MLUPluggableAllocator( + lib_name, "my_malloc", "my_free" + ) + return new_alloc + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]]): + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.mlu.memory.MemPool(new_alloc._allocator) + with torch.mlu.memory.use_mem_pool(mem_pool): + yield mem_pool, new_alloc + + +class CnMemAllocator: + """ + A singleton class that manages a memory pool for MLU tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + + instance: "CnMemAllocator" = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CnMemAllocator": + """ + CnMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + assert cnmem_available, "cnmem allocator is not available" + if CnMemAllocator.instance is None: + CnMemAllocator.instance = CnMemAllocator() + return CnMemAllocator.instance + + def __init__(self): + conf = os.environ.get("PYTORCH_MLU_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, ( + "Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates." + ) + + self.pointer_to_data: dict[int, AllocationData] = {} + self.current_tag: str = CnMemAllocator.default_tag + self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback + + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag + ) + logger.debug( + "Allocated %s bytes for %s with address %s from cnmem allocator", + allocation_handle[1], + self.current_tag, + py_d_mem, + ) + return + + def _python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + logger.debug( + "Freed %s bytes for %s with address %s from cnmem allocator", + data.handle[1], + data.tag, + ptr, + ) + return data.handle + + def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CnMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags,) + + assert isinstance(offload_tags, tuple) + + total_bytes = 0 + backup_bytes = 0 + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + total_bytes += handle[1] + if data.tag in offload_tags: + backup_bytes += handle[1] + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device="cpu", + pin_memory=is_pin_memory_available(), + ) + cpu_ptr = cpu_backup_tensor.data_ptr() + python_cn_memcpy(cpu_ptr, ptr, size_in_bytes) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + logger.info( + "CnMemAllocator: sleep freed %.2f GiB memory in total, of which " + "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " + "directly.", + total_bytes / 1024**3, + backup_bytes / 1024**3, + (total_bytes - backup_bytes) / 1024**3, + ) + + gc.collect() + torch.mlu.empty_cache() + + def wake_up(self, tags: list[str] | None = None) -> None: + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory. + + :param tags: The tags of the memory allocation that will be loaded + back to GPU memory. If None, all memory allocation will be loaded + back to GPU memory. + """ + for ptr, data in self.pointer_to_data.items(): + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = ( + cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() + ) + cpu_ptr = cpu_backup_tensor.data_ptr() + python_cn_memcpy(ptr, cpu_ptr, size_in_bytes) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: str | None = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CnMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator( + self.python_malloc_callback, self.python_free_callback + ) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released, e.g. in online quantization, + # where the model is created in higher precision, and then + # quantized in lower precision. + # Find all unused allocations and manually release them. + # TODO: we should expose `empty_cache` method in the memory pool. + # TODO: ask for help from PyTorch team to expose this method. + allocations = data[0].snapshot() + for allocation in allocations: + if allocation["allocated_size"] == 0: + handle = self._python_free_callback(allocation["address"]) + unmap_and_release(handle) + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm_mlu/distributed/__init__.py b/vllm_mlu/distributed/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/distributed/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/distributed/device_communicators/__init__.py b/vllm_mlu/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/distributed/device_communicators/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/distributed/device_communicators/mlu_communicator.py b/vllm_mlu/distributed/device_communicators/mlu_communicator.py new file mode 100644 index 0000000..e28af38 --- /dev/null +++ b/vllm_mlu/distributed/device_communicators/mlu_communicator.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase, +) + + +class MLUCommunicator(DeviceCommunicatorBase): + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "" + ): + super().__init__(cpu_group, device, device_group, unique_name) + # init device according to rank + self.device = torch.mlu.current_device() + self.ca_comm: CustomAllreduce | None = None \ No newline at end of file diff --git a/vllm_mlu/distributed/kv_transfer/kv_connector/factory.py b/vllm_mlu/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000..06b657c --- /dev/null +++ b/vllm_mlu/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory + + +MLUKVConnectors: dict[str, tuple[str, str]] = { + "MLUSharedStorageConnector": ( + "vllm_mlu.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", + "SharedStorageConnector" + ), + "MLUNixlConnector": ( + "vllm_mlu.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "MLUNixlConnector" + ), +} + +for name, (module_path, class_name) in MLUKVConnectors.items(): + if name not in KVConnectorFactory._registry: + KVConnectorFactory.register_connector(name, module_path, class_name) \ No newline at end of file diff --git a/vllm_mlu/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 0000000..c9d9ff1 --- /dev/null +++ b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1 +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +class LMCacheConnectorV1_MluHijack(LMCacheConnectorV1): + + def response_remote_alloc_once(self) -> None: + self._lmcache_engine.response_remote_alloc_once() + + def request_remote_memory_send(self) -> None: + self._lmcache_engine.request_remote_memory_send() + + +MluHijackObject.apply_hijack(LMCacheConnectorV1, + "response_remote_alloc_once", + LMCacheConnectorV1_MluHijack.response_remote_alloc_once) +MluHijackObject.apply_hijack(LMCacheConnectorV1, + "request_remote_memory_send", + LMCacheConnectorV1_MluHijack.request_remote_memory_send) \ No newline at end of file diff --git a/vllm_mlu/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 0000000..cd02668 --- /dev/null +++ b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +import math +import threading +import time +import uuid +from collections import defaultdict +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch +import zmq + +from vllm import envs +from vllm.attention.selector import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.platforms import _Backend +from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + EngineId, NixlConnectorWorker, NixlAgentMetadata, NixlConnectorScheduler, NixlConnector) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +Transfer = tuple[int, float] # (xfer_handle, start_time) +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +class MLUNixlConnector(NixlConnector): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super(NixlConnector, self).__init__(vllm_config, role, kv_cache_config) + + assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler : MLUNixlConnectorScheduler | None = ( + MLUNixlConnectorScheduler(vllm_config, self.engine_id) + ) + self.connector_worker: MLUNixlConnectorWorker | None = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MLUNixlConnectorWorker(vllm_config, self.engine_id) + + +class MLUNixlConnectorScheduler(NixlConnectorScheduler): + """Implementation of Scheduler side methods""" + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: kv transfer info + ''' + if request.kv_transfer_params.get("do_remote_prefill", False): + logger.info(f"NIXLConnector update_state_after_alloc: request_id={request.request_id}, " + f"num_prompt_tokens={request.num_prompt_tokens}, " + f"num_external_tokens={num_external_tokens}, " + f"kv_transfer_params={request.kv_transfer_params}") + ''' + ================== + End of MLU Hijack + ================== + ''' + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, + params, + ) + + if not params: + return + + if params.get("do_remote_decode"): + self._reqs_in_batch.add(request.request_id) + if self.use_host_buffer and params.get("do_remote_decode"): + # NOTE: when accelerator is not directly supported by Nixl, + # prefilled blocks need to be saved to host memory before transfer. + + # save all blocks + block_ids = blocks.get_block_ids()[0] + # TODO: skip the blocks that are already in the host xfer buffer. + # Currently, the host xfer buffer block is 1-to-1 mapped to device + # kv blocks, so host blocks won't be flushed as long as its device + # block is not overwritten; and it will be safe to skip saving them + # to host xfer buffer. + if block_ids: + self._reqs_need_save[request.request_id] = (request, block_ids) + elif params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all( + p in params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + local_block_ids = ( + blocks.get_unhashed_block_ids() + if num_external_tokens > 0 + else [] + ) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, + local_block_ids, + ) + + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", + params, + ) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + +class MLUNixlConnectorWorker(NixlConnectorWorker): + """Implementation of Worker side methods""" + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + _, first_kv_cache = next(iter(kv_caches.items())) + + ''' + ============================= + Add by vllm_mlu + ============================= + @brief: not support kv8 + ''' + if not isinstance(first_kv_cache, torch.Tensor): + kv_caches = {key: value[0] for key, value in kv_caches.items()} + _, first_kv_cache = next(iter(kv_caches.items())) + ''' + ================== + End of MLU Hijack + ================== + ''' + + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected + # KV memory layout is HND, as opposed to the default NHD. Note that it + # will only affects the strides. For MLA instead, we make require no + # such thing and resort to the standard layout. + + ''' + ============================= + Add by vllm_mlu + ============================= + @brief: support mla + ''' + use_mla = first_kv_cache.shape[0] == 1 + ''' + ================== + End of MLU Hijack + ================== + ''' + + assert use_mla == self.use_mla + + # TODO (NickLucche) not compatible with hybrid allocator. Enforce check + # once it goes live, as a single kv layout is expected for xfers. + if use_mla: + # MLA case. + + ''' + ============================= + Add by vllm_mlu + ============================= + @brief: support mla + ''' + self.num_blocks = first_kv_cache.shape[1] + ''' + ================== + End of MLU Hijack + ================== + ''' + + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim + else: + # [2 (k and v), num_blocks, ...] + if self._use_flashinfer: + # FlashInfer swaps 2<->num_blocks dimensions. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 4 # [2, block_size, kv_heads, head_dim] + else: + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + ''' + ============================= + Add by vllm_mlu + ============================= + @brief: MLU kv_cache layout is [2 (k and v), num_blocks, kv_heads, block_size, head_dim] + ''' + n_kv_heads, block_size, head_dim = block_shape[-3:] + ''' + ================== + End of MLU Hijack + ================== + ''' + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + # block size in bytes + self.block_len = kv_elem_size * math.prod(block_shape) + logger.info( + "Registering KV_Caches: use_mla: %s, num_blocks: %s, " + "block_shape: %s, per_layer_kv_cache_shape: %s", use_mla, + self.num_blocks, block_shape, first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + # Conversely for FlashInfer, K and V are transferred in the same tensor + # to better exploit the memory layout (ie num_blocks is the first dim). + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \ + else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append( + (base_addr, region_len, cache.device.index, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_descs.append(descs) + + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + tp_size=self.world_size, + block_len=self.block_len, + attn_backend_name=self.backend_name) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() \ No newline at end of file diff --git a/vllm_mlu/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py new file mode 100644 index 0000000..fdaf9f3 --- /dev/null +++ b/vllm_mlu/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -0,0 +1,450 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +import hashlib +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional + +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +from vllm_mlu.v1.attention.backends.flash_mla import MLAFlashAttentionCommonMetadata + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + mm_hashes: list[str] + + @staticmethod + def make_meta( + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + mm_hashes: list[str], + ) -> "ReqMeta": + valid_num_tokens = align_to_block_size(len(token_ids), block_size) + token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_tensor.reshape((num_blocks, 1)) * block_size + ) + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + mm_hashes=mm_hashes, + ) + + +@dataclass +class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] = field(default_factory=list) + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + mm_hashes: list[str], + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes) + ) + + +class SharedStorageConnector(KVConnectorBase_V1): + # NOTE: This is Simple debug implementation of the KV connector. + # It save / load the KV cache to / from the disk. + # It does extra work which will overwrite the existing prefix-cache in GPU + # - to remove the overhead, need to add some "mask" in the ReqMeta class + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + self._storage_path = self._kv_transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.info(self._kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1 + ) + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1 + ) + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, SharedStorageConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning("In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + logger.info( + "Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping), + ) + for layer_name in forward_context.no_compile_layers: + layer = forward_context.no_compile_layers[layer_name] + + # Only process layers that have kv_cache + # attribute (attention layers) Skip non-attention + # layers like FusedMoE/MLP etc. + kv_cache_attr = getattr(layer, "kv_cache", None) + if kv_cache_attr is None: + continue + + kv_cache_layer = kv_cache_attr[forward_context.virtual_engine] + + filename = self._generate_filename_debug( + layer_name, request.token_ids, request.mm_hashes + ) + kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + for request in connector_metadata.requests: + if request.is_store: + filename = self._generate_filename_debug( + layer_name, request.token_ids, request.mm_hashes + ) + kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) + tensors = {"kv_cache": kv_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + + def wait_for_save(self): + return + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + # NOTE: in this debug implementation, we assume that the prompt is + # cached_prompt + newly_generated_single_token + # Therefore, we use prompt_token_ids[:-1] to determine the folder name + + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned + # with the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + if not self._found_match_for_request(request): + return 0, False + + logger.info("External Cache Hit!") + + # Now, first num_tokens_to_check tokens are hit, we need to prepare + # the metadata for the worker connector to correctly load the KV + token_ids = request.prompt_token_ids or [] + num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size) + + return num_tokens_to_check - num_computed_tokens, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = SharedStorageConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + token_ids = new_req.prompt_token_ids or [] + mm_hashes = [f.identifier for f in new_req.mm_features] + if new_req.req_id in self._requests_need_load: + meta.add_request( + token_ids=token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False, + mm_hashes=mm_hashes, + ) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + if not self._found_match_for_prompt(token_ids, mm_hashes): + meta.add_request( + token_ids=token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True, + mm_hashes=mm_hashes, + ) + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + resumed_from_preemption = req_id in cached_reqs.resumed_req_ids + if not resumed_from_preemption or req_id not in self._requests_need_load: + continue + + num_computed_tokens = cached_reqs.num_computed_tokens[i] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + new_block_ids = cached_reqs.new_block_ids[i] + + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[req_id] + total_tokens = num_computed_tokens + num_new_tokens + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + assert new_block_ids is not None + block_ids = new_block_ids[0] + + meta.add_request( + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in request.mm_features], + ) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_request( + self, + request: "Request", + ) -> bool: + """Check if the cache is hit for the request.""" + return self._found_match_for_prompt( + list(request.prompt_token_ids or []), + [f.identifier for f in request.mm_features], + ) + + def _found_match_for_prompt( + self, + prompt_token_ids: list[int], + mm_hashes: list[str], + ) -> bool: + num_tokens_to_check = align_to_block_size( + len(prompt_token_ids) - 1, self._block_size + ) + foldername = self._generate_foldername_debug( + torch.tensor(prompt_token_ids)[:num_tokens_to_check], + mm_hashes, + create_folder=False, + ) + return os.path.exists(foldername) + + def _generate_foldername_debug( + self, + token_ids: torch.Tensor, + mm_hashes: list[str], + create_folder=False, + ) -> str: + """Generate a folder name based on the hash of the bytes of the input + ids. + """ + token_bytes = token_ids.numpy().tobytes() + # Add mm_hashes to the bytes being hashed to avoid path traversal and + # to create a canonical key. + if mm_hashes: + mm_str = "-".join(mm_hashes) + token_bytes += mm_str.encode("utf-8") + input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest() + + foldername = os.path.join(self._storage_path, input_ids_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug( + self, + layer_name: str, + token_ids: torch.Tensor, + mm_hashes: list[str], + ) -> str: + """Generate a file name based on the layer name and the hash + of the bytes of the input ids. + """ + foldername = self._generate_foldername_debug( + token_ids, mm_hashes=mm_hashes, create_folder=True + ) + return os.path.join(foldername, f"{layer_name}.safetensors") + + +def align_to_block_size(num_tokens: int, block_size) -> int: + """Align the number of tokens to the block size.""" + return (num_tokens - 1) // block_size * block_size diff --git a/vllm_mlu/distributed/parallel_state.py b/vllm_mlu/distributed/parallel_state.py new file mode 100644 index 0000000..4a5e57a --- /dev/null +++ b/vllm_mlu/distributed/parallel_state.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from contextlib import contextmanager, nullcontext +from typing import Optional +from dataclasses import dataclass + +import torch + +from vllm.distributed.parallel_state import ( + GroupCoordinator, + GraphCaptureContext, + get_pp_group, + get_tp_group, +) +from vllm.distributed.mlu_parallel_state import( + get_moe_expert_parallel_world_size, + get_moe_expert_parallel_rank, + get_moe_expert_parallel_group, +) +from vllm.logger import init_logger + +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu import _mlu_ops as mlu_ops + +logger = init_logger(__name__) + +@dataclass +class MLUGraphCaptureContext: + stream: torch.mlu.Stream + + +@contextmanager +def mlu_graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + context = MLUGraphCaptureContext(torch.mlu.Stream(device=device)) + with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context): + yield context + + +@contextmanager +def vllm__distributed__parallel_state__GroupCoordinator__graph_capture( + self, + graph_capture_context: GraphCaptureContext | None = None, +): + if graph_capture_context is None: + stream = torch.mlu.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + # only cuda uses this function, + # so we don't abstract it into the base class + maybe_ca_context = nullcontext() + from vllm_mlu.distributed.device_communicators.mlu_communicator import ( + MLUCommunicator, + ) + + if self.device_communicator is not None: + assert isinstance(self.device_communicator, MLUCommunicator) + ca_comm = self.device_communicator.ca_comm + if ca_comm is not None: + maybe_ca_context = ca_comm.capture() # type: ignore + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.mlu.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.mlu.stream(stream), maybe_ca_context: + yield graph_capture_context + +@dataclass +class CnclEPBuffer: + dispatch_send_token_tensor: torch.Tensor + dispatch_recv_token_tensor: torch.Tensor + combine_send_token_tensor: torch.Tensor + combine_recv_token_tensor: torch.Tensor + +class CnclEP: + + def __init__(self, + dispatch_token_size: int, + combine_token_size: int, + max_num_tokens_per_rank: int, + num_global_experts: int, + use_quant_dispatch: bool = True) -> None: + nranks = get_moe_expert_parallel_world_size() + rank = get_moe_expert_parallel_rank() + moe_ep_group = get_moe_expert_parallel_group() + self.max_num_tokens_per_rank = max_num_tokens_per_rank + self.use_quant_dispatch = use_quant_dispatch + + ( + handle, + exchange_info_size, + exchange_info, + dispatch_send_token_tensor, + dispatch_recv_token_tensor, + combine_send_token_tensor, + combine_recv_token_tensor + ) = mlu_ops.moe_all2all_create(dispatch_token_size, + combine_token_size, + num_global_experts, + max_num_tokens_per_rank, + rank, + nranks) + self.handle = handle + self.buffer = CnclEPBuffer( + dispatch_send_token_tensor, + dispatch_recv_token_tensor, + combine_send_token_tensor, + combine_recv_token_tensor) + + assert exchange_info.ndim == 1, "exchange_info should be 1D" + all_exchange_info = torch.empty((nranks, exchange_info.size(0)), + dtype=exchange_info.dtype, + device=exchange_info.device) + exchange_info = exchange_info.unsqueeze(0) + torch.distributed.all_gather_into_tensor(all_exchange_info, + exchange_info, + group=moe_ep_group.cpu_group, + async_op=False) + mlu_ops.moe_all2all_init(self.handle, all_exchange_info) + torch.distributed.barrier(group=moe_ep_group.cpu_group) + + def dispatch(self, + token_byte: int, + token_num: int, + send_layout: torch.Tensor, + send_token_num: torch.Tensor, + recv_layout: torch.Tensor, + recv_token_num: torch.Tensor, + send_token: Optional[torch.Tensor] = None, + recv_token: Optional[torch.Tensor] = None, + ) -> None: + ''' + The returned tensors are in-placed modified, we could directly use them + after dispatch finishes. + ''' + mlu_ops.moe_all2all_dispatch(self.handle, + token_byte, + token_num, + send_layout, + send_token_num, + recv_layout, + recv_token_num, + send_token, + recv_token) + + def combine(self, + token_byte: int, + token_num: int, + send_src_layout: torch.Tensor, + send_dst_layout: torch.Tensor, + send_token: Optional[torch.Tensor] = None, + recv_token: Optional[torch.Tensor] = None, + ) ->None: + mlu_ops.moe_all2all_combine(self.handle, + token_byte, + token_num, + send_src_layout, + send_dst_layout, + send_token, + recv_token) + + def destroy(self) -> None: + mlu_ops.moe_all2all_destroy(self.handle) + +_CNCLEP: CnclEP | None = None +_CNCLEP_BF16: CnclEP | None = None + +def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP: + if use_quant_dispatch: + assert _CNCLEP is not None, "cnclep is not initialized" + return _CNCLEP + else: + assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized" + return _CNCLEP_BF16 + +def init_cnclep(dispatch_token_size: int, + combine_token_size: int, + max_num_tokens_per_rank: int, + num_global_experts: int, + use_quant_dispatch: bool = True): + if use_quant_dispatch: + global _CNCLEP + assert _CNCLEP is None, "cnclep has been initialized" + _CNCLEP = CnclEP(dispatch_token_size, + combine_token_size, + max_num_tokens_per_rank, + num_global_experts, + use_quant_dispatch) + else: + global _CNCLEP_BF16 + assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized" + _CNCLEP_BF16 = CnclEP(dispatch_token_size, + combine_token_size, + max_num_tokens_per_rank, + num_global_experts, + use_quant_dispatch) + +def cnclep_dispatch(token_byte: int, + token_num: int, + send_layout: torch.Tensor, + send_token_num: torch.Tensor, + recv_layout: torch.Tensor, + recv_token_num: torch.Tensor, + send_token: Optional[torch.Tensor] = None, + recv_token: Optional[torch.Tensor] = None, + use_quant_dispatch: bool = True, +): + if use_quant_dispatch: + _CNCLEP.dispatch(token_byte, + token_num, + send_layout, + send_token_num, + recv_layout, + recv_token_num, + send_token, + recv_token) + else: + _CNCLEP_BF16.dispatch(token_byte, + token_num, + send_layout, + send_token_num, + recv_layout, + recv_token_num, + send_token, + recv_token) + +def cnclep_combine(token_byte: int, + token_num: int, + send_src_layout: torch.Tensor, + send_dst_layout: torch.Tensor, + send_token: Optional[torch.Tensor] = None, + recv_token: Optional[torch.Tensor] = None, + use_quant_dispatch: bool = True, +): + if use_quant_dispatch: + _CNCLEP.combine(token_byte, + token_num, + send_src_layout, + send_dst_layout, + send_token, + recv_token) + else: + _CNCLEP_BF16.combine(token_byte, + token_num, + send_src_layout, + send_dst_layout, + send_token, + recv_token) + + +def destroy_cnclep(): + global _CNCLEP + + if _CNCLEP: + _CNCLEP.destroy() + _CNCLEP = None + + global _CNCLEP_BF16 + + if _CNCLEP_BF16: + _CNCLEP_BF16.destroy() + _CNCLEP_BF16 = None + + +MluHijackObject.apply_hijack(GroupCoordinator, + GroupCoordinator.graph_capture, + vllm__distributed__parallel_state__GroupCoordinator__graph_capture) + diff --git a/vllm_mlu/engine/__init__.py b/vllm_mlu/engine/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/engine/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/engine/arg_utils.py b/vllm_mlu/engine/arg_utils.py new file mode 100644 index 0000000..09974b0 --- /dev/null +++ b/vllm_mlu/engine/arg_utils.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import get_args + +from vllm.platforms import current_platform +from vllm.config import ( + ModelConfig, + VllmConfig, + SchedulerConfig, +) +from vllm.config.cache import CacheDType +from vllm.engine.arg_utils import ( + EngineArgs, + _raise_unsupported_error, +) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +@classmethod +def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults( + cls, + model_config: ModelConfig, +) -> tuple[bool, bool]: + if model_config.runner_type != "pooling": + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: mlu-v1 default use unchunked scheduler + ''' + if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED: + default_chunked_prefill = False + else: + default_chunked_prefill = True + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + default_prefix_caching = not model_config.is_hybrid + else: + assert model_config.pooler_config is not None + + pooling_type = model_config.pooler_config.pooling_type + incremental_prefill_supported = ( + pooling_type is not None + and pooling_type.lower() == "last" + and getattr(model_config.hf_config, "is_causal", True) + ) + + default_chunked_prefill = incremental_prefill_supported + default_prefix_caching = incremental_prefill_supported + + return default_chunked_prefill, default_prefix_caching + +def vllm__engine__arg_utils__EngineArgs___set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig +) -> None: + """Set Default Arguments for V1 Engine.""" + ( + default_chunked_prefill, + default_prefix_caching, + ) = self.get_chunked_prefill_prefix_caching_defaults(model_config) + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = default_chunked_prefill + + logger.debug( + "%s chunked prefill by default", + "Enabling" if default_chunked_prefill else "Disabling", + ) + elif ( + model_config.runner_type == "pooling" + and self.enable_chunked_prefill + and not default_chunked_prefill + ): + logger.warning( + "This model does not officially support chunked prefill. " + "Enabling this manually may cause the engine to crash " + "or produce incorrect outputs.", + ) + + if self.enable_prefix_caching is None: + self.enable_prefix_caching = default_prefix_caching + + logger.debug( + "%s prefix caching by default", + "Enabling" if default_prefix_caching else "Disabling", + ) + elif ( + model_config.runner_type == "pooling" + and self.enable_prefix_caching + and not default_prefix_caching + ): + logger.warning( + "This model does not officially support prefix caching. " + "Enabling this manually may cause the engine to crash " + "or produce incorrect outputs.", + ) + + world_size = self.pipeline_parallel_size * self.tensor_parallel_size + ( + default_max_num_batched_tokens, + default_max_num_seqs, + ) = self.get_batch_defaults(world_size) + + orig_max_num_batched_tokens = self.max_num_batched_tokens + orig_max_num_seqs = self.max_num_seqs + + if self.max_num_seqs is None: + self.max_num_seqs = default_max_num_seqs.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_SEQS, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: only set max_num_batched_tokens when enable chunked_prefill + ''' + if self.max_num_batched_tokens is None: + self.max_num_batched_tokens = default_max_num_batched_tokens.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) + + if orig_max_num_batched_tokens is None: + if not self.enable_chunked_prefill: + # If max_model_len is too short, use the default for higher throughput. + self.max_num_batched_tokens = max( + model_config.max_model_len, + self.max_num_batched_tokens, + ) + + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * model_config.max_model_len, + self.max_num_batched_tokens, + ) + + logger.debug( + "Defaulting max_num_batched_tokens to %d for %s usage context.", + self.max_num_batched_tokens, + usage_context.value if usage_context else None, + ) + + if orig_max_num_seqs is None: + if self.max_num_batched_tokens is not None: # For type checking + self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) + + logger.debug( + "Defaulting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, + usage_context.value if usage_context else None, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3'] + +def vllm__engine__arg_utils__EngineArgs__create_engine_config( + self, + usage_context: UsageContext | None = None, + headless: bool = False, +) -> VllmConfig: + """ + Create the VllmConfig. + + NOTE: If VllmConfig is incompatible, we raise an error. + """ + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add data parallel params to parallel config. + ''' + if self.mlu_config and "decoder_attn_dtype" in self.mlu_config: + if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]: + self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype") + + engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org( + self, usage_context, headless) + + world_size = engine_config.parallel_config.world_size_across_dp + tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size + embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size + if embedding_tp_size: + assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, ( + f"embedding_tp_size = {embedding_tp_size} out of bounds. " + f"Require {tensor_parallel_size} ≤ size ≤ {world_size}") + dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size + if dense_mlp_tp_size: + assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, ( + f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}") + if dense_mlp_tp_size != world_size: + assert not engine_config.mlu_config.is_dpsk_mcc_enabled, ( + "dense_mlp_tp_size is not supported when dpsk mcc is enabled.") + if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1: + raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. " + "Necessity of this constraint requires further investigation.") + if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size: + raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} " + f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}") + if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0: + raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must " + f"be divisible by {tensor_parallel_size}") + + if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None + or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph): + logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.") + engine_config.mlu_config.prefill_enable_mlugraph = False + if engine_config.mlu_config.decoder_attn_dtype: + if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType): + raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} " + f"decoder_attn_dtype for now") + is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and + engine_config.model_config.hf_text_config.model_type == "glm4_moe") + if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe) + and engine_config.mlu_config.decoder_attn_dtype != "auto"): + raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model") + + # sequence parallel checks + if (engine_config.mlu_config.prefill_use_sequence_parallel + and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]): + raise ValueError("Prefill sequence parallel can only use in deepseek model.") + if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill: + raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.") + if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled: + raise ValueError("Prefill sequence parallel can not use with mcc.") + if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1: + raise ValueError("Prefill sequence parallel can not use with data parallel.") + if (engine_config.mlu_config.prefill_use_sequence_parallel + and engine_config.model_config.hf_text_config.model_type == "deepseek_v3" + and engine_config.quant_config.get_name() != "SmoothQuant"): + raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.") + + # disagg constraint + # 1、only support deepseek-v3/r1 + # 2、unsupport kv8 + if self.kv_transfer_config is not None: + if engine_config.model_config.hf_config.model_type != "deepseek_v3": + raise ValueError("Disagg only support DeepDeek-V3/R1") + if engine_config.cache_config.cache_dtype == "int8": + raise ValueError("Disagg does not support KV cache dtype is int8") + if engine_config.cache_config.enable_prefix_caching: + raise ValueError("Disagg does not support prefix caching") + + if isinstance(self.kv_transfer_config, dict): + kv_connector = self.kv_transfer_config.get("kv_connector") + kv_role = self.kv_transfer_config.get("kv_role") + else: + kv_connector = self.kv_transfer_config.kv_connector + kv_role = self.kv_transfer_config.kv_role + + if kv_connector != "LMCacheConnectorV1": + raise ValueError("Disagg only support LMCacheConnectorV1 connector") + if kv_role == "kv_consumer": + if not self.enable_chunked_prefill: + raise ValueError("Disagg decoder only support chunk scheduler") + + ''' + ================== + End of MLU Hijack + ================== + ''' + return engine_config + + +MluHijackObject.apply_hijack(EngineArgs, + EngineArgs._set_default_args, + vllm__engine__arg_utils__EngineArgs___set_default_args) +MluHijackObject.apply_hijack(EngineArgs, + EngineArgs.create_engine_config, + vllm__engine__arg_utils__EngineArgs__create_engine_config) +MluHijackObject.apply_hijack(EngineArgs, + EngineArgs.get_chunked_prefill_prefix_caching_defaults, + vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults) diff --git a/vllm_mlu/entrypoints/__init__.py b/vllm_mlu/entrypoints/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/entrypoints/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/entrypoints/llm.py b/vllm_mlu/entrypoints/llm.py new file mode 100644 index 0000000..b116592 --- /dev/null +++ b/vllm_mlu/entrypoints/llm.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from tqdm import tqdm +from typing import Callable + +from vllm.entrypoints.llm import LLM +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.logger import init_logger + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu.mlu_metric import LLMMetric +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +logger = init_logger(__name__) + +def vllm__entrypoints__llm__LLM__get_mlu_metrics( + self, + metrics_idx_start, + only_average, + input_len, + output_len, + tp_nums, + quantization, + show_per_iter=False, + is_embedding_task=False, + mm_kwargs=None, + total_prefill_steps=1, + num_speculative_tokens=0, + dp_size=1, +) -> None: + ''' + @brief:该函数用来打印vLLM调用generate接口过程中代码统计的各项性能指标数据 + @params: + metrics_idx_start: 考虑存在调用generate接口为warmup过程的情况, + 因此设置该参数可忽略统计[0,metrics_idx_start)之间的数据,默认为0,即所有性能数据有效。 + only_average: True 只打印N次调用generate接口的平均性能 False 打印每次调用generate接口的性能及其均值 若N次性能数据波动较大,需自行排查测试环境是否稳定。 + 其余参数:均为模型配置参数 + ''' + if mlu_envs.VLLM_LATENCY_DEBUG_EN: + batch_size = self.metric.batch_size_list[-1] * dp_size + if mm_kwargs or is_embedding_task: + # The multimodal and pooling model doesn't support the hfu feature yet. + hfu_info, io_efficiency = None, None + else: + hfu_info, io_efficiency = self.llm_engine.get_hfu_info(batch_size, input_len, output_len) + self.metric.calc_metric( + self.llm_engine.model_config.model, + self.llm_engine.model_config.dtype, + metrics_idx_start, only_average, + input_len, output_len, tp_nums, + quantization, show_per_iter, + is_embedding_task, mm_kwargs, total_prefill_steps, + num_speculative_tokens, dp_size=dp_size, hfu_info=hfu_info, io_efficiency=io_efficiency) + else: + print("Warnning:please set VLLM_LATENCY_DEBUG=true!") + + +def vllm__entrypoints__llm__LLM___run_engine( + self, *, use_tqdm: bool | Callable[..., tqdm] = True +) -> list[RequestOutput | PoolingRequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), + ) + + ''' + ============================= + Added by vllm_mlu + ============================= + ''' + if mlu_envs.VLLM_LATENCY_DEBUG_EN: + total_request_num = self.llm_engine.get_num_unfinished_requests() + e2e_start_time = self.metric.get_mlu_cost_time() + if not self.llm_engine.model_config.is_embedding_task(): + peak_memory, block_memory, num_gpu_blocks, num_cpu_blocks = \ + self.llm_engine.get_memory_usage() + self.metric.update_memory_usage(peak_memory, block_memory, + num_gpu_blocks, num_cpu_blocks) + ''' + ================== + End of addition + ================== + ''' + + # Run the engine. + outputs: list[RequestOutput | PoolingRequestOutput] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + n = len(output.outputs) + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) * n + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum( + len(stp.token_ids) for stp in output.outputs + ) + out_spd = total_out_toks / pbar.format_dict["elapsed"] + pbar.postfix = ( + f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s" + ) + pbar.update(n) + else: + pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() + + if use_tqdm: + pbar.close() + ''' + ============================= + Added by vllm_mlu + ============================= + ''' + if mlu_envs.VLLM_LATENCY_DEBUG_EN: + e2e_end_time = self.metric.get_mlu_cost_time() + e2e_latency = e2e_end_time - e2e_start_time + + engine_step_latency, model_forward_latency, mm_encoder_latency = self.llm_engine.get_latency() + self.metric.update_step_latency(engine_step_latency) + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + self.metric.update_step_latency_device(model_forward_latency) + self.metric.update_mm_encoder_latency_device(mm_encoder_latency) + + self.metric.add_metrics(total_request_num, e2e_latency) + ''' + ================== + End of addition + ================== + ''' + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) + + +LLM.metric = LLMMetric() +MluHijackObject.apply_hijack(LLM, + "get_mlu_metrics", + vllm__entrypoints__llm__LLM__get_mlu_metrics) +MluHijackObject.apply_hijack(LLM, + LLM._run_engine, + vllm__entrypoints__llm__LLM___run_engine) diff --git a/vllm_mlu/entrypoints/openai/__init__.py b/vllm_mlu/entrypoints/openai/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/entrypoints/openai/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/entrypoints/openai/api_server.py b/vllm_mlu/entrypoints/openai/api_server.py new file mode 100644 index 0000000..e73a19a --- /dev/null +++ b/vllm_mlu/entrypoints/openai/api_server.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from fastapi import Request +from fastapi.responses import Response + +import vllm_mlu._mlu_utils as mlu_envs +from vllm.entrypoints.openai.api_server import ( + router, engine_client +) +from vllm_mlu.logger import logger + +if mlu_envs.VLLM_SCHEDULER_PROFILE: + logger.info( + "vLLM V1 Scheduler Profiler is enabled in the API server. Please use " + "'tools/utils/post_scheduler_view_action.py' to dump profiling data " + "after all requests finished.") + + @router.post("/v1/start_scheduler_profile") + async def start_scheduler_profile(raw_request: Request): + logger.info("VLLM-V1 starting scheduler profiler...") + await engine_client(raw_request).start_scheduler_profile() + return Response(status_code=200) + + @router.post("/v1/stop_scheduler_profile") + async def stop_scheduler_profile(raw_request: Request): + logger.info("VLLM-V1 scheduler stopping profiler...") + await engine_client(raw_request).stop_scheduler_profile() + return Response(status_code=200) \ No newline at end of file diff --git a/vllm_mlu/envs.py b/vllm_mlu/envs.py new file mode 100644 index 0000000..9f06ddb --- /dev/null +++ b/vllm_mlu/envs.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import os +from typing import Any, Callable, Dict + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +env_variables: Dict[str, Callable[[], Any]] = { + # max compile thread num + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + "COMPILE_CUSTOM_KERNELS": + lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + "LD_LIBRARY_PATH": + lambda: os.getenv("LD_LIBRARY_PATH", None), + "CXX_COMPILER": + lambda: os.getenv("CXX_COMPILER", None), + "C_COMPILER": + lambda: os.getenv("C_COMPILER", None) +} + +# end-env-vars-definition + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in env_variables: + return env_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(env_variables.keys()) \ No newline at end of file diff --git a/vllm_mlu/executor/__init__.py b/vllm_mlu/executor/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/executor/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/logger.py b/vllm_mlu/logger.py new file mode 100644 index 0000000..6750b01 --- /dev/null +++ b/vllm_mlu/logger.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import logging +from typing import cast +from vllm.logger import _VllmLogger + + +class _ColorFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + if not record.name.startswith('vllm_mlu'): + return True + if record.levelno == logging.INFO: + record.msg = f"\033[32m{record.msg}\033[0m" + elif record.levelno == logging.WARNING: + record.msg = f"\033[33m{record.msg}\033[0m" + return True + + +def _apply_mlu_color(logger): + if not logger.handlers: + return + for h in logger.handlers: + if any(isinstance(f, _ColorFilter) for f in h.filters): + return + h.addFilter(_ColorFilter()) + + +def _mlu_init_logger(name: str) -> logging.Logger: + """Initialize loggers for vllm_mlu module, + and keep the configuration consistent with the vllm module""" + mlu_logger = logging.getLogger(name) + vllm_logger = logging.Logger.manager.loggerDict.get('vllm', None) + if vllm_logger: + mlu_logger.setLevel(vllm_logger.level) + mlu_logger.propagate = vllm_logger.propagate + mlu_logger.handlers = vllm_logger.handlers + return mlu_logger + + +def init_logger(name: str) -> _VllmLogger: + vllm_logger = cast(_VllmLogger, _mlu_init_logger(name)) + _apply_mlu_color(vllm_logger) + return vllm_logger + + +logger = init_logger(__name__) \ No newline at end of file diff --git a/vllm_mlu/lora/__init__.py b/vllm_mlu/lora/__init__.py new file mode 100644 index 0000000..1f684a8 --- /dev/null +++ b/vllm_mlu/lora/__init__.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, +) +from vllm.lora.layers.fused_moe import FusedMoEWithLoRA +from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA +from vllm.lora.layers.row_parallel_linear import ( + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, +) +from vllm.lora.layers.utils import LoRAMapping +from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA + +__all__ = [ + "BaseLayerWithLoRA", + "VocabParallelEmbeddingWithLoRA", + "LogitsProcessorWithLoRA", + "ColumnParallelLinearWithLoRA", + "ColumnParallelLinearWithShardedLoRA", + "MergedColumnParallelLinearWithLoRA", + "MergedColumnParallelLinearWithShardedLoRA", + "MergedQKVParallelLinearWithLoRA", + "MergedQKVParallelLinearWithShardedLoRA", + "QKVParallelLinearWithLoRA", + "QKVParallelLinearWithShardedLoRA", + "RowParallelLinearWithLoRA", + "RowParallelLinearWithShardedLoRA", + "ReplicatedLinearWithLoRA", + "LoRAMapping", + "FusedMoEWithLoRA", +] diff --git a/vllm_mlu/lora/layers/__init__.py b/vllm_mlu/lora/layers/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/lora/layers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/lora/layers/base_linear.py b/vllm_mlu/lora/layers/base_linear.py new file mode 100644 index 0000000..15c1ecd --- /dev/null +++ b/vllm_mlu/lora/layers/base_linear.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA +from vllm.platforms import current_platform + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply( + self, + x: torch.Tensor, + bias: torch.Tensor | None, + residual: torch.Tensor | None = None, +) -> torch.Tensor: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add residual in matmul + ''' + output = self.base_layer.quant_method.apply(self.base_layer, x, bias, residual) + ''' + ================== + End of MLU Hijack + ================== + ''' + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices + ) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + +MluHijackObject.apply_hijack( + BaseLinearLayerWithLoRA, + BaseLinearLayerWithLoRA.apply, + vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply, +) \ No newline at end of file diff --git a/vllm_mlu/lora/layers/column_parallel_linear.py b/vllm_mlu/lora/layers/column_parallel_linear.py new file mode 100644 index 0000000..6c6e1f1 --- /dev/null +++ b/vllm_mlu/lora/layers/column_parallel_linear.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.lora.layers.column_parallel_linear import ColumnParallelLinearWithLoRA +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add smooth_quant_scale and use_tp_weight parameters. +''' +def vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward( + self, + input_, + smooth_quant_scale: torch.Tensor | None = None, + use_tp_weight: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: + assert not use_tp_weight, "LoRa does not support use_tp_weight yet." + assert smooth_quant_scale is None, "LoRA does not support smooth quant yet." + return vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org(self, input_) +''' +================== +End of MLU Hijack +================== +''' + + +MluHijackObject.apply_hijack( + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithLoRA.forward, + vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward, +) \ No newline at end of file diff --git a/vllm_mlu/lora/layers/row_parallel_linear.py b/vllm_mlu/lora/layers/row_parallel_linear.py new file mode 100644 index 0000000..eb58b46 --- /dev/null +++ b/vllm_mlu/lora/layers/row_parallel_linear.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.distributed import ( + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, +) +from vllm.lora.layers.row_parallel_linear import ( + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, +) +from vllm.platforms import current_platform + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply( + self, + x: torch.Tensor, + bias: torch.Tensor | None = None, + residual: torch.Tensor | None = None, +) -> torch.Tensor: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add residual and bias in matmul + ''' + output = self.base_layer.quant_method.apply( + self.base_layer, x, bias, residual) + ''' + ================== + End of MLU Hijack + ================== + ''' + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0 + ) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + if self.tp_size > 1: + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + lora_output: torch.Tensor | None = self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + return output + + +def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward( + self, + input_: torch.Tensor, + residual: torch.Tensor | None = None, + smooth_quant_scale: torch.Tensor | None = None, + use_tp_weight: bool = False, + output: torch.Tensor | None = None, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add parameters `residual`, `smooth_quant_scale`, `use_tp_weight` and `output` + to keep parameters consistent with RowParallelLinear.forward. + ''' + assert (not use_tp_weight) and output is None, ( + f"RowParallelLinearWithLoRA.forward does not support use_tp_wight=True" + f" or pass output parameters.") + ''' + ================== + End of MLU Hijack + ================== + ''' + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size + ) + input_parallel = splitted_input[self.tp_rank].contiguous() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: 1) apply residual fusion in matmul like RowParallelLinear + 2) add bias in matmul, not after all reduce + ''' + # Matrix multiply. + bias_ = ( + None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) + else self.base_layer.bias + ) + residual_ = None if self.base_layer.tp_rank > 0 else residual + output_parallel = self.apply(input_parallel, bias_, residual_) + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.base_layer.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: do not add bias after all_reduce + ''' + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + ''' + ================== + End of MLU Hijack + ================== + ''' + + if not self.base_layer.return_bias: + return output + return output, output_bias + + +MluHijackObject.apply_hijack( + RowParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA.apply, + vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply, +) +MluHijackObject.apply_hijack( + RowParallelLinearWithLoRA, + RowParallelLinearWithLoRA.forward, + vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward, +) \ No newline at end of file diff --git a/vllm_mlu/lora/ops/__init__.py b/vllm_mlu/lora/ops/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/lora/ops/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/lora/ops/triton_ops/__init__.py b/vllm_mlu/lora/ops/triton_ops/__init__.py new file mode 100644 index 0000000..cd6ba6a --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm_mlu.lora.ops.triton_ops.sgmv_expand import sgmv_expand_mlu +from vllm_mlu.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice_mlu +from vllm_mlu.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink_mlu +from vllm_mlu.lora.ops.triton_ops.lora_shrink_op import lora_shrink +from vllm_mlu.lora.ops.triton_ops.lora_expand_op import lora_expand + +__all__ = [ + "sgmv_expand_mlu", + "sgmv_expand_slice_mlu", + "sgmv_shrink_mlu", + "lora_expand", + "lora_shrink" +] diff --git a/vllm_mlu/lora/ops/triton_ops/kernel_utils.py b/vllm_mlu/lora/ops/triton_ops/kernel_utils.py new file mode 100644 index 0000000..a173c31 --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/kernel_utils.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +""" +Utilities for Punica kernel construction. +""" +from vllm.triton_utils import tl, triton + +''' +============================= +Modify by vllm_mlu +============================= +@brief: modify mm triton + 1) add parameter offset_n: mlu add offset_n of matrix B, + value: tl.arange(0, BLOCK_N) + pid_n * BLOCK_N, shape: [BLOCK_N] + add parameter N: mlu add column number of matrix B + 2) tiled_b always need mask in case offset_n > N +''' + +@triton.jit +def mm_k( + a_ptr, + b_ptr, + ak_stride, + bk_stride, + offset_n, + offset_k, + K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + N: tl.constexpr, + CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr, +): + """ + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of + B (k x n), iterate, through the K dimension to compute the partial/complete + matrix block product. + If SPLIT_K == 1, the output m x n product is complete. + If SPLIT_K > 1, the thread block computes partial outputs. The partial + outputs are then atomically summed in the caller code. + Args: + a_ptr: Array of pointers, identifying rows of A + b_ptr: Array of pointers, identifying columns of B + ak_stride: K dimension stride of the A matrix + bk_stride: K dimension stride of the B matrix + K: Length of the K dimension + BLOCK_M: M dimension of the output block m x n + BLOCK_N: N dimension of the output block m x n + BLOCK_K: K dimension atom + EVEN_K: True if the blocks of A and B can be loaded without any + masking. + SPLIT_K: Parameter signifying parallelism in the K dimension. + CAST_TYPE: if True, cast the values from the A matrix to the B + matrix dtype. + b_dtype: datatype of the B matrix + """ + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N, other=0.0) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + tiled_b = tl.load(b_ptr, + mask=(offset_k[:, None] + < K - k * (BLOCK_K * SPLIT_K)) & (offset_n < N)[None, :], + other=0.0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * SPLIT_K * ak_stride + b_ptr += BLOCK_K * SPLIT_K * bk_stride + return accumulator + +''' +================== +End of MLU Hijack +================== +''' + + +@triton.jit +def do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SAME_STRIDE: tl.constexpr, + SLICE_NUM: tl.constexpr, + EVEN_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + ADD_INPUTS: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, + compute the matrix product and store in the appropriate output location. + Given that this is an expand kernel, we don't perform any split-K reduction + as the K dimension is assumed to be small. + """ + + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: # 'same_stride': True + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + # Identify the input_ptr and lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: 1) remove rbn definition: mlu doesn't support contiguous and + will handle as head corruption + 2) re-write b_ptr, use offset_n to identify its position + ''' + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + # rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = tl.arange(0, BLOCK_K) + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride) + # b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + # offset_k[:, None] * cur_lora_d2_stride + + # rbn[None, :] * cur_lora_d1_stride) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + offset_n[None, :] * cur_lora_d1_stride) + + # Compute the block matrix product. + SPLIT_K = 1 + accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, offset_n, + offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N, + CAST_TYPE, cur_lora_ptr.dtype.element_ty) + + ''' + ================== + End of MLU Hijack + ================== + ''' + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = (out_ptr + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] + < (cur_slice_start + N)) + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@triton.jit +def do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, compute the + matrix product and store in the appropriate output location. + """ + + # Identify the lora_ptr from slice_id. + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: 1) remove rbn definition: mlu doesn't support contiguous and + will handle as head corruption + 2) re-write b_ptr, use offset_n to identify its position + ''' + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + # rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + a_ptr = (input_ptr + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + # b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + # rbn[None, :] * lora_d1_stride + + # offset_k[:, None] * lora_d2_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + offset_n[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) + + # Compute partial/complete block matrix product. + accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_n, offset_k, + K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N, False, + cur_lora_ptr.dtype.element_ty) + + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cm = tl.arange(0, BLOCK_M) + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) + + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) diff --git a/vllm_mlu/lora/ops/triton_ops/lora_expand_op.py b/vllm_mlu/lora/ops/triton_ops/lora_expand_op.py new file mode 100644 index 0000000..ef3928b --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/lora_expand_op.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +''' +============================= +Modify by vllm_mlu +============================= +@brief: use vllm_mlu hijacked kernel +''' +from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_expand_kernel +''' +================== +End of MLU Hijack +================== +''' + +from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr + + +@triton.jit +def _lora_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr, +): + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_mn = tl.program_id(axis=0) + pid_m = pid_mn % cta_m_num + pid_n = (pid_mn // cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) + if pid_n * BLOCK_N >= curr_N: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_expand_kernel( + pid_n, + lora_id, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + curr_N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS, + ) + + +@torch.inference_mode() +def _lora_expand( + inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank] + output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (list[torch.Tensor]): lora'b weight + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(0) == len(lora_b_weights) + assert output_tensor.is_contiguous() + + # metadata sanity check. + M = inputs.size(1) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + hidden_sizes_tensor, + same_stride, + MAX_N, + ) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device) + + K = lora_b_weights[0].shape[-1] # K= rank + ADD_INPUTS = add_inputs + MAX_LORAS = lora_ids.size(0) + CAST_TYPE = False + NUM_SLICES = len(lora_b_weights) + + # Triton kernel configs. + BLOCK_M = 64 + BLOCK_N = 128 + BLOCK_K = 16 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + + EVEN_K = K % BLOCK_K == 0 # type: ignore + + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only a few input tokens require + # LoRA. This might not be the best in all cases. + grid = ( + triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks simply exit. + MAX_LORAS, + ) + + _lora_expand_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + MAX_N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + hidden_sizes_tensor, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + NUM_SLICES, + same_stride, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + ) + + return + + +def _lora_expand_fake( + inputs: torch.Tensor, + lora_b_weights: list[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + return + +''' +============================= +Modify by vllm_mlu +============================= +@brief: use only vllm operand +''' + +lora_expand = _lora_expand + +''' +================== +End of MLU Hijack +================== +''' diff --git a/vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py b/vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py new file mode 100644 index 0000000..3d2f035 --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py @@ -0,0 +1,258 @@ + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +''' +============================= +Modify by vllm_mlu +============================= +@brief: use vllm_mlu hijacked kernel +''' +from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_shrink_kernel +''' +================== +End of MLU Hijack +================== +''' +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr + + +@triton.jit +def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, + token_indices_sorted_by_lora_ids, num_tokens_per_lora, + lora_token_start_loc, lora_ids, scaling, + input_d0_stride, input_d1_stride, lora_d0_stride, + lora_d1_stride, lora_d2_stride, output_d0_stride, + output_d1_stride, output_d2_stride, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_sk_m_n = tl.program_id(axis=0) + pid_sk = pid_sk_m_n % SPLIT_K + pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num + pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM) + + +@torch.inference_mode() +def _lora_shrink( + inputs: torch.Tensor, # shape [num_tokens, hidden_size] + lora_a_weights: list[ + torch.Tensor], # shape [num_loras, lora_rank, hidden_size] + output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): Input tensor + lora_a_weights (list[torch.Tensor]): LoRA weights + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + scaling (float): Scaling factor. + """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.dtype == lora_a_weights[0].dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(1) == lora_a_weights[0].size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + # metadata sanity check + M = inputs.size(0) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank + NUM_SLICES = len(lora_a_weights) + MAX_LORAS = lora_ids.size(0) + + # Triton kernel configs + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 256 if M < 128 else 32 + SPLIT_K = 64 if M < 128 else 8 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only few of the input tokens + # require LoRA. This might not be the best in all cases. + grid = ( + SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks exit early. + MAX_LORAS, + ) + + _lora_shrink_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + NUM_SLICES, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + ) + + return + + +def _lora_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: list[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + scaling: float, +) -> None: + return + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: use only vllm operand +''' + +lora_shrink = _lora_shrink + +''' +================== +End of MLU Hijack +================== +''' \ No newline at end of file diff --git a/vllm_mlu/lora/ops/triton_ops/sgmv_expand.py b/vllm_mlu/lora/ops/triton_ops/sgmv_expand.py new file mode 100644 index 0000000..a3891df --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/sgmv_expand.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +import triton +import triton.language as tl + +from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size + +from vllm.utils.torch_utils import direct_register_custom_op + + +@triton.jit +def _sgmv_expand_kernel_mlu( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + The sgmv's expand triton kernel is based on GroupGEMM. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust kernel impl to fit mlu. + ''' + a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \ + offset_k[None, :] * xk_stride + b_ptr = lora_ptr + l0_stride * lora_index + \ + offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride + ''' + ================== + End of MLU Hijack + ================== + ''' + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust kernel impl to fit mlu. + ''' + if EVEN_K: + tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M) + tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N) + else: + tiled_a = tl.load(a_ptr, + mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)), + other=0) + tiled_b = tl.load(b_ptr, + mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)), + other=0) + ''' + ================== + End of MLU Hijack + ================== + ''' + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand_mlu( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Workaround: Adjust block size to meet mlu restrictions. + + The grid of mlu triton kernel must less than 65536, it will be out of bound when + the input seq is very long, and causes runtime error. So we need to adjust the block + size to avoid this. + ''' + BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32) + ''' + ================== + End of MLU Hijack + ================== + ''' + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: call _sgmv_expand_kernel_mlu + ''' + _sgmv_expand_kernel_mlu[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + return diff --git a/vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py b/vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py new file mode 100644 index 0000000..7f7035f --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +import triton +import triton.language as tl + +from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size + +from vllm.utils.torch_utils import direct_register_custom_op + +@triton.jit +def _sgmv_expand_slice_kernel_mlu( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust kernel impl to fit mlu. + ''' + a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \ + offset_k[None, :] * xk_stride + b_ptr = lora_ptr + l0_stride * lora_index + \ + offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride + ''' + ================== + End of MLU Hijack + ================== + ''' + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust kernel impl to fit mlu. + ''' + if EVEN_K: + tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M) + tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N) + else: + tiled_a = tl.load(a_ptr, + mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)), + other=0) + tiled_b = tl.load(b_ptr, + mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)), + other=0) + ''' + ================== + End of MLU Hijack + ================== + ''' + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < + (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand_slice_mlu( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + """_summary_ + + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Workaround: Adjust block size to meet mlu restrictions. + + The grid of mlu triton kernel must less than 65536, it will be out of bound when + the input seq is very long, and causes runtime error. So we need to adjust the block + size to avoid this. + ''' + BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32) + ''' + ================== + End of MLU Hijack + ================== + ''' + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: call _sgmv_expand_kernel_mlu + ''' + _sgmv_expand_slice_kernel_mlu[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + return diff --git a/vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py b/vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py new file mode 100644 index 0000000..cdb88ee --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +import triton +import triton.language as tl + +from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size + +from vllm.utils.torch_utils import direct_register_custom_op + + +@triton.jit +def _sgmv_shrink_kernel_mlu( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_sk = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust kernel impl to fit mlu. + ''' + a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \ + offset_k[None, :] * xk_stride + b_ptr = lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride + \ + offset_k[:, None] * lora_n_stride + ''' + ================== + End of MLU Hijack + ================== + ''' + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust kernel impl to fit mlu. + ''' + if EVEN_K: + tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M) + tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, + mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)), + other=0.0) + tiled_b = tl.load(b_ptr, + mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)), + other=0.0) + ''' + ================== + End of MLU Hijack + ================== + ''' + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def sgmv_shrink_mlu( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_a_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Workaround: adjust block size to meet mlu restrictions. + + The grid of mlu triton kernel must less than 65536, it will be out of bound when + the input seq is very long, and causes runtime error. So we need to adjust the block + size to avoid this. + ''' + BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16) + ''' + ================== + End of MLU Hijack + ================== + ''' + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K, + batches, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: call _sgmv_shrink_kernel_mlu + ''' + _sgmv_shrink_kernel_mlu[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + return diff --git a/vllm_mlu/lora/ops/triton_ops/utils.py b/vllm_mlu/lora/ops/triton_ops/utils.py new file mode 100644 index 0000000..176e73c --- /dev/null +++ b/vllm_mlu/lora/ops/triton_ops/utils.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Tuple +from math import ceil + +_MLU_MAX_GRID_SIZE = 65536 + +def adjust_kernel_block_size( + m: int, + block_m: int, + n: int, + block_n: int +) -> Tuple[int, int]: + """Adjust block size to meet mlu triton grid restrictions. + + Calculation of the max block size in candidates list: + + LLama3.1-8b-tp1 max n is 14336 + LLama3.1-70b-tp4 max n is 7168 + LLama3.1-405b-tp8 max n is 6656 + + when n is 14336, the max sequence length of block size 256 can be + floor(65536 / ceil(14336 / 256)) * 256 = 299520. + """ + candidates_list = [16, 32, 64, 96, 128, 192, 256] + candidates_list_len = len(candidates_list) + m_idx = 1 + n_idx = 0 if block_n == 16 else 1 + while m_idx < candidates_list_len and n_idx < candidates_list_len: + block_m = candidates_list[m_idx] + block_n = candidates_list[n_idx] + if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE: + break + if m_idx < candidates_list_len: + m_idx += 1 + if n_idx < candidates_list_len: + n_idx += 1 + if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE: + raise ValueError(f"the max seq len {m} is too long for lora triton kernel") + return block_m, block_n diff --git a/vllm_mlu/lora/punica_wrapper/__init__.py b/vllm_mlu/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/lora/punica_wrapper/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/lora/punica_wrapper/punica_mlu.py b/vllm_mlu/lora/punica_wrapper/punica_mlu.py new file mode 100644 index 0000000..d6c3c2f --- /dev/null +++ b/vllm_mlu/lora/punica_wrapper/punica_mlu.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Optional, Tuple, Union, final + +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu + from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu + from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu + +from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU + + +@final +class PunicaWrapperMLU(PunicaWrapperCPU): + """ + PunicaWrapperMLU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica triton kernel. + """ + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink_mlu( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_mlu( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice_mlu( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) \ No newline at end of file diff --git a/vllm_mlu/mlu_forward_context.py b/vllm_mlu/mlu_forward_context.py new file mode 100644 index 0000000..1d184a6 --- /dev/null +++ b/vllm_mlu/mlu_forward_context.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from dataclasses import dataclass +from typing import List + +from vllm.forward_context import DPMetadata + + +@dataclass +class MLUDPMetadata(DPMetadata): + # mlu platform arguments + # token num for current dp group + token_num: int = None + # token num offset for current dp group + token_num_offset: int = None + # whether we can use reduce scatter for both attn layer and mlp layer + layer_use_reduce_scatter: bool = False + # token num need to be pad for prefill, then we can do reduce scatter + + # all gather to optimize comm time + prefill_pad_to_token_num: int = -1 + # token num in each dp group, the list length is attn data parallel size + # used to do all gather in dp groups after all reduce in attn + token_split_list: List[int] = None + # token num in each card, the list length is world size + # used to do all gather in all cards after reduce scatter in attn + attn_token_split_list_reduce_scatter: List[int] = None + # token num in each tp group, the list length is tensor parallel size + # used to do all gather in tp groups after reduce scatter in moe + moe_token_split_list_reduce_scatter: List[int] = None + # prefill or decode stage in each dp group + dp_is_prefill: List[bool] = None + + # ADDITIONAL fields for merged compute and communication. + # Global sequence lengths for each batch size for prefill stage. + seq_lens: List[int] = None + # Batch sizes for each attn dp rank for prefill stage. + batch_sizes: List[int] = None + + # ADDITIONAL fields for custom split for embedding, logits and dense mlp layer + # token num in each emb tp group, the list length is tensor parallel size + # used to do all gather in emb tp groups after reduce scatter in moe + emb_token_split_list: List[int] = None + # batch sizes in each logits tp group, the list length is tensor parallel size + # used to do all gather in logits tp groups after reduce scatter in moe + logits_batch_split_list: List[int] = None + # token num in each dense mlp group, the list length is dense mlp tp size + # used to do one more all gather after dense mlp and before reduce scatter + dense_attn_token_split_list: List[int] = None + + @staticmethod + def make_oot( + data_parallel_rank: int, + data_parallel_size: int, + tensor_parallel_size: int, + dp_token_nums: List[int], + dp_is_prefill: List[bool], + prefill_dispatch_use_RS_AG: bool, + seq_lens: List[int] = None, + batch_sizes: List[int] = None, + emb_query_lens: List[int] = None, + logits_batch_sizes: List[int] = None, + dense_attn_token_split_list: List[int] = None, + ) -> "MLUDPMetadata": + token_num_offset = sum(dp_token_nums[:data_parallel_rank]) + token_num = dp_token_nums[data_parallel_rank] + token_split_list = dp_token_nums + + attn_can_use_reduce_scatter = all( + (num != 0 and num % tensor_parallel_size == 0) + for num in token_split_list + ) + all_split_token_num_equal = all( + num == token_split_list[0] for num in token_split_list + ) + layer_can_use_reduce_scatter = ( + attn_can_use_reduce_scatter and all_split_token_num_equal + ) + + attn_token_split_list_reduce_scatter = None + moe_token_split_list_reduce_scatter = None + prefill_pad_to_token_num = -1 + tp_world_size = data_parallel_size * tensor_parallel_size + if layer_can_use_reduce_scatter: + attn_token_split_list_reduce_scatter = ( + [token_split_list[0] // tensor_parallel_size] * tp_world_size + ) + moe_token_split_list_reduce_scatter = ( + attn_token_split_list_reduce_scatter[:tensor_parallel_size] + ) + elif ( + prefill_dispatch_use_RS_AG + and all(is_prefill for is_prefill in dp_is_prefill) + ): + dp_group_max_token_nums = max(dp_token_nums) + prefill_pad_to_token_num = ( + (dp_group_max_token_nums + tensor_parallel_size - 1) + // tensor_parallel_size + ) * tensor_parallel_size + attn_token_split_list_reduce_scatter = ( + [prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size + ) + + return MLUDPMetadata( + max_tokens_across_dp_cpu=None, + num_tokens_across_dp_cpu=None, + token_num=token_num, + token_num_offset=token_num_offset, + token_split_list=token_split_list, + layer_use_reduce_scatter=layer_can_use_reduce_scatter, + prefill_pad_to_token_num=prefill_pad_to_token_num, + attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter, + moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter, + seq_lens=seq_lens, + batch_sizes=batch_sizes, + dp_is_prefill=dp_is_prefill, + emb_token_split_list=emb_query_lens, + logits_batch_split_list=logits_batch_sizes, + dense_attn_token_split_list=dense_attn_token_split_list, + ) diff --git a/vllm_mlu/mlu_hijack.py b/vllm_mlu/mlu_hijack.py new file mode 100644 index 0000000..cc45def --- /dev/null +++ b/vllm_mlu/mlu_hijack.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import importlib.util +from vllm_mlu._mlu_utils import * +from vllm_mlu.logger import logger + + +def is_module_available(module_name): + spec = importlib.util.find_spec(module_name) + return spec is not None + +def check_environ_compatibility(): + if is_module_available('apex'): + logger.error(f"The `apex` package is currently present in your environment, " + f"which may cause model accuracy issues or other problems. It is " + f"strongly recommended that you uninstall it before using vLLM.") + +# Check environment compatibility first before applying mlu hijack. +check_environ_compatibility() + +logger.info(f"[MLU] Apply Monkey Patch.") + +# Apply v1 hijack +import vllm_mlu.v1.engine.core +import vllm_mlu.v1.engine.core_client +import vllm_mlu.v1.engine.llm_engine +import vllm_mlu.v1.engine.async_llm +import vllm_mlu.v1.core.sched.scheduler +import vllm_mlu.v1.core.single_type_kv_cache_manager +import vllm_mlu.v1.core.kv_cache_utils +import vllm_mlu.v1.core.kv_cache_manager +import vllm_mlu.v1.executor.abstract +import vllm_mlu.v1.executor.ray_executor +import vllm_mlu.v1.executor.multiproc_executor +import vllm_mlu.v1.sample.rejection_sampler +import vllm_mlu.v1.worker.lora_model_runner_mixin +import vllm_mlu.v1.worker.block_table +import vllm_mlu.v1.worker.gpu_input_batch +import vllm_mlu.v1.worker.kv_connector_model_runner_mixin +import vllm_mlu.v1.attention.backends.gdn_attn +import vllm_mlu.v1.attention.backends.mla.flashmla +import vllm_mlu.compilation.fix_functionalization + +# Apply common hijack +import vllm_mlu.attention.layer +import vllm_mlu.benchmarks.datasets +import vllm_mlu.config.model +import vllm_mlu.config.scheduler +import vllm_mlu.config.speculative +import vllm_mlu.config.vllm +import vllm_mlu.utils +import vllm_mlu.distributed.parallel_state +import vllm_mlu.distributed.kv_transfer.kv_connector.factory +import vllm_mlu.engine.arg_utils +import vllm_mlu.entrypoints.llm +import vllm_mlu.lora.layers.base_linear +import vllm_mlu.lora.layers.row_parallel_linear +import vllm_mlu.lora.layers.column_parallel_linear +import vllm_mlu.model_executor.parameter +import vllm_mlu.model_executor.layers.linear +import vllm_mlu.model_executor.layers.rotary_embedding +import vllm_mlu.model_executor.layers.quantization.utils.w8a8_utils +import vllm_mlu.model_executor.layers.quantization.fp8 +import vllm_mlu.model_executor.layers.activation +import vllm_mlu.model_executor.layers.layernorm +import vllm_mlu.model_executor.layers.fused_moe.layer +import vllm_mlu.model_executor.model_loader.tensorizer_loader +import vllm_mlu.model_executor.models.registry +import vllm_mlu.model_executor.models.config +import vllm_mlu.multimodal.utils +if is_module_available('lmcache'): + import vllm_mlu.distributed.kv_transfer.kv_connector.v1.lmcache_connector + +if VLLM_CI_ACCURACY_TEST: + import vllm_mlu.model_executor.model_loader.dummy_loader + +if VLLM_SCHEDULER_PROFILE: + import vllm_mlu.entrypoints.openai.api_server diff --git a/vllm_mlu/mlu_hijack_utils.py b/vllm_mlu/mlu_hijack_utils.py new file mode 100644 index 0000000..a08d304 --- /dev/null +++ b/vllm_mlu/mlu_hijack_utils.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +IS_GATED=False + +class MluHijackObject: + hijack_objs = [] + + @classmethod + def apply_hijack(cls, obj, org_func, hijack_func, + verify_orig_func_exists: bool = False): + """ + Optional Args: + verify_orig_func_exists (bool): If True, verifies that hijack succeeds + """ + cls.hijack_objs.append((obj, org_func, hijack_func)) + + if type(org_func) == str: + org_func_name = org_func + else: + if isinstance(org_func, property): + split_name = org_func.fget.__name__.split('__') + else: + split_name = org_func.__name__.split('__') + org_func_name = split_name[-1] + if org_func_name == "": + assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack" + org_func_name = split_name[-2] + "__" + if len(split_name) >= 3 and split_name[-3] == "": + org_func_name = "__" + org_func_name + + if verify_orig_func_exists and not hasattr(obj, org_func_name): + raise AttributeError(f"function {org_func_name} is not part of {obj}") + + setattr(obj, org_func_name, hijack_func) + + if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func): + raise AttributeError( + f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}") + + + @classmethod + def undo_hijack(cls, obj_ = None, hijack_func_ = None): + if obj_ and hijack_func_: + for obj, org_func, hijack_func in cls.hijack_objs: + if obj_ == obj and hijack_func == hijack_func_: + if type(org_func) == str: + if hasattr(obj, org_func): + delattr(obj, org_func) + else: + org_func_name = org_func.__name__ + setattr(obj, org_func_name, org_func) + return + for obj, org_func, hijack_func in cls.hijack_objs: + if type(org_func) == str: + if hasattr(obj, org_func): + delattr(obj, org_func) + else: + org_func_name = org_func.__name__ + setattr(obj, org_func_name, org_func) + + +TypedDict = { + "hidden_size": 0, + "vocab_size": 0, + "ffn_inner_size": 0, + "moe_inner_size": 0, + "layer_num": 0, + "moe_layer_num": 0, + "head_num": 0, + "head_size": 0, + "head_num_kv": 0, + "tp_num": 0, + "shared_expert_intermediate_size": 0, + "shared_experts": 0, + "qk_nope_head_dim": 0, + "qk_rope_head_dim": 0, + "q_lora_rank": 0.0, + "num_attention_heads": 0, + "kv_lora_rank": 0, + "v_head_dim": 0, + "use_gated_ffn": False, + "experts_num": 0, + "topk_num": 0, + "use_causal_mask": False, + "cla_coeffient": 0, + "kv_cache_dtype": "", + "smooth_quant_type": "", + "data_type": "", + "model_type": "", + "filter_data_type": "", +} + + +def set_is_gated(flag): + global IS_GATED + IS_GATED=flag + +def get_is_gated(): + return IS_GATED diff --git a/vllm_mlu/mlu_metric.py b/vllm_mlu/mlu_metric.py new file mode 100644 index 0000000..3b0d09e --- /dev/null +++ b/vllm_mlu/mlu_metric.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +import time +import statistics +import pandas as pd +import numpy as np +import json +import os +from datetime import datetime + +from vllm.logger import init_logger +from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG_WITH_DEVICE_EN, VLLM_DUMP_MLU_INFO_EN +from vllm.model_executor.layers.quantization import get_quantization_config + +logger = init_logger(__name__) + +millisecond2second_unit = 1000 + +class LLMMetric: + def __init__(self)->None: + self.batch_size_list = [] + self.context_latency_list = [] + self.e2e_latency_list = [] + self.per_token_latency_list = [ [] ] + self.per_token_latency_device_list = [ [] ] + self.mm_encoder_latency_device_list = [ [] ] + self.peak_memory = 0 + self.block_memory = 0 + self.num_total_gpu_blocks = 0 + self.num_total_cpu_blocks = 0 + self.num_free_gpu_blocks_list = [ [] ] + self.num_free_cpu_blocks_list = [ [] ] + self.num_spec_tokens = 0 + self.draft_acceptance_rate = 0.0 + self.context_latency_device = 0.0 + self.generate_latency_device = 0.0 + self.mm_encoder_latency_device = 0.0 + + def reset_metric(self): + self.batch_size_list = [] + self.context_latency_list = [] + self.e2e_latency_list = [] + self.per_token_latency_list = [ [] ] + self.per_token_latency_device_list = [ [] ] + self.mm_encoder_latency_device_list = [ [] ] + self.num_free_gpu_blocks_list = [ [] ] + self.num_free_cpu_blocks_list = [ [] ] + self.num_spec_tokens = 0 + self.draft_acceptance_rate = 0.0 + + @classmethod + def get_mlu_cost_time(cls): + torch.mlu.synchronize() + return time.perf_counter() + + def is_prefill_stage(self): + return len(self.per_token_latency_list[-1]) == 0 + + def update_memory_usage(self, peak_memory, block_memory, num_total_gpu_blocks, num_total_cpu_blocks): + self.peak_memory = peak_memory + self.block_memory = block_memory + self.num_total_gpu_blocks = num_total_gpu_blocks + self.num_total_cpu_blocks = num_total_cpu_blocks + + def update_step_block_usage(self, num_free_gpu_blocks, num_free_cpu_blocks): + self.num_free_gpu_blocks_list[-1].append(num_free_gpu_blocks) + self.num_free_cpu_blocks_list[-1].append(num_free_cpu_blocks) + + def update_step_latency(self, step_latency): + if isinstance(step_latency, list): + self.per_token_latency_list[-1].extend(step_latency) + else: + self.per_token_latency_list[-1].append(step_latency) + + def update_step_latency_device(self, step_latency): + if isinstance(step_latency, list): + self.per_token_latency_device_list[-1].extend(step_latency) + else: + self.per_token_latency_device_list[-1].append(step_latency) + + def update_mm_encoder_latency_device(self, step_latency): + if isinstance(step_latency, list): + if len(step_latency) == 0: + return + assert len(step_latency) == 1, f"Not supported! Model with multi mm encoder steps. {len(step_latency)} {step_latency}" + self.mm_encoder_latency_device_list[-1].extend(step_latency) + else: + self.mm_encoder_latency_device_list[-1].append(step_latency) + + def update_spec_decode_metrics(self, spec_decode_metrics): + self.num_spec_tokens = spec_decode_metrics.num_spec_tokens + self.draft_acceptance_rate = spec_decode_metrics.draft_acceptance_rate + + def add_metrics(self, batch_size, e2e_latency)->None: + self.batch_size_list.append(batch_size) + self.e2e_latency_list.append(e2e_latency) + self.per_token_latency_list.append([]) # new iter + self.per_token_latency_device_list.append([]) + self.mm_encoder_latency_device_list.append([]) + self.num_free_gpu_blocks_list.append([]) + self.num_free_cpu_blocks_list.append([]) + + def get_weight_dtype_str(self, model_path, model_dtype, quantization) -> str: + # get weight dtype based on quantization config if exists + if quantization == 'fp8': + return quantization + if quantization is not None: + quant_method = get_quantization_config(quantization) + # combine the model path with the quantization config file name + quant_config_paths = quant_method.get_config_filenames() + # if there are multiple quantization config files, return the first one existed + for quant_config_path in quant_config_paths: + quant_config_path = os.path.join(model_path, quant_config_path) + # check if the quantization config file exists + if not os.path.exists(quant_config_path): + continue + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + quant_config = quant_method.from_config(quant_config) + # for smoothquant and weightonly, return the quantization name with the weight bits + if quantization == "smoothquant" or quantization == ["weightonly"]: + return "{}-int{}".format(quant_config.get_name(), quant_config.weight_bits) + else: + # for other quantization methods, return the quantization name + return quant_config.get_name() + # if the quantization config file does not exist, just return the quanization name + return quant_config_path.get_name() + else: + # remove the prefix of model dtype from torch config + return str(model_dtype).split(".")[-1] + + def to_csv(self, filename: str, show_per_iter=False) -> None: + if show_per_iter: + df = pd.DataFrame(self.metrics_data) + df = pd.DataFrame([df.iloc[-1]], columns=df.columns) + memory_df = pd.DataFrame(self.memory_metrics_data) + memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns) + else: + df = pd.DataFrame(self.metrics_data) + memory_df = pd.DataFrame(self.memory_metrics_data) + df_mean = df.mean().round(3) + memory_df_mean = memory_df.mean().round(3) + header = ["datetime", "model", + "weight dtype", self.batch_size_name, + ] + header = header + list(self.mm_kwargs.keys()) + header = header + ["input len", "output len", "tp", + self.context_latency_name, self.per_token_latency_name] + data = [datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.model, + self.weight_dtype_str, int(self.metrics_data[self.batch_size_name][0])] + data = data + [self.mm_kwargs[k] for k in self.mm_kwargs.keys()] + data = data + [self.input_len, self.output_len, self.tp, + df_mean[self.context_latency_name], df_mean[self.per_token_latency_name]] + if self.num_spec_tokens > 0: + header += [self.per_step_latency_name] + data += [df_mean[self.per_step_latency_name]] + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + if self.is_v1_multimodal: + header += [self.mm_encoder_latency_device_name,] + data += [df_mean[self.mm_encoder_latency_device_name],] + header += [self.context_latency_device_name, self.per_token_latency_device_name] + data += [df_mean[self.context_latency_device_name], df_mean[self.per_token_latency_device_name]] + header += [self.e2e_latency_name, self.e2e_throughput_name, self.decoder_throughput_name,] + if self.num_spec_tokens > 0: + header += [self.k_name, self.acceptance_rate_name] + header += [self.decode_times_name, + self.peak_memory_name, self.block_memory_name] + data += [ + df_mean[self.e2e_latency_name], df_mean[self.e2e_throughput_name], df_mean[self.decoder_throughput_name], + ] + if self.num_spec_tokens > 0: + data += [self.num_spec_tokens, df_mean[self.acceptance_rate_name],] + data += [df_mean[self.decode_times_name], memory_df_mean[self.peak_memory_name], memory_df_mean[self.block_memory_name],] + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and self.save_hfu_info: + header += [self.context_hfu_name, self.decoder_hfu_name, self.decoder_io_efficiency_name] + data += [ + df_mean[self.context_hfu_name], df_mean[self.decoder_hfu_name], + df_mean[self.decoder_io_efficiency_name] + ] + data_dict = dict(zip(header, data)) + df_csv = pd.DataFrame(data_dict, index=[0]) + append = False + if os.path.isfile(filename): + try: + df_old = pd.read_csv(filename) + append = (df_old.columns.tolist() == header) + except Exception as e: + logger.info(f"Existing {filename} failed to be read and will be overwritten") + if append: + df_csv.to_csv(filename, mode='a', header=False, index=False) + logger.info(f"Metric appended to existing {filename}") + else: + df_csv.to_csv(filename, index=False) + logger.info(f"Metric written to {filename}") + + def calc_metric(self, model, model_dtype, metrics_idx_start, only_average, + input_len, output_len, tp_nums, quantization, + show_per_iter=False, is_embedding_task=False, mm_kwargs=None, + total_prefill_steps=1, num_spec_tokens=0, dp_size=1, hfu_info=None, io_efficiency=0.0) -> None: + keep_digits = 2 + + def round_fn(data): + return round(data, keep_digits) + + metrics_idx_end = len(self.per_token_latency_list) - 1 # without last [] + idx_range = range(metrics_idx_start, metrics_idx_end) + # specify entries to write to csv + self.is_v1_multimodal = mm_kwargs + self.mm_kwargs = mm_kwargs if mm_kwargs else {} # multimodal args + self.batch_size_name = "batch size" + self.input_len = input_len + self.output_len = output_len + self.tp = tp_nums + self.dp = dp_size + self.model = model + self.context_latency_name = "context latency(ms)" + self.mm_encoder_latency_device_name = "multimodal encoder latency device(ms)" + self.context_latency_device_name = "context latency device(ms)" + if num_spec_tokens > 0: + self.per_step_latency_name = "per step latency(ms)" + self.per_token_latency_device_name = "per step latency device(ms)" + else: + self.per_token_latency_device_name = "per token latency device(ms)" + self.per_token_latency_name = "per token latency(ms)" + self.e2e_latency_name = "e2e latency(ms)" + self.e2e_throughput_name = "e2e throughput(tokens/s)" + self.decoder_throughput_name = "decoder throughput(tokens/s)" + self.k_name = "K" + self.acceptance_rate_name = "acceptance rate" + self.decode_times_name = "decode times" + self.weight_dtype_str = self.get_weight_dtype_str(model, model_dtype, quantization) + self.num_spec_tokens = num_spec_tokens + rate_list=[] + rate=0 + if num_spec_tokens > 0: + for i in range(metrics_idx_end): + if len(self.per_token_latency_list[i]) - total_prefill_steps == 0: + logger.warning("For now output_len is 0, no need mtp info, if you need mtp info, please increase output_len.") + rate_list.append(0.0) + else: + rate_list.append(((self.output_len - 1) / (float)(len(self.per_token_latency_list[i]) - total_prefill_steps) - 1) / num_spec_tokens) + rate = statistics.fmean(rate_list[metrics_idx_start: metrics_idx_end]) + metrics_data = [ + ( + self.batch_size_name, [self.dp * int(self.batch_size_list[i]) for i in idx_range] + ), + ( + self.context_latency_name, [round_fn(millisecond2second_unit * sum(self.per_token_latency_list[i][:total_prefill_steps])) for i in idx_range] + ), + ( + self.per_token_latency_name, [ + 0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \ + round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * (len(self.per_token_latency_list[i]) - total_prefill_steps) / (self.output_len - 1) * millisecond2second_unit) for i in idx_range + ] + ), + ] + if num_spec_tokens > 0: + metrics_data += [(self.per_step_latency_name, [ + 0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \ + round_fn(statistics.fmean(self.per_token_latency_list[i][total_prefill_steps:]) * millisecond2second_unit) for i in idx_range + ])] + metrics_data += [ + ( + self.e2e_latency_name, [round_fn(millisecond2second_unit * self.e2e_latency_list[i]) for i in idx_range] + ), + ( + self.e2e_throughput_name, [ + round_fn(self.dp * (output_len / self.e2e_latency_list[i]) * self.batch_size_list[i]) \ + for i in idx_range + ] + ), + ( + self.decoder_throughput_name, [ + 0.0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \ + round_fn(self.dp * ((output_len-1) / sum(self.per_token_latency_list[i][total_prefill_steps:])) * self.batch_size_list[i]) \ + for i in idx_range + ] + ), + ( + self.decode_times_name, [ + 0 if len(self.per_token_latency_list[i]) <= total_prefill_steps else \ + len(self.per_token_latency_list[i][total_prefill_steps:]) for i in idx_range + ] + ), + ] + if num_spec_tokens > 0: + metrics_data.append((self.k_name, num_spec_tokens)) + metrics_data.append((self.acceptance_rate_name, [rate_list[i] for i in idx_range])) + + insert_latency_device = VLLM_LATENCY_DEBUG_WITH_DEVICE_EN + if insert_latency_device: + device_item_idx = 3 + if self.is_v1_multimodal: + mm_encoder_latency_device = [round_fn(sum(self.mm_encoder_latency_device_list[i])) for i in idx_range] + metrics_data.insert(device_item_idx, (self.mm_encoder_latency_device_name, mm_encoder_latency_device)) + device_item_idx = device_item_idx + 1 + context_latency_device = [round_fn(sum(self.per_token_latency_device_list[i][:total_prefill_steps])) for i in idx_range] + per_token_latency_device = [0.0 if len(self.per_token_latency_device_list[i]) <= total_prefill_steps else \ + round_fn(statistics.fmean(self.per_token_latency_device_list[i][total_prefill_steps:])) for i in idx_range] + metrics_data.insert(device_item_idx, (self.context_latency_device_name, context_latency_device)) + metrics_data.insert(device_item_idx + 1, (self.per_token_latency_device_name, per_token_latency_device)) + + self.metrics_data = dict(metrics_data) + + # Print + df = pd.DataFrame(self.metrics_data) + if show_per_iter: + df = pd.DataFrame([df.iloc[-1]], columns=df.columns) + else: + df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = df.mean().round(keep_digits) + if only_average: + df = pd.DataFrame([df.iloc[-1]], columns=df.columns) + + df.index.name = 'iter index' + df[self.batch_size_name] = df[self.batch_size_name].astype(int) + if num_spec_tokens > 0: + df[self.k_name] = df[self.k_name].astype(int) + + self.peak_memory_name = "profile memory(GB)" + self.block_memory_name = "total cache memory(GB)" + memory_metrics_data = [ + ( + self.peak_memory_name, [round_fn(self.peak_memory / 1024 / 1024 / 1024) for i in idx_range] + ), + ( + self.block_memory_name, [round_fn(self.block_memory / 1024 / 1024 / 1024) for i in idx_range] + ), + ] + + self.memory_metrics_data = dict(memory_metrics_data) + + # Print + memory_df = pd.DataFrame(self.memory_metrics_data) + if show_per_iter: + memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns) + else: + memory_df.loc["Average(" + str(metrics_idx_end-metrics_idx_start) + "iters)"] = memory_df.mean().round(keep_digits) + if only_average: + memory_df = pd.DataFrame([memory_df.iloc[-1]], columns=memory_df.columns) + + memory_df.index.name = 'iter index' + + pd.set_option('display.colheader_justify', 'center') + pd.set_option('display.max_columns', None) + pd.set_option('display.max_rows', None) + print("********************************* Test Info****************************") + mm_params_text = " ".join(f"{key}:{value}" for key, value in self.mm_kwargs.items()) + print("Generation Config {} input len:{} output len:{} tp_nums:{} quantization:{}".format( + mm_params_text, input_len,output_len,tp_nums,quantization)) + self.context_latency_device = np.mean(self.metrics_data['context latency device(ms)']) + self.generate_latency_device = np.mean(self.metrics_data[self.per_token_latency_device_name]) + if self.is_v1_multimodal: + self.mm_encoder_latency_device = np.mean(self.metrics_data[self.mm_encoder_latency_device_name]) + + print("*************************Performance Info******************************") + print(f"Total prefill steps: {total_prefill_steps}") + print(df.to_string()) + if not is_embedding_task: + # embedding task does not do profile run, so does not have memory infos + print(memory_df.to_string()) + if insert_latency_device : + context_latency = np.mean(self.metrics_data['context latency device(ms)']) + generate_latency = np.mean(self.metrics_data[self.per_token_latency_device_name]) + if num_spec_tokens > 0: + print("MTP token accept rate: {:.2f}%".format(rate*100)) + self.dump_performance_info(hfu_info, io_efficiency) + + avg_latency_e2e = sum(sum(self.per_token_latency_list[i]) for i in idx_range) / len(idx_range) + print("Avg latency without host time is :", avg_latency_e2e) + print("***********************************************************************") + # collect context_hfu and + self.save_hfu_info = False + if insert_latency_device: + if VLLM_DUMP_MLU_INFO_EN: + try: + import device_info + self.save_hfu_info = True + except: + logger.info(f"try import device_info failed. try pip install device_info.") + self.context_hfu_name = "Context HFU" + self.decoder_hfu_name = "Decoder HFU" + self.decoder_io_efficiency_name = "Decoder IO Efficiency" + if self.save_hfu_info: + self.metrics_data[self.context_hfu_name] = hfu_info["context_hfu"] * 100 + self.metrics_data[self.decoder_hfu_name] = hfu_info["decoder_hfu"] * 100 + self.metrics_data[self.decoder_io_efficiency_name] = io_efficiency * 100 + if csv_path := os.getenv("OUTPUT_CSV_PATH"): + try: + if dir_path := os.path.dirname(csv_path): + os.makedirs(dir_path, exist_ok=True) + self.to_csv(csv_path, show_per_iter=show_per_iter) + except Exception as e: + logger.error(f"Invalid OUTPUT_CSV_PATH: {csv_path} to dump metrics, Error: {e}") + + def dump_performance_info(self, hfu_info, io_efficiency): + try: + if VLLM_DUMP_MLU_INFO_EN and hfu_info != None: + hfu_info["context_hfu"] = hfu_info["context_hfu"] / (self.context_latency_device / millisecond2second_unit) + hfu_info["decoder_hfu"] = hfu_info["decoder_hfu"] / (self.generate_latency_device / millisecond2second_unit) + io_efficiency = io_efficiency / self.generate_latency_device + print(f"Context HFU-visible: {hfu_info['context_hfu']:.3%}") + print(f"Decoder HFU-visible: {hfu_info['decoder_hfu']:.3%}") + print(f"Decoder IO Efficiency: {io_efficiency:.3%}") + elif hfu_info != None: + print(f"Context FLOPS-visible: {hfu_info['context_flops']}") + print(f"Decoder FLOPS-visible: {hfu_info['decoder_flops']}") + else: + logger.info("Unsupport dump performance information") + except Exception as e: + logger.error(f"Failed to dump performance information: {str(e)}") diff --git a/vllm_mlu/model_executor/__init__.py b/vllm_mlu/model_executor/__init__.py new file mode 100755 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/model_executor/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/model_executor/layers/__init__.py b/vllm_mlu/model_executor/layers/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/model_executor/layers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/model_executor/layers/activation.py b/vllm_mlu/model_executor/layers/activation.py new file mode 100644 index 0000000..0f277ec --- /dev/null +++ b/vllm_mlu/model_executor/layers/activation.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +from vllm.model_executor.layers.activation import QuickGELU +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu import _mlu_ops as mlu_ops + +def vllm__model_executor__activation__QuickGELU__forward_oot(self, x: torch.Tensor) -> torch.Tensor: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: implement forward_oot + ''' + return mlu_ops.active(x, 'quick_gelu', False) + ''' + ================== + End of MLU Hijack + ================== + ''' + +MluHijackObject.apply_hijack(QuickGELU, + QuickGELU.forward_oot, + vllm__model_executor__activation__QuickGELU__forward_oot) diff --git a/vllm_mlu/model_executor/layers/compressor.py b/vllm_mlu/model_executor/layers/compressor.py new file mode 100644 index 0000000..f2503be --- /dev/null +++ b/vllm_mlu/model_executor/layers/compressor.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import math +from typing import Callable +from scipy.linalg import hadamard + +import torch +from torch import nn +import torch.nn.functional as F + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.v1.attention.backends.utils import get_common_metadata + + +def hadamard_transform_ref(x, scale=1.0): + """ + x: (..., dim) + out: (..., dim) + """ + x_shape = x.shape + dim = x.shape[-1] + x = x.reshape(-1, dim) + log_dim = math.ceil(math.log2(dim)) + dim_padded = 2 ** log_dim + if dim != dim_padded: + x = F.pad(x, (0, dim_padded - dim)) + out = F.linear( + x, + torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device), + ) + out = out * scale + return out[..., :dim].reshape(*x_shape) + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + hidden_size = x.size(-1) + return hadamard_transform_ref(x, scale=hidden_size ** -0.5) + + +class Compressor(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + rope, + compress_ratio: int = 4, + head_dim: int = 512, + rotate: bool = False, + prefix: str = "", + **kwargs,): + super().__init__() + config = vllm_config.model_config.hf_config + self.dim = config.dim + self.head_dim = head_dim + self.rope_head_dim =config.rope_head_dim + self.nope_head_dim = head_dim - config.rope_head_dim + self.compress_ratio = compress_ratio + self.overlap = compress_ratio == 4 + self.rotate = rotate + coff = 1 + self.overlap + self.norm_eps = config.norm_eps + self.window_size = config.window_size + + self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32)) + # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. + # The first half of dimensions for overlapping compression and second half for normal compression. + + self.wkv = ReplicatedLinear( + self.dim, + coff * self.head_dim, + bias=False, + quant_config=None, + params_dtype = torch.float32, + prefix=f"{prefix}.wkv", + ) + + self.wgate = ReplicatedLinear( + self.dim, + coff * self.head_dim, + bias=False, + quant_config=None, + params_dtype = torch.float32, + prefix=f"{prefix}.wgate", + ) + + self.norm = RMSNorm(self.head_dim, self.norm_eps) + + self.rotary_emb = rope + + hf_config = vllm_config.model_config.hf_config + assert hasattr(hf_config, "cached_state_num"), \ + f"cached_state_num is not set in hf_config" + cached_state_num = hf_config.cached_state_num + self.register_buffer( + "kv_state", + torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "score_state", + torch.full( + (cached_state_num, coff * compress_ratio, coff * self.head_dim), + float("-inf"), + dtype=torch.float32, + ), + persistent=False, + ) + + self.hadamard_matrix = torch.tensor( + hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu") + + def overlap_transform(self, tensor: torch.Tensor, value=0): + # tensor: [b,s,r,2d] + b, s, _, _ = tensor.size() + ratio, d = self.compress_ratio, self.head_dim + new_tensor = tensor.new_full((b, s, 2 * ratio, d), value) + new_tensor[:, :, ratio:] = tensor[:, :, :, d:] + new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d] + return new_tensor + + def forward_decode( + self, + x: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + batch_to_kv_state: torch.Tensor, + kv_cache: torch.Tensor, + window_offset: int, + compressor_slot_mapping: torch.Tensor, + ): + x = x.float() + kv_pack, _ = self.wkv(x) + score_pack, _ = self.wgate(x) + + mlu_ops.fused_compress_single_kv( + kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D) + score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D) + position=positions, + ape=self.ape, + kv_state=self.kv_state, + score_state=self.score_state, + gamma=self.norm.weight, + sin=self.rotary_emb.sin_, + cos=self.rotary_emb.cos_, + hadamard_matrix=self.hadamard_matrix, + slot_mapping=compressor_slot_mapping, + kv_cache=kv_cache, + kv_cache_scale=None, + eps=self.norm_eps, + overlap=self.overlap, + rotate=self.rotate, + state_idx=batch_to_kv_state, + ) + + # Here, return fake compressed_kv. + return None + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + batch_to_kv_state: torch.Tensor, + kv_cache: torch.Tensor, + window_offset: int, + compressor_slot_mapping: torch.Tensor, + ): + common_metadata = get_common_metadata() + forward_func: Callable = ( + self.forward_prefill if common_metadata.is_prefill_only + else self.forward_decode + ) + return forward_func( + x, + positions, + attn_metadata, + batch_to_kv_state, + kv_cache, + window_offset, + compressor_slot_mapping, + ) + + def forward_prefill( + self, + x: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + batch_to_kv_state: torch.Tensor, + kv_cache: torch.Tensor, + window_offset: int, + compressor_slot_mapping: torch.Tensor, + ): + common_metadata = get_common_metadata() + seq_lens = common_metadata.seq_lens + query_start_loc = common_metadata.query_start_loc + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + ratio, overlap = self.compress_ratio, self.overlap + dtype = x.dtype + x = x.float() + kv_pack, _ = self.wkv(x) + score_pack, _ = self.wgate(x) + + compress_lens = query_lens // self.compress_ratio + cu_compress_lens = torch.cat([ + torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device), + torch.cumsum(compress_lens, dim=0)], + ) + + compress_positions = [] + for i in range(len(seq_lens)): + seqlen = (query_start_loc[i+1] - query_start_loc[i]).item() + remainder = seqlen % ratio + cutoff = seqlen - remainder + pos = positions[query_start_loc[i]: query_start_loc[i+1]] + positions_ = pos[:cutoff:ratio].contiguous() + compress_positions.append(positions_) + kv_positions = torch.cat(compress_positions, dim=0) + + + total_compress_len = cu_compress_lens[-1].item() + kv = torch.empty( + [total_compress_len, self.head_dim], + dtype=kv_pack.dtype, + device=kv_pack.device, + ) + + mlu_ops.fused_compress_multi_kv( + kv = kv_pack, + score = score_pack, + kv_state = self.kv_state, + score_state = self.score_state, + state_batch_idx = batch_to_kv_state, + cu_seqlens = query_start_loc, + ape = self.ape, + max_seqlen = common_metadata.max_query_len, + overlap = overlap, + compressed_kv = kv, + ) + + if kv.size(0) == 0: + return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size) + + + kv = self.norm(kv.to(dtype)) + + kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2) + # use compressed cu_seqlens here, so can not call rotary_emb directly + kv_rope = mlu_ops.rotary_embedding( + kv_rope, + self.rotary_emb.sin_, + self.rotary_emb.cos_, + kv_positions, + torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens + True, # interleaved + True, # discrete + False, + common_metadata.max_query_len, + ) + + if self.rotate: + kv = rotate_activation(kv) + + mlu_ops.reshape_paged_cache( + kv.unsqueeze(1), + None, + kv_cache, + None, + compressor_slot_mapping, + ) + + return kv.unsqueeze(-2) # (compress_token_num, 1, head_size) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/dp_logits_processor.py b/vllm_mlu/model_executor/layers/dp_logits_processor.py new file mode 100644 index 0000000..f7d5260 --- /dev/null +++ b/vllm_mlu/model_executor/layers/dp_logits_processor.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Optional + +import torch + +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm_mlu.model_executor.models.dp_utils import ( + tensor_model_parallel_all_gather_dp, DataParallelRuntimeParams) + + +class DPLogitsProcessor(LogitsProcessor): + """DP LogitsProcessor.""" + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor], + dp_params: Optional[DataParallelRuntimeParams] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + batch_sizes = None + if (lm_head.tp_group is not None + and dp_params is not None + and dp_params.logits_batch_split_list is not None): + batch_sizes = dp_params.logits_batch_split_list + hidden_states = tensor_model_parallel_all_gather_dp( + group_num_tokens=batch_sizes, + rank=lm_head.tp_rank, + hidden_states=hidden_states, + group=lm_head.tp_group, + ) + + logits = lm_head.quant_method.apply( + lm_head, hidden_states, bias=embedding_bias) + + if self.use_all_gather: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits, tp_group=lm_head.tp_group) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits, tp_group=lm_head.tp_group) + + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., : self.org_vocab_size] + + if batch_sizes is not None: + offset = sum(batch_sizes[:lm_head.tp_rank]) + logits = logits[offset : offset + batch_sizes[lm_head.tp_rank]] + + return logits + + def forward( + self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + dp_params: Optional[DataParallelRuntimeParams] = None, + ) -> Optional[torch.Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + # Get the logits for the next tokens. + logits = self._get_logits( + hidden_states, lm_head, embedding_bias, dp_params) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + return logits diff --git a/vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py b/vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py new file mode 100644 index 0000000..6ffe081 --- /dev/null +++ b/vllm_mlu/model_executor/layers/dp_vocab_parallel_embedding.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding, +) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, + VocabParallelEmbedding, + DEFAULT_VOCAB_PADDING_SIZE, + get_masked_input_and_mask, + pad_vocab_size, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_reduce, +) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_logits_tp_group, + get_logits_tp_world_size, + get_logits_tp_rank, +) +from vllm_mlu.model_executor.models.dp_utils import ( + DataParallelRuntimeParams, + tensor_model_parallel_all_gather_dp, +) + + +class DPVocabParallelEmbedding(VocabParallelEmbedding): + """DP Embedding parallelized in the vocabulary dimension.""" + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + torch.nn.Module.__init__(self) + + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: add self.tp_group, world_size and tp_rank to support other parallel + """ + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_world_size = get_tensor_model_parallel_world_size() + self.tp_group = None + logits_tp_world_size = get_logits_tp_world_size() + if logits_tp_world_size != self.tp_world_size: + self.tp_group = get_logits_tp_group() + self.tp_world_size = logits_tp_world_size + self.tp_rank = get_logits_tp_rank() + + # Keep the input dimensions. + tp_rank = self.tp_rank + self.tp_size = self.tp_world_size + """ + ================= + End of MLU Hijack + ================= + """ + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + def forward(self, input_, + dp_params: Optional[DataParallelRuntimeParams] = None): + token_split_list = None + if (dp_params is not None + and self.tp_group is not None + and dp_params.emb_token_split_list is not None): + token_split_list = dp_params.emb_token_split_list + input_ = tensor_model_parallel_all_gather_dp( + group_num_tokens=token_split_list, + rank=self.tp_rank, + hidden_states=input_.reshape(-1, 1), + group=self.tp_group, + ).reshape(-1) + + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, + self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index, + ) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group) + + if token_split_list is not None: + offset = sum(token_split_list[:self.tp_rank]) + output = output[offset : offset + token_split_list[self.tp_rank]] + + return output + + +class DPParallelLMHead(DPVocabParallelEmbedding): + """DP Parallelized LM head. + + NOTE: A copy of ParallelLMHead class, and only change its parent + from VocabParallelEmbedding to DPVocabParallelEmbedding. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size, quant_config, + prefix) + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm_mlu/model_executor/layers/feed_forward.py b/vllm_mlu/model_executor/layers/feed_forward.py new file mode 100755 index 0000000..859fc2f --- /dev/null +++ b/vllm_mlu/model_executor/layers/feed_forward.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +import torch.nn.functional as F +from typing import Any + +from vllm.distributed import ( + get_parallel_world_size_with_group, + get_parallel_rank_with_group, +) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.logger import init_logger +from vllm.lora.layers import BaseLayerWithLoRA +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + ColumnParallelLinear, + RowParallelLinear +) + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.mlu_hijack_utils import set_is_gated + +logger = init_logger(__name__) + +class FeedForward(torch.nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + up_proj_name: str, + is_gated: bool, + down_proj_name: str, + bias: bool, + quant_config: QuantizationConfig | None = None, + skip_bias_add: bool = False, + reduce_results: bool = True, + prefix: str = "", + tp_group: Any = None, + keep_full_weights: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.is_gated = is_gated + self.bias = bias + self.up_proj_name = up_proj_name + self.down_proj_name = down_proj_name + self.quant_config = quant_config + self.is_initialized = False + self.skip_bias_add = skip_bias_add + self.reduce_results = reduce_results + self.use_bt_ffn = True + set_is_gated(self.is_gated) + + # modify tp_size, tp_rank and tp_group when enable data parallel + self.tp_size = get_parallel_world_size_with_group(tp_group) + self.tp_rank = get_parallel_rank_with_group(tp_group) + self.tp_group = tp_group + self.keep_full_weights = keep_full_weights + if self.keep_full_weights: + self.tp_size = 1 + self.tp_rank = 0 + self.tp_group = None + + # up_proj with gate or not + if self.is_gated: + up_proj = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.{up_proj_name}", + tp_group=self.tp_group, + keep_full_weights=keep_full_weights) + else: + up_proj = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=bias, + skip_bias_add=skip_bias_add, + quant_config=quant_config, + prefix=f"{prefix}.{up_proj_name}", + tp_group=self.tp_group, + keep_full_weights=keep_full_weights) + self.register_module(up_proj_name, up_proj) + + # down_proj + down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=bias, + skip_bias_add=skip_bias_add, + reduce_results=reduce_results, + quant_config=quant_config, + prefix=f"{prefix}.{down_proj_name}", + tp_group=self.tp_group, + keep_full_weights=keep_full_weights) + + self.register_module(down_proj_name, down_proj) + + def prepare_weight(self): + if not self.is_initialized: + # alpha and beta are 1.0 and 0.0 respectively due to the fact that we don't need residual for now + self.alpha = 1.0 + self.beta = 0.0 + # place it here to avoid the overhead of calling it in the forward pass + self.is_initialized = True + + def _forward(self, hidden_states): + self.prepare_weight() + up_proj = getattr(self, self.up_proj_name) + down_proj = getattr(self, self.down_proj_name) + act_dict = { + "relu": F.relu, + "gelu": F.gelu, + "silu": F.silu, + } + fc1 = F.linear(hidden_states, up_proj.weight, bias=up_proj.bias) + if self.is_gated: + d = fc1.shape[-1] // 2 + fc1 = act_dict[self.hidden_act](fc1[..., :d]) * fc1[..., d:] + else: + fc1 = act_dict[self.hidden_act](fc1) + fc2 = F.linear(fc1, down_proj.weight, bias=None) + fc2 = tensor_model_parallel_all_reduce(fc2) + if not self.skip_bias_add: + fc2 = fc2 + down_proj.bias if down_proj.bias is not None else fc2 + return fc2 + + def forward_naive( + self, + hidden_states, + residual: torch.Tensor | None = None, + smooth_quant_scale: torch.Tensor | None = None + ): + ''' + used by quant_tools + ''' + assert self.quant_config is None, "ffn naive forward dosen't support quantization" + assert smooth_quant_scale is None, "ffn naive forward dosen't support smooth_quant_scale" + + up_proj = getattr(self, self.up_proj_name) + down_proj = getattr(self, self.down_proj_name) + residual_ = None if self.tp_rank > 0 else residual + fc1, bias = up_proj(hidden_states) + if bias is not None: + fc1 += bias + fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated) + out, bias = down_proj(fc1, residual=residual_) + + if self.skip_bias_add: + return out, bias + return out + + def forward( + self, + hidden_states, + residual: torch.Tensor | None = None, + smooth_quant_scale: torch.Tensor | None = None, + use_tp_weight: bool = False, + output: torch.Tensor | None = None, + ): + self.prepare_weight() + + if self.use_bt_ffn is False: + return self.forward_naive(hidden_states, residual, None) + + up_proj = getattr(self, self.up_proj_name) + down_proj = getattr(self, self.down_proj_name) + residual_ = None if self.tp_rank > 0 else residual + if (self.quant_config is None and not isinstance(up_proj, BaseLayerWithLoRA) + and not isinstance(down_proj, BaseLayerWithLoRA)): + # The matmul formula is the following: + # mul_out = alpha * (matmul(input, filter, transpose\_b=True) + bias) + beta * residual + # output = active(mul_out) + # Notes: We cannot use the activation function in matmul because it does not support gated operation + # we might support its in tmo matmul in the future + up_proj_weight = up_proj.weight + down_proj_weight = down_proj.weight + if self.keep_full_weights and use_tp_weight: + up_proj_weight = up_proj.tp_weight + down_proj_weight = down_proj.tp_weight + fc1 = mlu_ops.matmul(hidden_states.view(-1, self.hidden_size), up_proj_weight, up_proj.bias, + None, 'none', self.alpha, self.beta) + act_out = mlu_ops.active(fc1.float(), self.hidden_act, self.is_gated).to(dtype=fc1.dtype) + beta = 0.0 + if residual_ is not None: + beta = 1.0 + residual_ = residual_.view(-1, residual_.shape[-1]) + out_ = mlu_ops.matmul(act_out, down_proj_weight, None, residual_, 'none', self.alpha, beta) + # bias if existed need to add after second matmul according to the original design of vllm + if self.reduce_results: + out = tensor_model_parallel_all_reduce(out_, self.tp_group) + else: + out = out_ + # do the bias add if needed + if not self.skip_bias_add: + out = out + down_proj.bias if down_proj.bias is not None else out + else: + return out, down_proj.bias + else: + fc1, bias = up_proj(hidden_states, smooth_quant_scale=smooth_quant_scale, use_tp_weight=use_tp_weight) + if bias is not None: + fc1 += bias + input_scale= None + if (self.quant_config is not None and self.quant_config.get_name() == "SmoothQuant" and + self.quant_config.input_quant_method == "per_token" and not self.quant_config.is_fp8): + down_proj.quant_method.skip_quant_input = True + down_proj_smooth = down_proj.smooth + if self.keep_full_weights and use_tp_weight: + assert down_proj.tp_smooth is not None, "tp_smooth is not initialized" + down_proj_smooth = down_proj.tp_smooth + fc1, input_scale = mlu_ops.per_token_smooth_quantize( + fc1, down_proj_smooth, None, None, act_mode=self.hidden_act, is_gated=self.is_gated) + else: + fc1 = mlu_ops.active(fc1, self.hidden_act, self.is_gated) + out, bias = down_proj( + fc1, residual=residual_, smooth_quant_scale=input_scale, + use_tp_weight=use_tp_weight, output=output) + + if self.skip_bias_add: + return out, bias + return out diff --git a/vllm_mlu/model_executor/layers/fused_moe/__init__.py b/vllm_mlu/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/model_executor/layers/fused_moe/fused_moe.py b/vllm_mlu/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 0000000..8ba5f33 --- /dev/null +++ b/vllm_mlu/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,935 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + _get_config_dtype_str, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_moe_kernel_gptq_awq, + write_zeros_to_output, + get_default_config, + try_get_optimal_moe_config, + _get_config_quant_dtype, +) +from vllm.model_executor.layers.fused_moe.utils import ( + activation_without_mul, + disable_inplace, +) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer + +from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize +import vllm_mlu._mlu_ops as mlu_ops + +logger = init_logger(__name__) + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_bias_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + stride_bbe, # bias expert stride + stride_bbn, # bias N stride + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Split the program ID into two dimensions (pid_0 and pid_1) + ''' + pid_0 = tl.program_id(axis=0) + pid_1 = tl.program_id(axis=1) + pid = pid_1 * tl.num_programs(axis=0) + pid_0 + ''' + ================== + End of MLU Hijack + ================== + ''' + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + if use_int8_w8a16: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8 or use_int8_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + # channel-wise + elif per_channel_quant: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + if HAS_BIAS: + # bias shape: [num_experts, N] + bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn + bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + if use_fp8_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + if HAS_BIAS: + accumulator = accumulator + bias[None, :] + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: torch.Tensor | None, + B_scale: torch.Tensor | None, + B_zp: torch.Tensor | None, + topk_weights: torch.Tensor | None, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: list[int] | None = None, + B_bias: torch.Tensor | None = None, +) -> None: + assert topk_weights is not None or not mul_routed_weight + assert topk_weights is None or topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8 or use_int8_w8a8: + assert B_scale is not None + assert block_shape is None or triton.cdiv( + B.size(-2), block_shape[0] + ) == B_scale.size(-2) + assert block_shape is None or triton.cdiv( + B.size(-1), block_shape[1] + ) == B_scale.size(-1) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + M = A.size(0) + num_tokens = M * top_k + + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Split the program ID into two dimensions (pid_0, pid_1) + ''' + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv( + B.shape[1], META['BLOCK_SIZE_N']), ) + + assert not (use_int8_w8a16 or use_int4_w4a16) + ''' + ================== + End of MLU Hijack + ================== + ''' + HAS_BIAS = B_bias is not None + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + use_moe_wna16_cuda = should_moe_wna16_use_cuda( + num_valid_tokens=num_tokens, + group_size=block_shape[1], + num_experts=B.size(0), + bit=4 if use_int4_w4a16 else 8, + ) + config = config.copy() + config.update( + get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"], + ) + ) + + if use_moe_wna16_cuda: + bit = 4 if use_int4_w4a16 else 8 + ops.moe_wna16_gemm( + A, + C, + B, + B_scale, + B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + bit, + ) + return + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + A.size(1), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + else: + config = config.copy() + config["SPLIT_K"] = 1 + BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") + if block_shape is not None: + BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) + fused_moe_kernel[grid]( + A, + B, + C, + B_bias, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + B.size(2), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_bias.stride(0) if B_bias is not None else 0, + B_bias.stride(1) if B_bias is not None else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + HAS_BIAS=HAS_BIAS, + BLOCK_SIZE_K=BLOCK_SIZE_K, + **config, + ) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: str | None = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ocp_mx_scheme, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) + + +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: str | None = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +) -> None: + pass + + +direct_register_custom_op( + op_name="outplace_fused_experts_mlu", + op_func=outplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=outplace_fused_experts_fake, + dispatch_key="PrivateUse1", + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), +) + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: + return torch.ops.vllm.outplace_fused_experts_mlu( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +SILU_NO_MUL: str = activation_without_mul("silu") +GELU_NO_MUL: str = activation_without_mul("gelu") +RELU2_NO_MUL: str = activation_without_mul("relu2") + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: str | None = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +) -> torch.Tensor: + # Check constraints. + if use_int4_w4a16: + assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" + elif ocp_mx_scheme is not None: + if ocp_mx_scheme in { + "w_mxfp4_a_mxfp4", + "w_mxfp4_a_mxfp6_e3m2", + "w_mxfp4_a_mxfp6_e2m3", + }: + # 16bit activation and fp4x2 packed weight + assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" + elif ocp_mx_scheme in { + "w_mxfp6_e3m2_a_mxfp6_e3m2", + "w_mxfp6_e2m3_a_mxfp6_e2m3", + }: + assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( + "hidden size mismatch" + ) + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") + else: + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" + ) + + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + num_tokens = hidden_states.size(0) + E, N, _ = w1.size() + K = w2.size(1) + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.size(1) + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + + config_dtype = _get_config_dtype_str( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ocp_mx_scheme=ocp_mx_scheme, + dtype=hidden_states.dtype, + ) + + # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are + # quantized prior to calling fused_experts. + quant_dtype = _get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + ocp_mx_scheme=ocp_mx_scheme, + ) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.size(), + w2.size(), + top_k_num, + config_dtype, + block_shape=block_shape, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Only use the default config + ''' + config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1], + hidden_states.dtype, block_shape) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty( + M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) + + # This needs separate memory since it's used concurrently with cache1 + intermediate_cache2 = torch.empty( + (M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + if inplace and not disable_inplace(): + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + if ocp_mx_scheme is not None: + # TODO: On platforms for which `current_platform.supports_mx()` is True + # and for which we have a native OCP mx fused MOE kernel, + # this dequantization step should not be done. + if ocp_mx_scheme in { + OCP_MX_Scheme.w_mxfp4_a_mxfp4, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, + }: + # Weight has to be dequantized for mxfp4 emulation. + w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) + w1_scale = None + w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) + w2_scale = None + elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: + w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype + ) + w1_scale = None + w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype + ) + w2_scale = None + elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: + w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype + ) + w1_scale = None + w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype + ) + w2_scale = None + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.size() + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + a1q_scale: Optional[torch.Tensor] = None + + if use_fp8_w8a8: + qcurr_hidden_states, a1q_scale = _fp8_quantize( + curr_hidden_states, a1_scale, block_shape) + else: + qcurr_hidden_states = curr_hidden_states + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) + + invoke_fused_moe_kernel( + qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w1_bias, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Activate by mlu_ops + ''' + intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N), + act_mode=activation, + is_gated=True) + ''' + ================== + End of MLU Hijack + ================== + ''' + a2q_scale: Optional[torch.Tensor] = None + + if use_fp8_w8a8: + qintermediate_cache2, a2q_scale = _fp8_quantize( + intermediate_cache2, a2_scale, block_shape) + else: + qintermediate_cache2 = intermediate_cache2 + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w2_bias, + ) + + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: replace moe_sum with torch.sum + Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513 + ''' + if topk_ids.shape[1] == 2: + torch.add( + intermediate_cache3[:, 0], + intermediate_cache3[:, 1], + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ).squeeze(dim=1) + elif topk_ids.shape[1] > 2: + torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + return out_hidden_states + diff --git a/vllm_mlu/model_executor/layers/fused_moe/layer.py b/vllm_mlu/model_executor/layers/fused_moe/layer.py new file mode 100644 index 0000000..de374e4 --- /dev/null +++ b/vllm_mlu/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Optional, Callable + +import torch + +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts + + +def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, +) -> torch.Tensor: + #TODO: support `routed_scaling_factor` + assert routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU." + ) + use_fused_kernel = topk_group is None + if use_fused_kernel: + assert not enable_eplb, f"MLU not support eplb in fused_moe kernel." + assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \ + f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet." + return mlu_ops.fused_moe( + x, + router_logits, + layer.w13_weight, layer.w2_weight, + None, None, # bias1, bias2 + None, # residual + None, # input_smooth + None, # act_smooth + None, None, # w1_scale, w2_scale + top_k, + renormalize, + True, # gated + activation + ) + else: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) + + if self.rocm_aiter_moe_enabled: + assert expert_map is None + return self.rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + else: + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + +MluHijackObject.apply_hijack( + UnquantizedFusedMoEMethod, + UnquantizedFusedMoEMethod.forward_oot, + vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot +) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py new file mode 100644 index 0000000..496bdac --- /dev/null +++ b/vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv, round_up + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: Implementation of moe_align_block_size_triton. +Note: the implemtentation has been removed from vllm since the +cuda implementation is more efficient. +''' +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, + numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts, ) + tokens_cnts = torch.zeros((num_experts + 1, num_experts), + dtype=torch.int32, + device=topk_ids.device) + cumsum = torch.zeros((num_experts + 1, ), + dtype=torch.int32, + device=topk_ids.device) + tokens_per_thread = cdiv(numel, num_experts) + sorted_token_ids.fill_(numel) + expert_ids.zero_() + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1, )]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) +''' +================== +End of MLU Hijack +================== +''' + + +def moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. + - pad_sorted_ids: A flag indicating whether the sorted_token_ids length + should be padded to a multiple of block_size, + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be zeroed out to prevent index out of bounds error while + # mapping global expert ids to local expert ids in expert parallelism. + expert_ids = torch.zeros((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Only use triton to implement moe_align_block_size + ''' + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/vllm_mlu/model_executor/layers/fused_moe/utils.py b/vllm_mlu/model_executor/layers/fused_moe/utils.py new file mode 100644 index 0000000..e3cff48 --- /dev/null +++ b/vllm_mlu/model_executor/layers/fused_moe/utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +from math import prod +from typing import List, Optional, Tuple + +import torch + +from vllm.utils.math_utils import cdiv + + +def _fp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + block_shape: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform fp8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) + assert block_shape is not None + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale + + diff --git a/vllm_mlu/model_executor/layers/indexer.py b/vllm_mlu/model_executor/layers/indexer.py new file mode 100644 index 0000000..7f11b14 --- /dev/null +++ b/vllm_mlu/model_executor/layers/indexer.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + get_tp_group +) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.model_executor.layers.compressor import ( + Compressor, + rotate_activation, +) +from vllm_mlu.v1.attention.backends.utils import get_common_metadata + +logger = init_logger(__name__) + + +class Indexer(torch.nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + rope, + compress_ratio: int = 4, + prefix: str = "", + **kwargs, + ): + super().__init__() + config = vllm_config.model_config.hf_config + self.dim = config.dim + self.n_heads = config.index_n_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.n_local_heads = config.index_n_heads // self.tp_size + self.head_dim = config.index_head_dim + self.rope_head_dim = config.rope_head_dim + self.index_topk = config.index_topk + self.q_lora_rank = config.q_lora_rank + self.window_size = config.window_size + self.block_size = vllm_config.cache_config.block_size + + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=None, + prefix=f"{prefix}.wq_b", + ) + + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=None, + params_dtype = torch.bfloat16, + prefix=f"{prefix}.weights_proj", + ) + + self.softmax_scale = self.head_dim ** -0.5 + self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5) + self.compress_ratio = compress_ratio + self.max_model_len = vllm_config.model_config.max_model_len + + self.rotary_emb = rope + self.tp_group = get_tp_group() + + self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor") + + self.freqs_cis = None + + def forward_prefill( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + weights: torch.Tensor, + attn_metadata: AttentionMetadata, + k_full: torch.Tensor, + context_lens: torch.Tensor, + ): + assert attn_metadata.prefill.chunked_context is None, \ + f"Prefill chunked context is not supported." + + query_start_loc = attn_metadata.prefill.query_start_loc + cu_seq_q_lens = query_start_loc + cu_seq_k_lens = torch.zeros( + context_lens.size(0) + 1, dtype=torch.int32, device=q.device, + ) + torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:]) + attn_metadata.prefill.query_start_loc + seq_lens = torch.diff(cu_seq_k_lens) + + batch_size = seq_lens.shape[0] + + new_block_tables = torch.empty( + [attn_metadata.num_prefill_tokens, self.index_topk], + dtype=torch.int32, + device=q.device, + ) + new_context_lens = torch.empty( + [attn_metadata.num_prefill_tokens], + dtype=torch.int32, + device=q.device, + ) + + q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1] + max_seq_len = q_seq_lens.max().item() + batch_size = q_seq_lens.size(0) + max_compressed_kv_len = max_seq_len // self.compress_ratio + kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device) + + # The layout of linear kv is as follows: + # | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv | + for i in range(batch_size): + start = cu_seq_k_lens[i].item() + kv_cache_block_table[i] = torch.arange( + start, start + max_compressed_kv_len, + dtype=torch.int32, + device=q.device, + ) + # offset total origin_kv len + kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1] + + # query: (tokens, index_head, index_head_dim) + # k_full: (tokens, index_head_dim) + # weights: (tokens, index_head, 1) + mlu_ops.masked_indexer_select_paged_kv_prefill( + query=q, + key_value=k_full, + weights=weights.unsqueeze(-1), + kv_cache_block_table=kv_cache_block_table, + cu_seq_q_lens=cu_seq_q_lens, + cu_seq_k_lens=cu_seq_k_lens, + index_topk=self.index_topk, + kv_cache_block_size=self.block_size, + softmax_scale=self.merged_softmax_scale, + q_scale=None, + k_scale_cache=None, + sparse_block_table=new_block_tables, + sparse_context_lens=new_context_lens, + compress_ratio=self.compress_ratio, + kv_cache_block_table_offset=None, + ) + + return new_block_tables, new_context_lens + + def forward_decode( + self, + q: torch.Tensor, + x: torch.Tensor, + k_cache: torch.Tensor, + weights: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + block_table = attn_metadata.decode.block_table + batch_size = block_table.shape[0] + seq_len = x.shape[0] // batch_size + q = q.view(batch_size, seq_len, *q.shape[1:]) + weights = weights.view(batch_size, seq_len, *weights.shape[1:]) + + seq_lens = attn_metadata.decode.seq_lens + k_block_table = block_table + + seq_len = x.shape[0] // batch_size + new_block_tables = torch.empty( + [batch_size, seq_len, self.index_topk], + dtype=torch.int32, + device=block_table.device, + ) + new_context_lens = torch.empty( + [attn_metadata.num_decode_tokens], + dtype=torch.int32, + device=block_table.device, + ) + + kv_cache_block_table_offset=torch.empty( + [attn_metadata.num_decode_tokens], + dtype=torch.int32, + device=block_table.device, + ) + kv_cache_block_table_offset.fill_(self.window_size) + mlu_ops.masked_indexer_select_paged_kv_decode( + query=q, + k_cache=k_cache, + weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1) + kv_cache_block_table=block_table, + k_context_lens=seq_lens // self.compress_ratio, + k_cache_block_table=k_block_table, + index_topk=self.index_topk, + kv_cache_block_size=self.block_size, + softmax_scale=self.merged_softmax_scale, + q_scale=None, + k_scale_cache=None, + sparse_block_table=new_block_tables, + sparse_context_lens=new_context_lens, + compress_ratio=self.compress_ratio, + kv_cache_block_table_offset=kv_cache_block_table_offset, + ) + + # [batch, seq_q, index_topk] -> [batch, index_topk] + new_block_tables = new_block_tables.squeeze(1) + return new_block_tables, new_context_lens + + def forward(self, + x: torch.Tensor, + qr: torch.Tensor, + positions: torch.Tensor, + offsets: torch.Tensor, + attn_metadata: AttentionMetadata, + batch_to_kv_state: torch.Tensor, + indexer_kv_cache: torch.Tensor, + compressor_slot_mapping: torch.Tensor, + ): + common_metadata = get_common_metadata() + query_start_loc = common_metadata.query_start_loc + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + rd = self.rope_head_dim + q = self.wq_b(qr)[0] + q = q.unflatten(-1, (self.n_heads, self.head_dim)) + + self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False) + q_pack = rotate_activation(q) + weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head) + + num_decode_tokens = attn_metadata.num_decode_tokens + compressed_kv = self.compressor( + x, + positions, + attn_metadata, + batch_to_kv_state, + indexer_kv_cache, + 0, + compressor_slot_mapping, + ) + if attn_metadata.prefill: + assert compressed_kv is not None and compressed_kv.dim() == 3 + compressed_kv = compressed_kv.squeeze(-2) + compressed_context_lens = query_lens // self.compress_ratio + + prefill_q = q_pack[num_decode_tokens:, ...] + prefill_weights = weights_pack[num_decode_tokens:, ...] + prefill_block_tables, prefill_context_lens = self.forward_prefill( + prefill_q, + indexer_kv_cache, + prefill_weights, + attn_metadata, + compressed_kv, + compressed_context_lens, + ) + + if attn_metadata.decode: + decode_x = x[:num_decode_tokens, ...] + decode_q = q_pack[:num_decode_tokens, ...] + decode_weights = weights_pack[attn_metadata.num_prefills:] + decode_block_tables, decode_context_lens = self.forward_decode( + decode_q, + decode_x, + indexer_kv_cache, + decode_weights, + attn_metadata, + ) + + if attn_metadata.prefill and attn_metadata.decode: + new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0) + new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0) + elif attn_metadata.prefill: + new_block_tables = prefill_block_tables + new_context_lens = prefill_context_lens + else: + new_block_tables = decode_block_tables + new_context_lens = decode_context_lens + return new_block_tables, new_context_lens \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/layernorm.py b/vllm_mlu/model_executor/layers/layernorm.py new file mode 100644 index 0000000..d649bdb --- /dev/null +++ b/vllm_mlu/model_executor/layers/layernorm.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Tuple + +import torch + +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod + +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant + + +@CustomOp.register("quant_fusion_rms_norm") +class QuantFusionRMSNorm(RMSNorm): + def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase): + super().__init__(hidden_size, variance_epsilon) + assert not isinstance( + proj.quant_method, UnquantizedLinearMethod + ), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported" + proj.quant_method.skip_quant_input = True + if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config): + quant_scale = proj.smooth.data + else: + quant_scale = proj.scale_to_int.data + self.dynamic_quant = dynamic_quant + self.quant_scale = torch.nn.Parameter(quant_scale) + + def forward( + self, x: torch.Tensor, residual: torch.Tensor | None = None + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None: + return mlu_ops.fused_rms_norm( + x, + residual, + self.weight.data, + None, + None, + self.variance_epsilon, + False, + self.quant_scale.data, + self.dynamic_quant, + ) + + +@CustomOp.register("quant_fusion_layer_norm") +class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp): + def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase): + super().__init__(hidden_size, variance_epsilon) + assert not isinstance( + proj.quant_method, UnquantizedLinearMethod + ), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported" + proj.quant_method.skip_quant_input = True + if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config): + quant_scale = proj.smooth.data + else: + quant_scale = proj.scale_to_int.data + self.dynamic_quant = dynamic_quant + self.quant_scale = torch.nn.Parameter(quant_scale) + + def forward( + self, x: torch.Tensor, residual: torch.Tensor | None = None + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None: + bias = None if self.bias is None else self.bias.data + return mlu_ops.fused_layer_norm( + x, + residual, + self.weight.data, + bias, + None, + self.eps, + False, + self.quant_scale.data, + self.dynamic_quant, + ) + + +def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None: + + org_shape = x.shape + x = x.reshape(-1, self.weight.data.shape[0]) + if out is not None: + out = out.view(-1, self.weight.data.shape[0]) + if residual is not None: + residual = residual.view(-1, self.weight.data.shape[0]) + x = mlu_ops.fused_rms_norm( + x, + residual, + self.weight.data, + None, + None, + self.variance_epsilon, + True, + out=out, + ) + else: + x = mlu_ops.fused_rms_norm( + x, + residual, + self.weight.data, + None, + None, + self.variance_epsilon, + False, + out=out, + ) + + if out is not None: + return x + + if residual is None: + assert isinstance(x, torch.Tensor) + return x.view(org_shape) + + assert isinstance(x, tuple) + assert len(x) == 2 + return x[0].view(org_shape), x[1].view(org_shape) + +MluHijackObject.apply_hijack( + RMSNorm, + RMSNorm.forward_oot, + vllm__model_executor__layers__layernorm__RMSNorm__forward_oot, +) diff --git a/vllm_mlu/model_executor/layers/linear.py b/vllm_mlu/model_executor/layers/linear.py new file mode 100644 index 0000000..3cdee31 --- /dev/null +++ b/vllm_mlu/model_executor/layers/linear.py @@ -0,0 +1,693 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Optional, Any + +import torch +from torch.nn.parameter import Parameter + +from vllm.distributed import (divide, split_tensor_along_last_dim, + get_parallel_rank_with_group, get_parallel_world_size_with_group, + get_tp_world_group, get_tp_world_world_size, get_tp_world_rank) +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.linear import ( + WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase, + ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.utils import set_weight_attrs +from vllm.logger import init_logger + +from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu import _mlu_ops as mlu_ops + +logger = init_logger(__name__) + + +WEIGHT_LOADER_V2_SUPPORTED.extend([ + "GPTQMluLinearMethod", + "AWQMluLinearMethod" +]) + +vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__ +vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader +vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add residual parameter. +@brief: dispatch unquantized_gemm to mlu ops. +''' +def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + residual: torch.Tensor | None = None +) -> torch.Tensor: + beta = 0.0 + if residual is not None: + beta = 1.0 + residual = residual.view(-1, residual.shape[-1]) + res_shape = x.shape[0:-1] + (layer.weight.shape[0], ) + return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]), + layer.weight, + bias, residual, 'none', 1.0, beta).view(res_shape) +''' +================== +End of MLU Hijack +================== +''' + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add tp_group and keep_full_weights parameters. +''' +def vllm__module_executor__layers__linear__LinearBase____init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + *, + tp_group: Any = None, + keep_full_weights: bool = False, + return_bias: bool = True, + disable_tp: bool = False, +): + vllm__module_executor__layers__linear__LinearBase____init__org( + self=self, + input_size=input_size, + output_size=output_size, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel + ''' + self.tp_group = tp_group + self.tp_world_size = get_parallel_world_size_with_group(self.tp_group) + self.tp_size = self.tp_world_size + self.tp_rank = get_parallel_rank_with_group(self.tp_group) + + self.keep_full_weights = keep_full_weights + if self.keep_full_weights or disable_tp: + self.tp_group = None + self.tp_world_size = 1 + self.tp_size = self.tp_world_size + self.tp_rank = 0 + self.tp_world_size_org = get_tp_world_world_size() + self.tp_rank_org = get_tp_world_rank() + ''' + ================= + End of MLU Hijack + ================= + ''' +''' +================= +End of MLU Hijack +================= +''' + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add tp_group and keep_full_weights parameters. +''' +def vllm__module_executor__layers__linear__ColumnParallelLinear____init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + output_sizes: list[int] | None = None, + prefix: str = "", + *, + tp_group: Any = None, + keep_full_weights: bool = False, + return_bias: bool = True, + disable_tp: bool = False, +): + super(ColumnParallelLinear, self).__init__( + input_size, + output_size, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + tp_group=tp_group, + keep_full_weights=keep_full_weights, + return_bias=return_bias, + disable_tp=disable_tp, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__ + ''' + # Divide the weight matrix along the last dimension. + # self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + # self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 + ''' + ================= + End of MLU Hijack + ================= + ''' + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) for output_size in self.output_sizes + ] + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add tp_group in create_weights + ''' + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + tp_group=self.tp_group, + ) + ''' + ================= + End of MLU Hijack + ================= + ''' + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + self.update_param_tp_status() +''' +================= +End of MLU Hijack +================= +''' + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add smooth_quant_scale and use_tp_weight parameters. +''' +def vllm__module_executor__layers__linear__ColumnParallelLinear__forward( + self, + input_, + smooth_quant_scale: torch.Tensor | None = None, + use_tp_weight: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add input_scale and use_tp_weight parameter. + ''' + kwargs = {'bias': bias} + if use_tp_weight: + kwargs['use_tp_weight'] = use_tp_weight + if smooth_quant_scale is not None: + kwargs['input_scale'] = smooth_quant_scale + output_parallel = self.quant_method.apply(self, input_, **kwargs) + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.gather_output and self.tp_size > 1: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add tp_group param to tensor_model_parallel_all_gather + ''' + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group) + ''' + ================= + End of MLU Hijack + ================= + ''' + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias +''' +================= +End of MLU Hijack +================= +''' + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add tp_group and keep_full_weights parameters. +''' +def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + *, + tp_group: Any = None, + keep_full_weights: bool = False, + return_bias: bool = True, + disable_tp: bool = False, +): + self.output_sizes = output_sizes + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: checkout output_sizes after init to get self.tp_world_size + @brief: add keep_full_weights for dp parallelize shared expert + ''' + super(MergedColumnParallelLinear, self).__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + output_sizes=self.output_sizes, + prefix=prefix, + tp_group=tp_group, + keep_full_weights=keep_full_weights, + return_bias=return_bias, + disable_tp=disable_tp, + ) + + assert all(output_size % self.tp_size == 0 for output_size in output_sizes) + + if self.keep_full_weights: + tp_size = self.tp_world_size_org + if isinstance(self.quant_method, UnquantizedLinearMethod): + out_dim, in_dim = self.weight.shape + out_dim_tp = divide(out_dim, tp_size) + self.tp_weight = Parameter( + self.weight.data.new_empty((out_dim_tp, in_dim)), + requires_grad=False, + ) + elif (isinstance(self.quant_method, SmoothQuantLinearMethod) + and quant_config.input_quant_method == "per_token"): + out_dim, in_dim = self.qweight.shape + out_dim_tp = divide(out_dim, tp_size) + self.tp_qweight = Parameter( + self.qweight.data.new_empty((out_dim_tp, in_dim)), + requires_grad=False, + ) + self.tp_per_channel_scale = Parameter( + self.per_channel_scale.data.new_empty((out_dim_tp)), + requires_grad=False, + ) + else: + raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token") + ''' + ================= + End of MLU Hijack + ================= + ''' +''' +================= +End of MLU Hijack +================= +''' + + +def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: int | None = None, +): + loaded_weight_orig = loaded_weight + output_dim = getattr(param, "output_dim", None) + + vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org( + self=self, + param=param, + loaded_weight=loaded_weight, + loaded_shard_id=loaded_shard_id, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add keep_full_weights for dp parallelize shared expert + ''' + # load into tp weight + if self.keep_full_weights: + tp_size = self.tp_world_size_org + tp_rank = self.tp_rank_org + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + start_idx = tp_rank * shard_size + if isinstance(self.quant_method, UnquantizedLinearMethod): + tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size) + tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size) + tp_weight_shard.copy_(tp_weight) + elif isinstance(self.quant_method, SmoothQuantLinearMethod): + if output_dim is None: + return + tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size) + if loaded_weight_orig.ndim == 1: + tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size) + elif loaded_weight_orig.ndim == 2: + tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size) + else: + raise ValueError("only support rank 1 and 2 when using tp_weight") + + tp_weight_shard.copy_(tp_weight) + else: + raise TypeError(f"quant method is expected to be either unquantized or smoothquant") + ''' + ================= + End of MLU Hijack + ================= + ''' + + +def vllm__module_executor__layers__linear__RowParallelLinear____init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + tp_group: Any = None, + keep_full_weights: bool = False, + return_bias: bool = True, + disable_tp: bool = False, +): + super(RowParallelLinear, self).__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + tp_group=tp_group, + keep_full_weights=keep_full_weights, + return_bias=return_bias, + disable_tp=disable_tp, + ) + + # Divide the weight matrix along the last dimension + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add tp_group in create_weights + ''' + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + tp_group=self.tp_group, + ) + ''' + ================= + End of MLU Hijack + ================= + ''' + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add keep_full_weights for dp parallelize shared expert + ''' + if self.keep_full_weights: + tp_size = self.tp_world_size_org + if isinstance(self.quant_method, UnquantizedLinearMethod): + out_dim, in_dim = self.weight.data.shape + in_dim_tp = divide(in_dim, tp_size) + self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)), + requires_grad=False) + elif (isinstance(self.quant_method, SmoothQuantLinearMethod) + and quant_config.input_quant_method == "per_token"): + out_dim, in_dim = self.qweight.data.shape + in_dim_tp = divide(in_dim, tp_size) + self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)), + requires_grad=False) + if hasattr(self, "smooth"): + assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor" + dim = self.smooth.shape[0] + dim_tp = divide(dim, tp_size) + self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)), + requires_grad=False) + else: + raise TypeError("quant method expected to be unquantized or smoothquant per-token") + ''' + ================= + End of MLU Hijack + ================= + ''' + self.update_param_tp_status() + + +def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader( + self, param: Parameter, loaded_weight: torch.Tensor +): + input_dim = getattr(param, "input_dim", None) + loaded_weight_orig = loaded_weight + + vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org( + self=self, + param=param, + loaded_weight=loaded_weight, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add keep_full_weights for dp parallelize shared expert + ''' + if self.keep_full_weights: + if input_dim is None: + return + tp_size = self.tp_world_size_org + tp_rank = self.tp_rank_org + shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size) + start_idx = tp_rank * shard_size + if isinstance(self.quant_method, UnquantizedLinearMethod): + shard_view = self.weight.narrow(input_dim, start_idx, shard_size) + self.tp_weight.copy_(shard_view) + elif isinstance(self.quant_method, SmoothQuantLinearMethod): + if loaded_weight_orig.ndim == 1: + shard_view = self.smooth.narrow(input_dim, start_idx, shard_size) + self.tp_smooth.copy_(shard_view) + elif loaded_weight_orig.ndim == 2: + shard_view = self.qweight.narrow(input_dim, start_idx, shard_size) + self.tp_qweight.copy_(shard_view) + else: + raise ValueError("only rank 1 and 2 is supported for tp_weight") + else: + raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant") + ''' + ================= + End of MLU Hijack + ================= + ''' + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters. +''' +def vllm__module_executor__layers__linear__RowParallelLinear__forward( + self, + input_, + residual: torch.Tensor | None = None, + smooth_quant_scale: torch.Tensor | None = None, + use_tp_weight: bool = False, + output: torch.Tensor | None = None, +) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add additional matmul parameters. + ''' + residual_ = None if self.tp_rank > 0 else residual + kwargs = {'bias': bias_, 'residual': residual_} + if use_tp_weight: + kwargs['use_tp_weight'] = use_tp_weight + if smooth_quant_scale is not None: + kwargs['input_scale'] = smooth_quant_scale + if output is not None: + kwargs['output'] = output + output_parallel = self.quant_method.apply(self, input_parallel, **kwargs) + ''' + ================= + End of MLU Hijack + ================= + ''' + + if self.reduce_results and self.tp_size > 1: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add tensor_model_parallel_all_reduce() with self.tp_group + ''' + output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group) + ''' + ================= + End of MLU Hijack + ================= + ''' + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias +''' +================= +End of MLU Hijack +================= +''' + + +MluHijackObject.apply_hijack(UnquantizedLinearMethod, + UnquantizedLinearMethod.apply, + vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply) +MluHijackObject.apply_hijack(LinearBase, + LinearBase.__init__, + vllm__module_executor__layers__linear__LinearBase____init__) +MluHijackObject.apply_hijack(ColumnParallelLinear, + ColumnParallelLinear.__init__, + vllm__module_executor__layers__linear__ColumnParallelLinear____init__) +MluHijackObject.apply_hijack(ColumnParallelLinear, + ColumnParallelLinear.forward, + vllm__module_executor__layers__linear__ColumnParallelLinear__forward) +MluHijackObject.apply_hijack(MergedColumnParallelLinear, + MergedColumnParallelLinear.__init__, + vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__) +MluHijackObject.apply_hijack(MergedColumnParallelLinear, + MergedColumnParallelLinear.weight_loader, + vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader) +MluHijackObject.apply_hijack(RowParallelLinear, + RowParallelLinear.__init__, + vllm__module_executor__layers__linear__RowParallelLinear____init__) +MluHijackObject.apply_hijack(RowParallelLinear, + RowParallelLinear.weight_loader, + vllm__module_executor__layers__linear__RowParallelLinear__weight_loader) +MluHijackObject.apply_hijack(RowParallelLinear, + RowParallelLinear.forward, + vllm__module_executor__layers__linear__RowParallelLinear__forward) diff --git a/vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py b/vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py new file mode 100644 index 0000000..57d5c91 --- /dev/null +++ b/vllm_mlu/model_executor/layers/longcat_sparse_moe_mlp.py @@ -0,0 +1,744 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +"""Inference-only MOE model.""" +from typing import Optional, Any, List, Dict + +import torch +from torch import nn + +from vllm.distributed import ( + divide, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu._mlu_utils import * +from vllm_mlu.distributed.parallel_state import( + cnclep_dispatch, cnclep_combine) +from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp + +class LongCatSparseMoeMlp(SparseMoeMlp): + """ + sparse moe mlp layer specific to longcat model + """ + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + up_proj_name: str, + is_gated: bool, + down_proj_name: str, + has_bias: bool, + skip_bias_add: bool = False, + renormalize:bool = False, + hidden_act: str = "silu", + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + is_use_fused_moe: bool = False, + expert_group: Optional[int] = 1, + topk_group: Optional[int] = 1, + scoring_func: str = "softmax", + topk_method: str = "", + routed_scaling_factor: float = 1.0, + tp_group: Any = None, + use_all2all: bool = False, + num_zero_experts: int = 0, + ): + super().__init__( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + up_proj_name=up_proj_name, + is_gated=is_gated, + down_proj_name=down_proj_name, + has_bias=has_bias, + skip_bias_add=skip_bias_add, + renormalize=renormalize, + hidden_act=hidden_act, + params_dtype=params_dtype, + quant_config=quant_config, + is_use_fused_moe=is_use_fused_moe, + expert_group=expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + topk_method=topk_method, + routed_scaling_factor=routed_scaling_factor, + tp_group=tp_group, + use_all2all=use_all2all, + init_avg_moe=False, + ) + self.num_zero_experts = num_zero_experts + self.total_experts_including_zero = self.num_total_experts + self.num_zero_experts + self.use_quant_all2all = use_all2all and quant_config is not None + self.zero_expert_size = divide(self.num_zero_experts, self.moe_ep_size) + self.start_zero_expert_id = ( + self.num_total_experts + self.moe_ep_rank * ((self.num_zero_experts + self.moe_ep_size - 1) // self.moe_ep_size) + ) + + if VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg: + n_tokens = SparseMoeMlp.max_batched_token * self.dp_size + expert_group = self.moe_ep_size + val = 1.0 / float(self.total_experts_including_zero) + SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32) + if VLLM_RANDOM_MOE_EN: + import numpy as np + # example deepseekv2: experts 160 topk 6 + # avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,...... + array = np.stack([np.random.permutation(self.total_experts_including_zero)[:top_k] for _ in range(n_tokens)]) + table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32) + else: + # example deepseekv2: experts 160 + # avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,...... + import math + batch_table = math.ceil(n_tokens * top_k / self.total_experts_including_zero) * self.total_experts_including_zero + hi_val = batch_table // self.total_experts_including_zero + table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view( + hi_val, expert_group, num_experts // expert_group).transpose(1, 2) + if self.num_zero_experts > 0: + # Longcat model, for avg expert, we choose eight non-zero experts and four zero + # experts for each token accorrding to the paper. + assert num_experts == 512 and num_zero_experts == 256 and top_k == 12 + assert num_zero_experts % expert_group == 0 + non_zero_expert_num_per_token = 8 + zero_expert_num_per_token = 4 + zero_expert_table = torch.arange( + num_experts, num_experts + num_zero_experts, dtype=table.dtype, device=table.device).view( + expert_group, num_zero_experts // expert_group).transpose(0, 1).flatten() + non_zero_expert_table = table[0].flatten() + token_expert_list = [] + for idx in range(0, num_experts // non_zero_expert_num_per_token): + token_expert_list.append(non_zero_expert_table[ + idx * non_zero_expert_num_per_token: + idx * non_zero_expert_num_per_token + non_zero_expert_num_per_token]) + token_expert_list.append(zero_expert_table[ + idx * zero_expert_num_per_token: + idx * zero_expert_num_per_token + zero_expert_num_per_token]) + avg_expert_table = torch.cat(token_expert_list) + table = avg_expert_table.repeat(hi_val) + SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k) + SparseMoeMlp.is_expert_avg = True + + + def forward_experts_nofused_longcat( + self, hidden_states, total_num_experts, total_num_experts_per_rank, + topk_indices=None, topk_weights=None, residual_=None): + assert self.moe_ep_size == 1 + assert not self.use_all2all + expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = mlu_ops.moe_gen_idx( + topk_indices.to(torch.int32), total_num_experts) + # no expert is routed, then expand_gather_idx, expand_scatter_idx has no item, + # expand_token_count and expand_cusum_token_count has item but the value is all zero + # so this rank should only return final_hidden_states with zero value + if cusum_token_count[-1] == 0: + final_hidden_states = torch.zeros_like(hidden_states, + dtype=hidden_states.dtype, + device=hidden_states.device) + return final_hidden_states + + expand_hidden_states = mlu_ops.moe_expand_input( + hidden_states, expand_gather_idx, cusum_token_count, + start_expert_id=self.start_expert_id, + expert_size=self.end_expert_id - self.start_expert_id) + expand_hidden_states_zero = mlu_ops.moe_expand_input( + hidden_states, expand_gather_idx, cusum_token_count, + start_expert_id=self.start_zero_expert_id, + expert_size=self.zero_expert_size) + + expand_output_list = [] + expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id + + 1] - cusum_token_count[self.start_expert_id] + + for expert_idx, num_tokens_per_expert in enumerate(expand_token_count[:self.num_total_experts]): + if num_tokens_per_expert > 0: + expert_hidden_states = expand_hidden_states[ + expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]] + if expert_idx < self.num_total_experts: + expert_output = self.experts[expert_idx](expert_hidden_states) + else: + expert_output = expert_hidden_states + expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output + expand_output_list.append(expert_output) + expand_output = torch.cat(expand_output_list, dim=0) + num_normal_tokens = cusum_token_count[self.num_total_experts] + expand_hidden_states[:num_normal_tokens] = expand_output + # reduce normal experts + final_hidden_states = mlu_ops.moe_combine_result( + expand_hidden_states, topk_weights, scatter_idx, + residual_, cusum_token_count, start_expert_id=self.start_expert_id, + expert_size=self.end_expert_id - self.start_expert_id, bias=None) + # reduce zero experts + if self.moe_ep_size > 1 or self.moe_tp_rank == 0: + final_hidden_states = mlu_ops.moe_combine_result( + expand_hidden_states_zero, topk_weights, scatter_idx, + final_hidden_states, cusum_token_count, start_expert_id=self.start_zero_expert_id, + expert_size=self.zero_expert_size, bias=None, + output=final_hidden_states) + return final_hidden_states + + # no compute-communication parallel, for prototyping only, not in actual use. + # subject to becoming stale + def forward_all2all_int8_longcat( + self, hidden_states, total_num_experts, total_num_experts_per_rank, + topk_indices=None, topk_weights=None, residual_=None): + ori_input_shape = hidden_states.shape + dtype = hidden_states.dtype + self.pack_params() + self.pack_params_after_loading() + w1=self.w13 + w2=self.w2 + bias2=self.b2 + input_smooth=self.a13_scale_all_experts + act_smooth=self.a2_scale + w1_scale=self.w13_scale + w2_scale=self.w2_scale + act_mode=self.hidden_act + quant_input=None + + max_m = hidden_states.shape[0] + + reduce_weight = topk_weights + expert_id = topk_indices + + expand_idx, combine_idx, token_count, cusum_token_count \ + = mlu_ops.moe_gen_idx(expert_id, total_num_experts) + + num_token_expand = hidden_states.shape[0] * self.top_k + dispatch_bytes = num_token_expand * self.dispatch_token_size + + dispatch_send_token_tensor = ( + self.dispatch_send_buffer[:dispatch_bytes] + .view(num_token_expand, self.dispatch_token_size) + ) + + quant_size = self.hidden_size + quant_input = dispatch_send_token_tensor[:, : quant_size] + input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32) + quant_input, input_scale = mlu_ops.moe_quantize( + hidden_states, input_smooth, None, token_count[:self.num_total_experts], + expand_idx, None, + output=quant_input, + output_scale=input_scale) + expand_hidden_states_zero = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id=self.num_total_experts, + expert_size=self.num_zero_experts) + + dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout( + token_count[:self.num_total_experts], self.moe_ep_size) + + cnclep_dispatch(self.dispatch_token_size, + num_token_expand, + dispatch_send_layout, + token_count[:self.num_total_experts], + self.dispatch_recv_layout, + self.dispatch_recv_token_num) + + recv_token_num = self.dispatch_recv_token_num.view( + self.moe_ep_size, self.num_experts_per_rank) + pad_num = self.max_num_tokens_per_rank + + ( + gather_by_expert_index, + gather_by_rank_index, + tokens_per_local_expert, + token_sum + ) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num) + + max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size + dispatch_recv_token_tensor = ( + self.dispatch_recv_buffer[:max_tokens_bytes_recv] + .view(self.max_num_tokens_recv, self.dispatch_token_size)) + + mlu_ops.gather_split(dispatch_recv_token_tensor, + gather_by_expert_index, + token_sum, + self.quant_input_recv, + self.input_scale_recv) + + max_m = self.max_num_tokens_per_expert + gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1, + tokens_per_local_expert, + None, None, None, None, + self.input_scale_recv.view(torch.float32).flatten(), + w1_scale, dtype, max_m) + + # continue reusing self.quant_input_recv and self.input_scale_recv + quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2] + input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]] + quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None, + tokens_per_local_expert, + output=quant_input, + output_scale=input_scale_fp32, + act_mode=act_mode, + is_gated=self.is_gated) + + gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2, + tokens_per_local_expert, + None, None, None, None, input_scale, w2_scale, dtype, max_m) + + combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype) + mlu_ops.gather_split(gemm_out, + gather_by_rank_index, + token_sum, + combine_send_token_tensor, + None) + + combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size) + combine_recv_layout = self.dispatch_recv_layout + + # combine + combine_args = dict( + token_byte=self.hidden_size * 2, + token_num=num_token_expand, + send_src_layout=combine_send_layout, + send_dst_layout=combine_recv_layout, + send_token=None, + recv_token=None) + + cnclep_combine(**combine_args) + + numel_recv = num_token_expand * self.hidden_size + recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv] + .view(num_token_expand, self.hidden_size)) + + residual_ = None + output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx, + residual_, cusum_token_count, start_expert_id=0, + expert_size=self.num_total_experts, bias=bias2, output=hidden_states) + assert self.moe_ep_size > 1 + # zero expert reduce + output = mlu_ops.moe_combine_result( + expand_hidden_states_zero, reduce_weight, combine_idx, + output, cusum_token_count, self.num_total_experts, + self.num_zero_experts, output=hidden_states) + + return output.view(ori_input_shape) + + # no compute-communication parallel, for prototyping only, not in actual use. + # subject to becoming stale + def forward_all2all_bf16_longcat( + self, hidden_states, total_num_experts, total_num_experts_per_rank, + topk_indices=None, topk_weights=None, residual_=None): + is_fp8_quant = isinstance(self.quant_config, Fp8Config) + ori_input_shape = hidden_states.shape + dtype = hidden_states.dtype + self.pack_params() + self.pack_params_after_loading() + w1=self.w13 + w2=self.w2 + bias1=self.b13 + bias2=self.b2 + gated=self.is_gated + act_mode=self.hidden_act + + max_m = hidden_states.shape[0] + reduce_weight = topk_weights + expert_id = topk_indices + + # gen_idx + expand_idx, combine_idx, token_count, cusum_token_count = \ + mlu_ops.moe_gen_idx(expert_id, total_num_experts) + num_token_expand = hidden_states.shape[0] * self.top_k + dispatch_bytes = num_token_expand * self.dispatch_token_size + + dispatch_send_token_tensor = ( + self.dispatch_send_buffer[:dispatch_bytes] + .view(num_token_expand, self.dispatch_token_size) + .view(hidden_states.dtype) + ) + + expand_hidden_states = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, start_expert_id=0, + expert_size=self.num_total_experts) + expand_hidden_states_zero = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id=self.num_total_experts, + expert_size=self.num_zero_experts) + + dispatch_send_token_tensor.copy_(expand_hidden_states) + + dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout( + token_count[:self.num_total_experts], self.moe_ep_size) + + cnclep_dispatch(self.dispatch_token_size, + num_token_expand, + dispatch_send_layout, + token_count[:self.num_total_experts], + self.dispatch_recv_layout, + self.dispatch_recv_token_num, + use_quant_dispatch=False, + ) + + recv_token_num = self.dispatch_recv_token_num.view( + self.moe_ep_size, self.num_experts_per_rank) + pad_num = self.max_num_tokens_per_rank + + ( + gather_by_expert_index, + gather_by_rank_index, + tokens_per_local_expert, + token_sum + ) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num) + + max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size + dispatch_recv_token_tensor = ( + self.dispatch_recv_buffer[:max_tokens_bytes_recv] + .view(self.max_num_tokens_recv, self.dispatch_token_size) + .view(hidden_states.dtype) + ) + + + self.quant_input_recv = self.quant_input_recv.view(hidden_states.dtype) + mlu_ops.gather_split(dispatch_recv_token_tensor, + gather_by_expert_index, + token_sum, + self.quant_input_recv) + + max_m = self.max_num_tokens_per_expert + gemm_out = mlu_ops.group_gemm( + self.quant_input_recv, w1, tokens_per_local_expert, + None, None, None, None, max_m) + act_out = mlu_ops.moe_active( + gemm_out, act_mode, gated) + gemm_out = mlu_ops.group_gemm( + act_out, w2, tokens_per_local_expert, + None, None, None, None, max_m) + + combine_send_token_tensor = self.combine_send_buffer.view( + self.max_num_tokens_recv, -1).view(hidden_states.dtype) + mlu_ops.gather_split(gemm_out, + gather_by_rank_index, + token_sum, + combine_send_token_tensor, + None) + + combine_send_layout = mlu_ops.moe_all2all_gen_send_layout( + self.dispatch_recv_token_num, self.moe_ep_size) + combine_recv_layout = self.dispatch_recv_layout + + combine_args = dict( + token_byte=self.hidden_size * 2, + token_num=num_token_expand, + send_src_layout=combine_send_layout, + send_dst_layout=combine_recv_layout, + send_token=None, + recv_token=None, + use_quant_dispatch=False, + ) + + cnclep_combine(**combine_args) + + numel_recv = num_token_expand * self.hidden_size + recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv] + .view(num_token_expand, self.hidden_size)) + + residual_ = None + output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx, + residual_, cusum_token_count, start_expert_id=0, + expert_size=self.num_total_experts, bias=bias2, output=hidden_states) + # zero expert reduce + output = mlu_ops.moe_combine_result( + expand_hidden_states_zero, reduce_weight, combine_idx, + output, cusum_token_count, self.num_total_experts, + self.num_zero_experts, output=hidden_states) + return output.view(ori_input_shape) + + def forward_before_dispatch(self, hidden_states: torch.Tensor, + topk_indices: torch.Tensor): + # gate and softmax topk is called in router for longcat + # other models can do these operations here + expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx( + topk_indices, self.total_experts_including_zero) + + num_token_expand = hidden_states.shape[0] * self.top_k + dispatch_bytes = num_token_expand * self.dispatch_token_size + dispatch_send_token_tensor = ( + self.dispatch_send_buffer[:dispatch_bytes] + .view(num_token_expand, self.dispatch_token_size) + ) + if self.use_quant_all2all: + hidden_states_stride = self.hidden_size + quant_input = dispatch_send_token_tensor[:, : hidden_states_stride] + input_scale = dispatch_send_token_tensor[:, hidden_states_stride :].view(torch.float32) + # expand input + quantize + quant_input, input_scale = mlu_ops.moe_quantize( + hidden_states, self.a13_scale_all_experts, None, + token_count[:self.num_total_experts], + expand_idx, None, + output=quant_input, + output_scale=input_scale) + # expand input of zero-expert + expand_hidden_states_zero = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id=self.num_total_experts, + expert_size=self.num_zero_experts) + else: + expand_hidden_states = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, start_expert_id=0, + expert_size=self.num_total_experts) + dispatch_send_token_tensor = dispatch_send_token_tensor.view( + hidden_states.dtype) + dispatch_send_token_tensor.copy_(expand_hidden_states) + del expand_hidden_states + expand_hidden_states_zero = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id=self.num_total_experts, + expert_size=self.num_zero_experts) + + dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout( + token_count[:self.num_total_experts], self.moe_ep_size) + + return combine_idx, token_count, cusum_token_count, dispatch_send_layout, expand_hidden_states_zero + + def forward_dispatch(self, token_num: int, dispatch_send_layout: torch.Tensor, + token_count: torch.Tensor): + num_token_expand = token_num * self.top_k + cnclep_dispatch(self.dispatch_token_size, + num_token_expand, + dispatch_send_layout, + token_count[:self.num_total_experts], + self.dispatch_recv_layout, + self.dispatch_recv_token_num, + use_quant_dispatch=self.use_quant_all2all) + + def forward_before_combine(self, hidden_states_dtype: torch.dtype): + recv_token_num = self.dispatch_recv_token_num.view( + self.moe_ep_size, self.num_experts_per_rank) + + ( + gather_by_expert_index, + gather_by_rank_index, + tokens_per_local_expert, + token_sum, + cusum_token_count + ) = mlu_ops.moe_all2all_gen_gather_index( + recv_token_num, self.max_num_tokens_per_rank, + return_cusum_token_count=True) + + max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size + dispatch_recv_token_tensor = ( + self.dispatch_recv_buffer[:max_tokens_bytes_recv] + .view(self.max_num_tokens_recv, self.dispatch_token_size)) + + max_m = self.max_num_tokens_per_expert + if self.use_quant_all2all: + mlu_ops.gather_split(dispatch_recv_token_tensor, + gather_by_expert_index, + token_sum, + self.quant_input_recv, + self.input_scale_recv) + # OPT: input_scale_recv_flatten can reuse self.input_scale_recv + input_scale_recv_flatten = self.input_scale_recv.view(torch.float32).flatten() + gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, self.w13, + tokens_per_local_expert, + None, None, None, None, + input_scale_recv_flatten, + self.w13_scale, hidden_states_dtype, max_m) + + quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2] + input_scale_fp32 = input_scale_recv_flatten[:gemm_out.shape[0]] + quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, self.a2_scale, None, + tokens_per_local_expert, + output=quant_input, + output_scale=input_scale_fp32, + act_mode=self.hidden_act, + is_gated=self.is_gated) + + gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, self.w2, tokens_per_local_expert, + None, None, None, None, input_scale, self.w2_scale, + hidden_states_dtype, max_m) + else: + dispatch_recv_token_tensor = dispatch_recv_token_tensor.view(hidden_states_dtype) + self.input_recv = self.input_recv.view(hidden_states_dtype) + mlu_ops.gather_split(dispatch_recv_token_tensor, + gather_by_expert_index, + token_sum, + self.input_recv) + gemm_out = mlu_ops.group_gemm( + self.input_recv, self.w13, tokens_per_local_expert, + None, None, None, None, max_m) + act_out = self.input_recv[:, :gemm_out.shape[-1] // 2] + act_out = mlu_ops.moe_active( + gemm_out, self.hidden_act, self.is_gated, output=act_out, + bias=None, cusum_token_count=cusum_token_count, + start_expert_id=0, expert_size=self.num_experts_per_rank) + gemm_out = mlu_ops.group_gemm( + act_out, self.w2, tokens_per_local_expert, + None, None, None, None, max_m) + + combine_send_token_tensor = self.combine_send_buffer.view( + self.max_num_tokens_recv, -1).view(hidden_states_dtype) + mlu_ops.gather_split(gemm_out, + gather_by_rank_index, + token_sum, + combine_send_token_tensor, + None) + + combine_send_layout = mlu_ops.moe_all2all_gen_send_layout( + self.dispatch_recv_token_num, self.moe_ep_size) + + return combine_send_layout + + def forward_combine(self, token_num: int, combine_send_layout: torch.Tensor): + num_token_expand = token_num * self.top_k + # combine_recv_layout(self.dispatch_recv_layout) is calculated when cnclep_dispatch + # because dispatch and combine are inverse operation + cnclep_combine(token_byte=self.hidden_size * 2, + token_num=num_token_expand, + send_src_layout=combine_send_layout, + send_dst_layout=self.dispatch_recv_layout, + send_token=None, + recv_token=None, + use_quant_dispatch=self.use_quant_all2all) + + def forward_after_combine(self, token_num: int, + reduce_weight: torch.Tensor, + combine_idx: torch.Tensor, + cusum_token_count: torch.Tensor, + expand_hidden_states_zero: torch.Tensor, + output_tensor_dtype: torch.dtype, + output_tensor: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None): + num_token_expand = token_num * self.top_k + numel_recv = num_token_expand * self.hidden_size + recv_token = (self.combine_recv_buffer.view(output_tensor_dtype)[:numel_recv] + .view(num_token_expand, self.hidden_size)) + + output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx, + residual, cusum_token_count, start_expert_id=0, + expert_size=self.num_total_experts, bias=self.b2, output=output_tensor) + output = mlu_ops.moe_combine_result( + expand_hidden_states_zero, reduce_weight, combine_idx, + output, cusum_token_count, self.num_total_experts, + self.num_zero_experts, output=output_tensor) + + return output + + # no compute-communication parallel, for prototyping only, not in actual use. + # subject to becoming stale + def forward_group_experts_longcat( + self, hidden_states, total_num_experts, total_num_experts_per_rank, + topk_indices=None, topk_weights=None, residual_=None, + expand_idx=None, combine_idx=None, token_count=None, cusum_token_count=None): + is_fp8_quant = isinstance(self.quant_config, Fp8Config) + ori_input_shape = hidden_states.shape + dtype = hidden_states.dtype + self.pack_params() + self.pack_params_after_loading() + w1=self.w13 + w2=self.w2 + bias1=self.b13 + bias2=self.b2 + input_smooth=self.a13_scale + act_smooth=self.a2_scale + w1_scale=self.w13_scale + w2_scale=self.w2_scale + gated=self.is_gated + act_mode=self.hidden_act + quant_input=None + + start_expert_id=self.start_expert_id + expert_size = w1.size(0) + max_m = hidden_states.shape[0] + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None + # Check smooth quant parameters. + per_token_sq = False + if not is_fp8_quant: + check_list = [input_smooth, act_smooth, w1_scale, w2_scale] + if all(x is not None for x in check_list): + per_token_sq = True + + if not (all(x is None for x in check_list) or all(x is not None for x in check_list)): + raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present " + "and absent at the same time.") + + expert_id = topk_indices + reduce_weight = topk_weights + + # gen_idx + if expert_id is not None: + expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, total_num_experts) + + # check quant + if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token': + raise NotImplementedError + elif per_token_sq: + expand_hidden_states = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id=start_expert_id, + expert_size=expert_size) + expand_hidden_states_zero = mlu_ops.moe_expand_input( + hidden_states, expand_idx, cusum_token_count, + start_expert_id=self.start_zero_expert_id, + expert_size=self.zero_expert_size) + quant_input, input_scale = mlu_ops.moe_quantize( + expand_hidden_states, input_smooth, None, + token_count[start_expert_id:start_expert_id+expert_size]) + else: + expand_hidden_states = mlu_ops.moe_expand_input(hidden_states, expand_idx, + cusum_token_count, start_expert_id, expert_size) + expand_hidden_states_zero = mlu_ops.moe_expand_input(hidden_states, expand_idx, + cusum_token_count, self.start_zero_expert_id, self.zero_expert_size) + + if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq: + gemm_out = mlu_ops.smooth_quant_group_gemm( + quant_input, w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, input_scale, w1_scale, dtype, max_m) + else: + gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, max_m) + + # add_bias_active + if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token': + raise NotImplementedError + elif per_token_sq: + quant_input = quant_input[:, :gemm_out.shape[-1] // 2] + input_scale = input_scale[:gemm_out.shape[0]] + quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None, + token_count[start_expert_id:start_expert_id+expert_size], + output=quant_input, + output_scale=input_scale, + act_mode=act_mode, + is_gated=self.is_gated) + + if ((is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') + or per_token_sq): + # Remove the reference to gemm_out tensor. + # If that was the only reference, the tensor’s memory becomes eligible for deallocation + # So that we can reuse this memory for the new allocation of next gemm operation + # del gemm_out + gemm_out = mlu_ops.smooth_quant_group_gemm( + quant_input, w2, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, input_scale, w2_scale, dtype, max_m, + output=expand_hidden_states) + else: + act_out = mlu_ops.moe_active( + gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2], + bias1, cusum_token_count, start_expert_id, expert_size) + gemm_out = mlu_ops.group_gemm( + act_out, w2, token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, max_m, + output=expand_hidden_states) + + output = mlu_ops.moe_combine_result( + gemm_out, reduce_weight, combine_idx, + residual_, cusum_token_count, start_expert_id, + expert_size, bias2) + if self.moe_ep_size > 1 or self.moe_tp_rank == 0: + output = mlu_ops.moe_combine_result( + expand_hidden_states_zero, reduce_weight, combine_idx, + output, cusum_token_count, self.start_zero_expert_id, + self.zero_expert_size, bias2, + output=output) + return output.view(ori_input_shape) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/quantization/__init__.py b/vllm_mlu/model_executor/layers/quantization/__init__.py new file mode 100644 index 0000000..c60025c --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/__init__.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.model_executor.layers.quantization import ( + QUANTIZATION_METHODS, register_quantization_config +) + +MLU_QUANTIZATION_METHODS= [ + "smoothquant", + "weightonly", + "awq_mlu", + "gptq_mlu", +] + + +def register_fake_mlu_quantization_methods(): + for quant_method in MLU_QUANTIZATION_METHODS: + if quant_method not in QUANTIZATION_METHODS: + QUANTIZATION_METHODS.append(quant_method) + + +def remove_fake_mlu_quantization_methods(): + for quant_method in MLU_QUANTIZATION_METHODS: + if quant_method in QUANTIZATION_METHODS: + QUANTIZATION_METHODS.remove(quant_method) + + +def register_real_mlu_quantization_methods(): + remove_fake_mlu_quantization_methods() + from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig + from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig + from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig + from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig + register_quantization_config("weightonly")(WeightOnlyConfig) + register_quantization_config("smoothquant")(SmoothQuantConfig) + register_quantization_config("awq_mlu")(AWQMluConfig) + register_quantization_config("gptq_mlu")(GPTQMluConfig) diff --git a/vllm_mlu/model_executor/layers/quantization/awq_mlu.py b/vllm_mlu/model_executor/layers/quantization/awq_mlu.py new file mode 100644 index 0000000..59c5011 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/awq_mlu.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import ScalarType, scalar_types +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from vllm_mlu import _mlu_ops as mlu_ops + +logger = init_logger(__name__) + +MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512] + +# We only support gptq and awq over 300 serials and only support int4 and int8 precision +def query_mlu_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + + if has_zp: + # AWQ style, unsigned + zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def check_mlu_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + + supported_types = query_mlu_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"Mlu does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES): + return (False, f"Mlu does not support group_size = {group_size}. " + f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True + + +# @register_quantization_config("awq_mlu") +class AWQMluConfig(QuantizationConfig): + """Config class for AWQMlu. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + # num_bits -> type + TYPE_MAP = { + 4: { + False: scalar_types.uint4b8, + True: scalar_types.uint4, + }, + 8: { + False: scalar_types.uint8b128, + True: scalar_types.uint8, + } + } + + VERSION = ["gemm"] + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + version: str = "gemm", + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.lm_head_quantized = lm_head_quantized + self.pack_factor = 32 // self.weight_bits + self.version = version + self.support_scale_zeros = False + + if self.weight_bits not in [4, 8]: + raise ValueError( + "Currently, only 4/8-bit weight quantization is supported for " + f"AWQMlu, but got {self.weight_bits} bits.") + if self.version not in self.VERSION: + raise ValueError( + "Currently, only gemm, gemv version is supported for " + f"AWQMlu, but got verion:{self.version}.") + + if self.version in ["gemm"]: + self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]} + self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]} + else: + self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]} + self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]} + + def __repr__(self) -> str: + return (f"AWQMluConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}), " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> str: + return "awq_mlu" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16, torch.float32] + + @staticmethod + def get_config_filenames() -> List[str]: + return ["quant_config.json", "quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + version = cls.get_from_keys_or(config, ["version"], + default="gemm") + return cls(weight_bits, group_size, zero_point, lm_head_quantized, version) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AWQMluLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return AWQMluLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None or user_quant == "awq" + or user_quant == "awq_mlu") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info("Detected that the model can run with awq_mlu" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_mlu for" + " faster inference") + return None + + @classmethod + def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + has_zp = quant_config.get("zero_point", None) + version = quant_config.get("version", "gemm") + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or has_zp is None): + return False + + if num_bits not in cls.TYPE_MAP: + return False + + if version not in cls.VERSION: + return False + + return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp], + group_size=group_size, + has_zp=has_zp) + +class AWQMluLinearMethod(LinearMethodBase): + """Linear method for AWQMlu. + + Args: + quant_config: The AWQMlu quantization config. + """ + + def __init__(self, quant_config: AWQMluConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight_loader = extra_weight_attrs.get("weight_loader") + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + qzeros = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + packed_qweight, scale_zeros = self.extract_autoawq(layer) + if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros): + layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False) + layer.qzeros = None + layer.scales = None + else: + layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False) + if scale_zeros is not None: + layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False) + else: + layer.qzeros = None + layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.quant_config.zero_point and not self.quant_config.support_scale_zeros: + output = mlu_ops.matmul(x, layer.qweight, bias) + if residual is not None: + output = output + residual + else: + output = mlu_ops.weight_only_quant_matmul(x, + layer.qweight, + layer.scales, + layer.qzeros, + bias, + residual, + "none", + self.quant_config.weight_bits) + + return output + + def extract_autoawq(self, layer: torch.nn.Module): + qweight = layer.qweight.data + qzeros = layer.qzeros.data + scales = layer.scales.data + bits = self.quant_config.weight_bits + group_size = self.quant_config.group_size + + # Unpack the qweight and qzeros tensors + iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = self.reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + if izeros is not None: + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros): + scales = scales.repeat_interleave(group_size, dim=0) + if izeros is not None: + izeros = izeros.repeat_interleave(group_size, dim=0) + fweight = (iweight - izeros) * scales + else: + fweight = iweight * scales + # transpose [ci, co] -> [co, ci] + fweight = fweight.transpose(0, 1) + + return fweight, None + + if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None: + scale_zeros = izeros.to(scales.dtype) * -1 * scales + # transpose [ci, co] -> [co, ci] + scale_zeros = scale_zeros.transpose(0, 1) + else: + scale_zeros = None + + # transpose [ci, co] -> [co, ci] + iweight = iweight.to(torch.int8).transpose(0, 1) + + if bits == 4: + higher_bit_tensor = iweight[:, 1::2] + lower_bit_tensor = iweight[:, 0::2] + packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor) + else: + packed_qweight = iweight + + return packed_qweight, scale_zeros + + def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qweight.device) + dtype = torch.int16 if bits == 8 else torch.int8 + # unpacking columnwise + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype) + iweights = iweights.view(iweights.shape[0], -1) + if not self.quant_config.zero_point or self.quant_config.support_scale_zeros: + iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1) + + # unpacking columnwise + if qzeros is not None: + izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype) + izeros = izeros.view(izeros.shape[0], -1) + if not self.quant_config.zero_point: + izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1) + else: + izeros = None + + return iweights, izeros + + def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]] + reverse_order_tensor = reverse_order_tensor.view(-1) + + rweights = iweights[:, reverse_order_tensor] + if izeros is not None: + rzeros = izeros[:, reverse_order_tensor] + + return rweights, rzeros + + def combine_low_bits(self, tensor_a, tensor_b): + """ + Combine the lower 4 bits of two int8 tensors into a new int8 tensor. + + Args: + tensor_a (torch.Tensor): First tensor of type int8. + tensor_b (torch.Tensor): Second tensor of type int8. + + Returns: + torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b. + """ + # 确保输入是 int8 类型 + if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8: + raise ValueError("Both tensors must be of int8 type.") + + # 提取每个 tensor 的低4位 + low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位 + low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位 + + # 将 tensor_a 的低4位左移4位 + shifted_low_bits_a = low_bits_a << 4 + + # 组合两个 tensor 的低4位 + combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b) + + return combined diff --git a/vllm_mlu/model_executor/layers/quantization/fp8.py b/vllm_mlu/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000..a590a21 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/fp8.py @@ -0,0 +1,753 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import functools +from functools import partial +import importlib.util +from typing import Any, Callable, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from typing import Any, Dict, List, Optional, Callable +from vllm import envs +from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.model_executor.layers.quantization.fp8 import ( + get_flashinfer_moe_backend, + ACTIVATION_SCHEMES, + Fp8Config, + Fp8LinearMethod, + Fp8MoeBackend, + Fp8MoEMethod, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + flashinfer_cutlass_moe_fp8, + get_flashinfer_moe_backend, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, + create_fp8_input_scale, + create_fp8_scale_parameter, + create_fp8_weight_parameter, + validate_fp8_block_shape +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported, + normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, + maybe_create_device_identity, Fp8LinearOp) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, ChannelQuantScaleParameter, + ModelWeightParameter, PerTensorScaleParameter) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import ( + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, +) +from vllm.utils.flashinfer import has_flashinfer_moe +from vllm.utils.import_utils import has_deep_gemm + +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize +import vllm_mlu._mlu_ops as mlu_ops + + +logger = init_logger(__name__) + + +def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: + """ + Select the primary FP8 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. + if ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) + ) + and envs.VLLM_USE_FLASHINFER_MOE_FP8 + and has_flashinfer_moe() + ): + backend = get_flashinfer_moe_backend() + if backend == FlashinferMoeBackend.TENSORRT_LLM: + logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") + return Fp8MoeBackend.FLASHINFER_TRTLLM + else: + if block_quant and current_platform.is_device_capability(100): + raise ValueError( + "FlashInfer FP8 MoE throughput backend does not " + "support block quantization. Please use " + "VLLM_FLASHINFER_MOE_BACKEND=latency " + "instead." + ) + logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100") + return Fp8MoeBackend.FLASHINFER_CUTLASS + + # weight-only path for older GPUs without native FP8 + use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: disable marlin for MLU backend. + ''' + if current_platform.is_rocm() or current_platform.is_out_of_tree(): + use_marlin = False + ''' + ================== + End of MLU Hijack + ================== + ''' + if use_marlin: + logger.info_once("Using Marlin backend for FP8 MoE") + return Fp8MoeBackend.MARLIN + + # deepGEMM on supported platforms with block-quantized weights + if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: + if not has_deep_gemm(): + logger.warning_once("DeepGEMM backend requested but not available.") + elif is_deep_gemm_supported(): + logger.info_once("Using DeepGEMM backend for FP8 MoE") + return Fp8MoeBackend.DEEPGEMM + + # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and block_quant + ): + logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") + return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + + # default to Triton + logger.info_once("Using Triton backend for FP8 MoE") + return Fp8MoeBackend.TRITON + + +Fp8Config____init____org = Fp8Config.__init__ + +def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: list[str] | None = None, + weight_block_size: list[int] | None = None, + activation_quant_method: Optional[str] = None, + weight_quant_method: Optional[str] = None, +) -> None: + super(Fp8Config, self).__init__() + + Fp8Config____init____org( + self, + is_checkpoint_fp8_serialized, + activation_scheme, + ignored_layers, + weight_block_size + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add class members activation_quant_method and weight_quant_method to + indicate the granularity of quantization. + ''' + self.activation_quant_method = activation_quant_method + self.weight_quant_method = weight_quant_method + + assert (self.weight_block_size or \ + self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel" + and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\ + "input dynamic per-token weight per-channel quantization yet." + ''' + ================== + End of MLU Hijack + ================== + ''' + + +@classmethod +def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config( + cls, config: Dict[str, Any] +) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add config members activation_quant_method and weight_quant_method to + indicate the granularity of quantization. + ''' + activation_quant_method = cls.get_from_keys_or(config, + ["activation_quant_method"], + 'per_token') + weight_quant_method = cls.get_from_keys_or(config, + ["weight_quant_method"], + None) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + activation_quant_method=activation_quant_method, + weight_quant_method=weight_quant_method) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, +): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add tp_group. + ''' + tp_group = extra_weight_attrs.get("tp_group", None) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add tp_group. + ''' + # WEIGHT + if self.quant_config.is_checkpoint_fp8_serialized: + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) + else: + # For non-serialized checkpoints, use original dtype + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + tp_group=tp_group, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + if not self.block_quant: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Support weight per channel quantization. + @brief: Add tp_group to enable custom split. + ''' + if self.weight_per_channel: + scale = ChannelQuantScaleParameter( + data=torch.empty(sum(output_partition_sizes), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + tp_group=tp_group, + ) + else: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + layer.register_parameter("weight_scale", scale) + ''' + ================== + End of MLU Hijack + ================== + ''' + else: + assert not self.act_q_static + assert self.weight_block_size is not None + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader, + ) + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", scale) + + # INPUT ACTIVATION SCALE + if self.act_q_static: + scale = create_fp8_input_scale(output_partition_sizes, weight_loader) + set_weight_attrs(scale, {"scale_type": "input_scale"}) + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + +def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__( + self, + quant_config: Fp8Config +): + self.quant_config = quant_config + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.out_dtype = torch.get_default_dtype() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + if vllm_is_batch_invariant(): + self.use_marlin = False + + # AITER is only supported on ROCm and only for FP8_FNUZ + # and at the moment are MI300 series + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() + self.use_deep_gemm = is_deep_gemm_supported() + + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None + if self.block_quant: + # Marlin doesn't support block-wise fp8 + self.use_marlin = False + + self.act_q_static = self.quant_config.activation_scheme == "static" + if self.weight_block_size: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + else: + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add config members activation_quant_method and weight_quant_method to + indicate the granularity of quantization. + ''' + self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel') + self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token') + if self.weight_per_channel and self.activation_per_token: + self.use_marlin = False + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.block_quant: + assert not self.act_q_static + assert self.weight_block_size is not None + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape, + ) + + +Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading + + +def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading( + self, + layer: Module, +) -> None: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: For dynamic activation and channel-wise weight quantization, + additional processing is not needed. + ''' + if (self.quant_config.is_checkpoint_fp8_serialized + and self.weight_per_channel + and self.quant_config.activation_scheme == "dynamic"): + return + ''' + ================== + End of MLU Hijack + ================== + ''' + Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer) + + +def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert residual is None, "Fp8Linear residual is not supported yet." + + # if batch invariant mode is enabled, prefer DeepGEMM FP8 path + # we will use BF16 dequant when DeepGEMM is not supported. + if vllm_is_batch_invariant(): + if self.block_quant: + assert self.weight_block_size is not None + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) + else: + # per-tensor/channel: dequant to BF16 and run GEMM + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) + if weight_scale.numel() == 1: + # Per-tensor: simple scalar multiplication + weight_bf16 = weight_fp8 * weight_scale + else: + # Multiple scales (fused modules like QKV) + # Try to infer correct broadcasting + # weight is [K, N], scale could be [num_logical_weights] + # Need to figure out how to broadcast - for now just try + # direct multiplication + if ( + weight_scale.dim() == 1 + and weight_scale.shape[0] == weight_fp8.shape[0] + ): + # Per-row scaling + weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1) + else: + # Fallback + weight_bf16 = weight_fp8 * weight_scale + return torch.nn.functional.linear(x, weight_bf16.t(), bias) + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + if self.block_quant: + assert self.weight_block_size is not None + from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import ( + apply_w8a8_block_fp8_linear) + return apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Use activation per token quantization based on quantization config. + ''' + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + weight_per_channel=self.weight_per_channel, + activation_per_token=self.activation_per_token) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__( + self, + quant_config: Fp8Config, + layer: torch.nn.Module +): + super(Fp8MoEMethod, self).__init__(layer.moe_config) + self.layer = layer + self.quant_config = quant_config + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant: bool = self.weight_block_size is not None + self.fp8_backend = get_fp8_moe_backend(self.block_quant) + + self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN + self.flashinfer_moe_backend: FlashinferMoeBackend | None = None + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + if self.block_quant: + assert self.weight_block_size == [128, 128], ( + f"Only support weight_block_size == [128, 128], " + f"got {self.weight_block_size}" + ) + self.flashinfer_moe_fn = partial( + flashinfer_cutlass_moe_fp8, + moe=self.moe, + use_deepseek_fp8_block_scale=self.block_quant, + ) + + self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM + self.allow_cutlass_block_scaled_grouped_gemm = ( + self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: In mlu, always set self.use_marlin as False. + ''' + self.use_marlin = False + ''' + ================== + End of MLU Hijack + ================== + ''' + + +def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, +) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts + ''' + from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts + if scoring_func == "softmax": + topk_weights, topk_ids = mlu_ops.moe_softmax_topk( + router_logits, + top_k, + renormalize, + num_expert_group, + topk_group, + route_scale=routed_scaling_factor, + ) + elif scoring_func == "sigmoid": + topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk( + router_logits, + top_k, + renormalize, + num_expert_group, + topk_group, + routed_scaling_factor, + e_score_correction_bias, + ) + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + # gen_idx + ori_input_shape = x.shape + x = x.reshape(-1, x.size(-1)) + router_logits = router_logits.reshape(-1, router_logits.size(-1)) + expert_num = router_logits.size(-1) + tokens_num = x.size(0) + expert_size = layer.w13_weight.size(0) + expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx( + topk_ids, expert_num + ) + + expand_hidden_states = mlu_ops.moe_expand_input( + x, expand_idx, cumsum_token_count, 0, expert_size + ) + quant_input, input_scale = _fp8_quantize( + expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size + ) + gemm1_out = mlu_ops.smooth_quant_group_gemm( + quant_input, + layer.w13_weight, + token_count, + expand_idx=None, + c=None, + alpha=None, + beta=None, + a_scale=input_scale.T.contiguous(), + b_scale=layer.w13_weight_scale_inv, + dtype=x.dtype, + max_m=tokens_num, + ) + + act_out = mlu_ops.active(gemm1_out, activation, is_gated=True) + act_out_quantize, act_out_scale = _fp8_quantize( + act_out, A_scale=None, block_shape=self.quant_config.weight_block_size + ) + + gemm2_out = mlu_ops.smooth_quant_group_gemm( + act_out_quantize, + layer.w2_weight, + token_count, + expand_idx=None, + c=None, + alpha=None, + beta=None, + a_scale=act_out_scale.T.contiguous(), + b_scale=layer.w2_weight_scale_inv, + dtype=x.dtype, + max_m=tokens_num, + ) + + output = mlu_ops.moe_combine_result( + gemm2_out, + topk_weights, + combine_idx, + residual=None, + cusum_token_count=cumsum_token_count, + start_expert_id=0, + expert_size=expert_size, + bias=None, + ) + return output.view(ori_input_shape) + + """ + ================== + End of MLU Hijack + ================== + """ + + +MluHijackObject.apply_hijack( + Fp8LinearMethod, + Fp8LinearMethod.apply, + vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply +) +MluHijackObject.apply_hijack( + Fp8Config, + Fp8Config.__init__, + vllm__model_executor__layers__quantization__fp8__Fp8Config____init__ +) +MluHijackObject.apply_hijack( + Fp8Config, + Fp8Config.from_config, + vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config +) +MluHijackObject.apply_hijack( + Fp8LinearMethod, + Fp8LinearMethod.create_weights, + vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights +) +MluHijackObject.apply_hijack( + Fp8LinearMethod, + Fp8LinearMethod.__init__, + vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__ +) +MluHijackObject.apply_hijack( + Fp8LinearMethod, + Fp8LinearMethod.process_weights_after_loading, + vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading +) +MluHijackObject.apply_hijack( + Fp8MoEMethod, + Fp8MoEMethod.__init__, + vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__ +) +MluHijackObject.apply_hijack( + Fp8MoEMethod, + Fp8MoEMethod.apply, + vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply +) diff --git a/vllm_mlu/model_executor/layers/quantization/gptq_mlu.py b/vllm_mlu/model_executor/layers/quantization/gptq_mlu.py new file mode 100644 index 0000000..4cae5d3 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/gptq_mlu.py @@ -0,0 +1,440 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from fractions import Fraction +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from vllm_mlu import _mlu_ops as mlu_ops + +logger = init_logger(__name__) + +MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512] + +# We only support gptq and awq over 300 serials and only support int4 and int8 precision +def query_mlu_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + + if has_zp: + # AWQ style, unsigned + zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def check_mlu_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + + supported_types = query_mlu_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"Mlu does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES): + return (False, f"Mlu does not support group_size = {group_size}. " + f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True + + +# @register_quantization_config("gptq_mlu") +class GPTQMluConfig(QuantizationConfig): + """Config class for GPTQMlu. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + (4, False): scalar_types.uint4b8, + (8, False): scalar_types.uint8b128, + } + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.lm_head_quantized = lm_head_quantized + self.pack_factor = Fraction(32, self.weight_bits) + self.support_scale_zeros = False + self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros) + + if self.weight_bits not in [4, 8]: + raise ValueError( + "Currently, only 4/8-bit weight quantization is " + f"supported for GPTQMlu, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQMluConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})," + f"lm_head_quantized={self.lm_head_quantized}") + + @classmethod + def get_name(cls) -> str: + return "gptq_mlu" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16, torch.float32] + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quant_config.json", "quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMluLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return GPTQMluLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + @classmethod + def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + sym = quant_config.get("sym", None) + desc_act = quant_config.get("desc_act", None) + + if quant_method != "gptq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size, has_zp=False) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_mlu") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + +class GPTQMluLinearMethod(LinearMethodBase): + """Linear method for GPTQMlu. + + Args: + quant_config: The GPTQMlu quantization config. + """ + + def __init__(self, quant_config: GPTQMluConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) + if (output_size_per_partition % self.quant_config.pack_factor.numerator + != 0): + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if (input_size != input_size_per_partition) and (self.quant_config.group_size != + -1) and (not self.quant_config.desc_act): + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + g_idx = RowvLLMParameter(data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + qzeros_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.device = layer.qweight.data.device + packed_qweight, scale_zeros = self.extract_autogptq(layer) + if self.quant_config.use_native: + layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False) + layer.qzeros = None + layer.scales = None + else: + layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False) + if scale_zeros is not None: + layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False) + else: + layer.qzeros = None + layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False) + + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.quant_config.use_native: + output = mlu_ops.matmul(x, layer.qweight, bias) + if residual is not None: + output = output + residual + else: + output = mlu_ops.weight_only_quant_matmul(x, + layer.qweight, + layer.scales, + layer.qzeros, + bias, + residual, + "none", + self.quant_config.weight_bits) + + return output + + + def extract_autogptq(self, layer: torch.nn.Module): + scales = layer.scales.data + bits = self.quant_config.weight_bits + group_size = self.quant_config.group_size + # Unpack the qweight and qzeros tensors + iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits) + izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits) + + if self.quant_config.use_native: + if self.quant_config.desc_act: + scales = torch.index_select(scales, 0, layer.g_idx) + if izeros is not None: + izeros = torch.index_select(izeros, 0, layer.g_idx) + else: + scales = scales.repeat_interleave(group_size, dim=0) + if izeros is not None: + izeros = izeros.repeat_interleave(group_size, dim=0) + + if izeros is not None: + fweight = (iweight - izeros) * scales + else: + fweight = iweight * scales + # transpose [ci, co] -> [co, ci] + fweight = fweight.transpose(0, 1) + + return fweight, None + + if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None: + scale_zeros = izeros.to(scales.dtype) * -1 * scales + # transpose [ci, co] -> [co, ci] + scale_zeros = scale_zeros.transpose(0, 1) + else: + # for is_sym is true now, so make iweight to sign value and ignore qzeros + iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1) + scale_zeros = None + + # transpose [ci, co] -> [co, ci] + iweight = iweight.to(torch.int8).transpose(0, 1) + + if bits == 4: + higher_bit_tensor = iweight[:, 1::2] + lower_bit_tensor = iweight[:, 0::2] + packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor) + else: + packed_qweight = iweight + + return packed_qweight, scale_zeros + + + def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0) + dtype = torch.int16 if bits == 8 else torch.int8 + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), + shifts.unsqueeze(-1), + ).to(dtype) + weight = torch.bitwise_and(weight, (2**bits) - 1) + weight = weight.reshape(-1, weight.shape[-1]) + + return weight + + + def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0) + dtype = torch.int16 if bits == 8 else torch.int8 + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), + shifts.unsqueeze(0), + ).to(dtype) + + zeros = zeros + 1 + + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + zeros = zeros.reshape(qzeros.shape[0], -1) + + return zeros + + + def combine_low_bits(self, tensor_a, tensor_b): + """ + Combine the lower 4 bits of two int8 tensors into a new int8 tensor. + + Args: + tensor_a (torch.Tensor): First tensor of type int8. + tensor_b (torch.Tensor): Second tensor of type int8. + + Returns: + torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b. + """ + # 确保输入是 int8 类型 + if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8: + raise ValueError("Both tensors must be of int8 type.") + + # 提取每个 tensor 的低4位 + low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位 + low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位 + + # 将 tensor_a 的低4位左移4位 + shifted_low_bits_a = low_bits_a << 4 + + # 组合两个 tensor 的低4位 + combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b) + + return combined diff --git a/vllm_mlu/model_executor/layers/quantization/smoothquant.py b/vllm_mlu/model_executor/layers/quantization/smoothquant.py new file mode 100755 index 0000000..b22e577 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/smoothquant.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, + RowvLLMParameter) +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch, + str_dtype_to_bits, + is_fp8_str_dtype) + + +# @register_quantization_config("smoothquant") +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant. + """ + + def __init__( + self, + quant_mode: str, # smoothquant + input_quant_method: str, # per token/per tensor + group_size: int, + weight_precision: str, + activation_precision: str, + only_expert_per_group: bool, + expert_weight_precision: str, + expert_activation_precision: str, + force_use_weightonly_except_expert: bool, + ) -> None: + super().__init__() + self.quant_mode = quant_mode + self.input_quant_method = input_quant_method + self.group_size = group_size + self.weight_precision = weight_precision + self.activation_precision = activation_precision + self.only_expert_per_group = only_expert_per_group + self.expert_weight_precision = expert_weight_precision + self.expert_activation_precision = expert_activation_precision + self.force_use_weightonly_except_expert = force_use_weightonly_except_expert + + if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"): + raise ValueError( + "Currently, only per_token or per_tensor input quantization is supported for " + f"SmoothQuant, but got {self.input_quant_method}.") + + self.weight_bits = str_dtype_to_bits(self.weight_precision) + self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision) + if self.weight_precision == 'int4': + self.weight_dtype = torch.int8 + else: + self.weight_dtype = str_dtype_to_torch(self.weight_precision) + if self.expert_weight_precision == 'int4': + self.expert_weight_dtype = torch.int8 + else: + self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision) + self.is_fp8 = is_fp8_str_dtype(self.weight_precision) + self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision) + + self.pack_factor = 8 // self.weight_bits + self.expert_pack_factor = 8 // self.expert_weight_bits + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, " + f"quant_mode={self.quant_mode}, " + f"group_size={self.group_size}, " + f"weight_precision={self.weight_precision}, " + f"activation_precision={self.activation_precision}, " + f"only_expert_per_group={self.only_expert_per_group}, " + f"expert_weight_precision={self.expert_weight_precision}, " + f"expert_activation_precision={self.expert_activation_precision}, " + f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})") + + @classmethod + def get_name(self) -> str: + return "SmoothQuant" + + @classmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @staticmethod + def get_config_filenames() -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + quant_mode = cls.get_from_keys(config, ["quant_mode"]) + input_quant_method = cls.get_from_keys(config, ["input_quant_method"]) + group_size = cls.get_from_keys_or(config, ["group_size"], 1) + weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8") + activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8") + only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False) + expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None) + expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None) + force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False) + + if expert_weight_precision is None: + expert_weight_precision = weight_precision + if group_size > 1 and only_expert_per_group and weight_precision == 'int4': + weight_precision = 'int8' + + if expert_activation_precision is None: + expert_activation_precision = activation_precision + + return cls(quant_mode=quant_mode, + input_quant_method=input_quant_method, + group_size=group_size, + weight_precision=weight_precision, + activation_precision=activation_precision, + only_expert_per_group=only_expert_per_group, + expert_weight_precision=expert_weight_precision, + expert_activation_precision=expert_activation_precision, + force_use_weightonly_except_expert=force_use_weightonly_except_expert) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["SmoothQuantLinearMethod"]: + if isinstance(layer, LinearBase): + return SmoothQuantLinearMethod(self, prefix) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + +class SmoothQuantLinearMethod(LinearMethodBase): + """Linear method for SmoothQuant. + + Args: + quant_config: The SmoothQuant quantization config. + """ + + def __init__(self, quant_config: SmoothQuantConfig, prefix: str): + self.quant_config = quant_config + # for per-tensor case, we can skip quant input for the first attn|ffn linear + # and fusion this step in layernorm to get better performance + self.skip_quant_input = False + self.compute_dtype = torch.get_default_dtype() + self.is_expert = 'expert' in prefix and "shared_expert" not in prefix + self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype + self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor + self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8 + + if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1: + self.is_group_quant = True + elif quant_config.only_expert_per_group is False and quant_config.group_size > 1: + self.is_group_quant = True + else: + self.is_group_quant = False + self.has_smooth = self.quant_config.input_quant_method == "per_token" and ( + self.quant_config.force_use_weightonly_except_expert is False or self.is_expert) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + if (output_size_per_partition % self.quant_config.pack_factor != 0): + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight_loader = extra_weight_attrs.get("weight_loader") + group_num = 1 + if self.is_group_quant: + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + f"The input size {input_size_per_partition} is not aligned with the quantized " + f"weight shape. This can be caused by too large " + f"tensor parallel size. group_size: {self.quant_config.group_size}.") + + group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size + if input_size_per_partition != input_size: + group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size + + qweight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.pack_factor, + device="mlu", + dtype=self.weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + if self.is_group_quant: + per_channel_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + group_num, + device="mlu", + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + else: + per_channel_scale = ChannelQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + device="mlu", + dtype=torch.float32, + ), + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("per_channel_scale", per_channel_scale) + + if self.has_smooth: + smooth = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + device="mlu", + dtype=torch.float32, + ), + input_dim=0, + weight_loader=weight_loader, + ) + set_weight_attrs(smooth, { + "ignore_warning": True, + }) + layer.register_parameter("smooth", smooth) + if self.quant_config.input_quant_method == "per_tensor": + scale_to_int = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + device="mlu", + dtype=torch.float32, + ), + input_dim=0, + weight_loader=weight_loader, + ) + set_weight_attrs(scale_to_int, { + "ignore_warning": True, + }) + layer.register_parameter("scale_to_int", scale_to_int) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.has_smooth and layer.smooth.dtype != torch.float: + layer.smooth = layer.smooth.to(torch.float) + if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float: + layer.scale_to_int = layer.scale_to_int.to(torch.float) + if layer.per_channel_scale.dtype != torch.float: + layer.per_channel_scale = layer.per_channel_scale.to(torch.float) + + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False) + if self.has_smooth: + layer.smooth = Parameter(layer.smooth.data, requires_grad=False) + if self.quant_config.input_quant_method == "per_tensor": + layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + use_tp_weight : bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + layer_smooth = layer.smooth if self.has_smooth else None + layer_qweight = layer.qweight + layer_per_channel_scale = layer.per_channel_scale + if use_tp_weight: + if hasattr(layer, 'tp_smooth'): + layer_smooth = layer.tp_smooth + if hasattr(layer, 'tp_qweight'): + layer_qweight = layer.tp_qweight + if hasattr(layer, 'tp_per_channel_scale'): + layer_per_channel_scale = layer.tp_per_channel_scale + + quant_input = None + if self.skip_quant_input: + quant_input = x + elif self.quant_config.input_quant_method == "per_token": + if self.is_fp8: + quant_input, input_scale = mlu_ops.scaled_quantize(x, + layer_smooth, + quant_type=self.weight_dtype, + quant_mode='dynamic_per_token') + else: + quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None) + elif self.quant_config.input_quant_method == "per_tensor": + quant_input = mlu_ops.quantize(x, layer.scale_to_int, None) + else: + raise ValueError( + "Currently, only per_token or per_tensor input quantization is supported for " + f"SmoothQuant, but got {self.input_quant_method}.") + quant_input_shape = quant_input.shape + if len(quant_input_shape) > 2: + quant_input = quant_input.view(-1, quant_input_shape[-1]) + input_scale = input_scale.view(-1) + if residual is not None and len(residual.shape) > 2: + residual = residual.view(-1, residual.shape[-1]) + if self.is_fp8: + out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale, + layer_per_channel_scale, + self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype, + bias, + c=residual, act_mode="none",quant_bit_size=8, + alpha=1.0, beta=1.0, use_hp_active=False, + a_quant_bit_size=8, a_calib=None, b_calib=None) + if output is not None: + out = out.view(output.shape) + output.copy_(out) + out = output + else: + if output is not None: + out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight, + layer_per_channel_scale, self.compute_dtype, bias, residual, output=output) + else: + out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight, + layer_per_channel_scale, self.compute_dtype, bias, residual) + if len(quant_input_shape) > 2: + out = out.view(*quant_input_shape[:-1], out.shape[-1]) + return out diff --git a/vllm_mlu/model_executor/layers/quantization/utils/__init__.py b/vllm_mlu/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/model_executor/layers/quantization/utils/common_utils.py b/vllm_mlu/model_executor/layers/quantization/utils/common_utils.py new file mode 100644 index 0000000..7f3ea07 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/utils/common_utils.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz'] +INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short, + torch.int32, torch.int, torch.int64, torch.long] +FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16, + torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half] +FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] +FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz'] +GEMM_GROUP_SIZE = [64, 128, 256, 512] + +_STR_TO_TORCH_DTYPE_DICT = dict( + bfloat16=torch.bfloat16, + float16=torch.float16, + float32=torch.float32, + int64=torch.int64, + int32=torch.int32, + int8=torch.int8, + bool=torch.bool, + e4m3fn=torch.float8_e4m3fn, + e4m3fnuz=torch.float8_e4m3fnuz, + e5m2=torch.float8_e5m2, + e5m2fnuz=torch.float8_e5m2fnuz, +) + +TORCH_DTYPE_TO_STR_DICT = { + torch.bfloat16: "bfloat16", + torch.float16: "float16", + torch.float32: "float32", + torch.int64: "int64", + torch.int32: "int32", + torch.int8: "int8", + torch.bool: "bool", + torch.float8_e4m3fn: "e4m3fn", + torch.float8_e4m3fnuz: "e4m3fnuz", + torch.float8_e5m2: "e5m2", + torch.float8_e5m2fnuz: "e5m2fnuz", +} + +STR_DTYPE_TO_BITS_DICT = { + "bfloat16": 16, + "float16": 16, + "float32": 32, + "int64": 64, + "int32": 32, + "int8": 8, + 'int4': 4, + "bool": 1, + "e4m3fn": 8, + "e4m3fnuz": 8, + "e5m2": 8, + "e5m2fnuz": 8, +} + + +def str_dtype_to_torch(str_dtype: str): + ''' + convert torch dytpe to str dtype + ''' + ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype) + dtype = ret if ret is not None else torch.float16 + return dtype + + +def torch_dtype_to_str(dtype: torch.dtype): + ''' + convert torch dytpe to str dtype + ''' + ret = TORCH_DTYPE_TO_STR_DICT.get(dtype) + str_dtype = ret if ret is not None else "float16" + return str_dtype + + +def str_dtype_to_bits(str_dtype): + ''' + convert torch dtype to bits size + ''' + ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype) + bits = ret if ret is not None else 8 + return bits + + +def is_integer_dtype(dtype: torch.dtype): + ''' + check whether is integer or not + ''' + return dtype in INTERGER_DTYPES + + +def is_float_dtype(dtype: torch.dtype): + ''' + check whether is float or not + ''' + return dtype in FLOAT_DTYPES + + +def is_fp8_dtype(dtype: torch.dtype): + ''' + judge fp8 torch dtype + ''' + return dtype in FP8_DTYPE + + +def is_fp8_str_dtype(str_dtype: str): + ''' + judge fp8 str dtype + ''' + return str_dtype in FP8_STR_DTYPE diff --git a/vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py b/vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py new file mode 100644 index 0000000..8943703 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py @@ -0,0 +1,424 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import functools +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor) +from vllm.platforms import current_platform + +from vllm_mlu import _mlu_ops as mlu_ops + +logger = init_logger(__name__) + +''' +============================= +Modify by vllm_mlu +============================= +@brief: get total core for split triton kernel +''' + +import triton.backends.mlu.driver as driver + +_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device()) +TOTAL_CLUSTER_NUM = _devprob.get("cluster_num") +TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster") + +''' +================== +End of MLU Hijack +================== +''' + + +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0) + if current_platform.is_rocm(): + # TODO this is never used, as cutlass_block_fp8_supported is False + scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + + input_2d.shape[:-1])[::-1] + scale_b_shape = (weight_scale.view(-1, 1) + if weight_scale.dim() <= 1 else weight_scale.T).shape + ar, ac = scale_a_shape + br, bc = scale_b_shape + if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) + or br not in (1, weight.shape[0])): + shape_supported_by_cutlass = False + if cutlass_block_fp8_supported and shape_supported_by_cutlass: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) + output = ops.cutlass_scaled_mm(q_input, + weight.T, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.T) + else: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=False) + output = w8a8_block_fp8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + dtype = current_platform.fp8_dtype() if dtype is None else dtype + assert (x.shape[-1] % group_size == 0), ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: split for limit the memory usage(65536) + ''' + group_per_block = 1 + while M >= 65536: + group_per_block *= 2 + M = x.numel() // (group_size * group_per_block) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if column_major_scales: + shape = (x.shape[-1] // group_size, ) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, + dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: set num_warps to 1 for triton-mlu + ''' + num_warps = 1 + num_stages = 1 + ''' + ================== + End of MLU Hijack + ================== + ''' + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with + '_per_token_group_quant_fp8' kernel + ''' + # Check if x is contiguous, if not, create a new tensor for contiguous x + if not x.is_contiguous(): + x = x.contiguous() + x_origin_shape = x.shape + x = x.reshape(*x.shape[:-1], -1, group_size) + x_q, x_s = mlu_ops.scaled_quantize(x, + None, + quant_type=dtype, + quant_mode='dynamic_per_token') + x_q = x_q.reshape(x_origin_shape) + ''' + ================== + End of MLU Hijack + ================== + ''' + + return x_q, x_s + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: split for limit the memory usage(65536) + ''' + num_block_size_all = num_pid_m * num_pid_n + num_block_size_per = num_block_size_all // tl.num_programs(axis=0) + num_block_size_rem = num_block_size_all % tl.num_programs(axis=0) + + core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem) + core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid) + + for pid_i in range(0, core_deal_num_block_size): + pid_in_core_deal_block = core_deal_num_block_start + pid_i + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_in_core_deal_block // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m) + pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m + + ''' + ================== + End of MLU Hijack + ================== + ''' + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with + '_w8a8_block_fp8_matmul' kernel + ''' + + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert B.ndim == 2 and Bs.ndim == 2 + + if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0): + C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none", + quant_bit_size=8, alpha=1, beta=1, use_hp_active=False, + a_quant_bit_size=8, a_calib=None, b_calib=None) + else: + # NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数 + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + # Default config + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] + # BLOCK_SIZE_K must be divisible by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 1, + } + + def grid(META): + return (TOTAL_CORE_NUM, ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + + ''' + ================== + End of MLU Hijack + ================== + ''' + + return C diff --git a/vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py new file mode 100644 index 0000000..4cb5b73 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Optional, Callable +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm, + flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm, + torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm, + torch_channelwise_w8a8_scaled_mm) +from vllm.platforms import current_platform + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def mlu_w8a8_scaled_mm( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: list, **kwargs +) -> torch.Tensor: + output = mlu_ops.scaled_matmul( + qinput, # a + weight, # b + scale_a, # a_scale + scale_b, # b_scale + out_dtype, # output_dtype + bias, # bias + c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False, + a_quant_bit_size=8, a_calib=None, b_calib=None + ) + return output.view(*output_shape) + + +def dispatch_w8a8_scaled_mm( + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool, + weight_per_channel: bool, activation_per_token: bool +) -> Callable[..., torch.Tensor]: + if per_tensor_weights and per_tensor_activations: + if preferred_backend == "rocm": + return rocm_per_tensor_w8a8_scaled_mm + if preferred_backend == "flashinfer": + return flashinfer_w8a8_scaled_mm + if preferred_backend == "cutlass": + return cutlass_w8a8_scaled_mm + return torch_per_tensor_w8a8_scaled_mm + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if preferred_backend == "cutlass" or preferred_backend == "flashinfer": + return cutlass_w8a8_scaled_mm + + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) + if ( + not per_tensor_weights + and not per_tensor_activations + and USE_ROWWISE_TORCH_SCALED_MM + ): + return torch_per_token_w8a8_scaled_mm + # Normally, torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: dispatch to mlu_w8a8_scaled_mm + ''' + if weight_per_channel and activation_per_token: + return mlu_w8a8_scaled_mm + ''' + ================== + End of MLU Hijack + ================== + ''' + return torch_channelwise_w8a8_scaled_mm + + +def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype | None = None, + input_scale: torch.Tensor | None = None, + input_scale_ub: torch.Tensor | None = None, + bias: torch.Tensor | None = None, + weight_per_channel: bool = True, + activation_per_token: bool = True, +) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add mlu_fp8_supported + ''' + self.mlu_fp8_supported = False + if weight_per_channel and activation_per_token: + self.mlu_fp8_supported = True + ''' + ================== + End of MLU Hijack + ================== + ''' + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + if out_dtype is None: + out_dtype = input.dtype + + if self.mlu_fp8_supported: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add support for activation-per-token weight-per-channel quantization. + ''' + qinput, x_scale = mlu_ops.scaled_quantize( + input_2d,# x + None, # scale + None, # zero + None, # scale_ub + quant_type=torch.float8_e4m3fn, + quant_mode='dynamic_per_token' + ) + output_shape = [*input.shape[:-1], weight.shape[0]] + ''' + ================== + End of MLU Hijack + ================== + ''' + else: + # If input not quantized + # TODO(luka) remove this path if not used anymore + if input.dtype != current_platform.fp8_dtype(): + qinput, x_scale = self.quant_fp8( + input_2d, + input_scale, + input_scale_ub, + ) + else: + qinput, x_scale = input_2d, input_scale + + # Must have dim() conditions + # In per-token quant scenario, when the number of token is 1, + # the scale will only have 1 elements. + # Without checking the dim(), + # we cannot distingushes between per-tensor and per-token quant. + # Example: + # When the number of token is 1, per-token scale is [[1]] + # When per-tensor scale is [1] or (). + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 + + # TODO(luka) do this dispatch during init (after ScaledMM refactor) + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( + self.preferred_backend, per_tensor_weights, per_tensor_activations, + weight_per_channel, activation_per_token) + return w8a8_scaled_mm_func( + qinput=qinput, + weight=weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + output_shape=output_shape, + ) + + +MluHijackObject.apply_hijack( + Fp8LinearOp, + Fp8LinearOp.apply, + vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply +) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/quantization/weightonly.py b/vllm_mlu/model_executor/layers/quantization/weightonly.py new file mode 100755 index 0000000..a3ebb58 --- /dev/null +++ b/vllm_mlu/model_executor/layers/quantization/weightonly.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter +from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +from vllm_mlu import _mlu_ops as mlu_ops + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +# @register_quantization_config("weightonly") +class WeightOnlyConfig(QuantizationConfig): + """Config class for WeightOnly. + """ + + def __init__( + self, + weight_bits: int, + quant_mode: str, # weight_only + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.quant_mode = quant_mode + + if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4): + raise ValueError( + "Currently, only 8/4-bit weight quantization is supported for " + f"weight_only, but got {self.weight_bits} bits.") + self.pack_factor = 8 // self.weight_bits + + def __repr__(self) -> str: + return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, " + f"quant_mode={self.quant_mode})") + + def get_name(self) -> str: + return "WeightOnly" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @staticmethod + def get_config_filenames() -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + try: + quant_mode = cls.get_from_keys(config, ["quant_mode"]) + except Exception: + quant_mode = "WeightOnly" + return cls(weight_bits, quant_mode) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["WeightOnlyLinearMethod"]: + if isinstance(layer, LinearBase): + return WeightOnlyLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + +class WeightOnlyLinearMethod(LinearMethodBase): + """Linear method for WeightOnly. + + Args: + quant_config: The WeightOnly quantization config. + """ + + def __init__(self, quant_config: WeightOnlyConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> Dict[str, Any]: + output_size_per_partition = sum(output_partition_sizes) + if self.quant_config.quant_mode == "WeightOnly": + scale_and_zero_input_dim = None + if output_size != output_size_per_partition: + scale_and_zero_input_dim = 0 + qweight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + device="mlu", + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs(qweight, { + "input_dim": 1, + "output_dim": 0, + }) + scales = Parameter( + torch.empty( + output_size_per_partition, + device="mlu", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 0, + }) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if layer.scales.dtype != torch.float: + layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + x_shape = x.shape + if len(x_shape) > 2: + x = x.view(-1, x_shape[-1]) + out = mlu_ops.weight_only_quant_matmul(x, + layer.qweight, + layer.scales, + None, + bias, + residual, + "none", + self.quant_config.weight_bits) + if len(x_shape) > 2: + out = out.view(*x_shape[:-1], out.shape[-1]) + return out diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/__init__.py b/vllm_mlu/model_executor/layers/rotary_embedding/__init__.py new file mode 100644 index 0000000..79b7455 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/__init__.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import math +from typing import Any + +import torch + +from vllm.logger import init_logger +import vllm.model_executor.layers.rotary_embedding as rotary_embedding +from vllm.model_executor.layers.rotary_embedding import ( + _ROPE_DICT, + RotaryEmbedding, +) +from vllm.model_executor.layers.rotary_embedding import ( + _ROPE_DICT, + DualChunkRotaryEmbedding, + DynamicNTKAlphaRotaryEmbedding, + DynamicNTKScalingRotaryEmbedding, + Llama4VisionRotaryEmbedding, + MRotaryEmbedding, + NTKScalingRotaryEmbedding, + Phi3LongRoPEScaledRotaryEmbedding, + YaRNScalingRotaryEmbedding, +) + +from .base import MLURotaryEmbedding +from .deepseek_scaling_rope import MLUDeepseekScalingRotaryEmbedding +from .dynamic_ntk_alpha_rope import MLUDynamicNTKAlphaRotaryEmbedding +from .dynamic_ntk_scaling_rope import MLUDynamicNTKScalingRotaryEmbedding +from .linear_scaling_rope import MLULinearScalingRotaryEmbedding +from .llama3_rope import MLULlama3RotaryEmbedding +from .mrope import MLUMRotaryEmbedding +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor): + if MLURotaryEmbedding.max_seq_len != None and \ + MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor: + logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " + + f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " + + "from model's config.json, This may lead to incorrect model outputs or MLU errors. " + + f"Make sure the value is correct and within the model context size. " + + f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.") + return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor) + return max_position_embeddings + +def vllm__model_executor__layers__rotary_embedding__get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: float, + is_neox_style: bool = True, + rope_scaling: dict[str, Any] | None = None, + dtype: torch.dtype | None = None, + partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: dict[str, Any] | None = None, + inverse: bool = False +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dual_chunk_attention_args, + dtype, + inverse, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) + elif not rope_scaling: + rotary_emb = MLURotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype, + inverse=inverse, + ) + else: + scaling_type = rope_scaling["rope_type"] + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = MLULlama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) + elif scaling_type == "mllama4": + rotary_emb = Llama4VisionRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MLUMRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + ) + else: + rotary_emb = MLURotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + inverse=inverse, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = MLULinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "ntk": + scaling_factor = rope_scaling["factor"] + mixed_b = rope_scaling.get('mixed_b', None) + rotary_emb = NTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + mixed_b, + ) + elif scaling_type == "dynamic": + if "alpha" in rope_scaling: + scaling_alpha = rope_scaling["alpha"] + rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_alpha, + dtype, + ) + elif "factor" in rope_scaling: + scaling_factor = rope_scaling["factor"] + rotary_emb = MLUDynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + else: + raise ValueError( + "Dynamic rope scaling must contain either 'alpha' or 'factor' field" + ) + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "apply_yarn_scaling", + ) + } + if "mrope_section" in rope_scaling: + extra_kwargs.pop("apply_yarn_scaling", None) + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + scaling_factor=scaling_factor, + **extra_kwargs, + ) + else: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: update original_max_position + ''' + original_max_position = get_long_max_model_max_position_emb( + original_max_position, scaling_factor, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: update original_max_position + ''' + original_max_position = get_long_max_model_max_position_emb( + original_max_position, scaling_factor, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + rotary_emb = MLUDeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + inverse, + **extra_kwargs, + ) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb + +MluHijackObject.apply_hijack( + rotary_embedding, + rotary_embedding.get_rope, + vllm__model_executor__layers__rotary_embedding__get_rope, +) diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/base.py b/vllm_mlu/model_executor/layers/rotary_embedding/base.py new file mode 100644 index 0000000..c7fb103 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/base.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Tuple +import torch + +from vllm.config import get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.v1.attention.backends.utils import ( + get_common_metadata, + MLUCommonAttentionMetadata, +) +from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata +from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context + +logger = init_logger(__name__) + + +@CustomOp.register("rotary_embedding_mlu") +class MLURotaryEmbedding(RotaryEmbedding, CustomOp): + + cu_seq_lens : torch.Tensor = None + max_seq_len : int = None + max_model_len : int = None + is_prompt : bool = False + is_chunked : bool = False + positions_: torch.Tensor = None + chunked_prefill_enabled: bool = False + prefill_cu_seq_lens: torch.Tensor = None + prefill_max_seq_len: int = None + decode_cu_seq_lens: torch.Tensor = None + decode_max_seq_len: int = None + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + inverse: bool = False, + ) -> None: + CustomOp.__init__(self) + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + # TODO(mgoin): disabled for now due to failures + # Flashinfer only supports head_size=64, 128, 256, 512. + # https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202 + # self.use_flashinfer = (self.enabled() + # and dtype in (torch.float16, torch.bfloat16) + # and current_platform.is_cuda() + # and has_flashinfer() + # and self.head_size in [64, 128, 256, 512]) + self.use_flashinfer = False + self.inverse = inverse + + # For vlm v1 + # 1. mlu rope run in eager mode + # 2. all layer use layer0's rope to inference + prefix = "global_rope" + vllm_config = get_current_vllm_config() + self.use_direct_call = False + if not self.use_direct_call: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + pass + else: + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding + from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding + + if MLURotaryEmbedding.max_seq_len != None \ + and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \ + not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)): + logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " + + f"max_position_embedding ({max_position_embeddings}) from model's config.json, " + + f"This may lead to incorrect model outputs or MLU errors. " + + f"Make sure the value is correct and within the model context size. " + + f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.") + self.max_position_embeddings = MLURotaryEmbedding.max_seq_len + cache = self._compute_cos_sin_cache() + + from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding + if isinstance(self, MLULinearScalingRotaryEmbedding): + logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition") + elif is_neox_style: + cache_pos = cache.shape[0] + cache = cache.reshape(cache_pos, 2, -1) + cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1) + else: + cache = cache.repeat_interleave(2, dim=-1) + + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.cos_, self.sin_ = self._get_cos_sin() + @classmethod + def set_mlu_var_v1( + cls, + common_metadata: MLUCommonAttentionMetadata + ) -> None: + cls.unset_mlu_var() + cls.cu_seq_lens = common_metadata.query_start_loc + cls.max_seq_len = common_metadata.max_query_len + cls.is_prompt = common_metadata.is_prefill_only + cls.is_chunked = common_metadata.is_chunked + + # for MLA + attn_metadata = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + _, attn_metadata = next(iter(attn_metadata.items())) + if isinstance(attn_metadata, MLACommonMetadata): + prefill_metadata = attn_metadata.prefill + decode_metadata = attn_metadata.decode + if prefill_metadata: + cls.prefill_max_seq_len = prefill_metadata.max_query_len + cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc + else: + cls.prefill_max_seq_len = cls.max_seq_len + cls.prefill_cu_seq_lens = cls.cu_seq_lens + + if decode_metadata: + cls.decode_max_seq_len = decode_metadata.max_query_len + cls.decode_cu_seq_lens = decode_metadata.query_start_loc + else: + cls.decode_max_seq_len = cls.max_seq_len + cls.decode_cu_seq_lens = cls.cu_seq_lens + + # for sp + sp_context = get_sp_forward_context() + if sp_context is not None and sp_context.is_v32: + prefill_metadata = sp_context.sp_attn_metadata.prefill + cls.is_chunked = True + cls.prefill_max_seq_len = prefill_metadata.max_query_len + cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc + + @classmethod + def unset_mlu_var(cls): + cls.cu_seq_lens = None + cls.max_seq_len = None + cls.is_prompt = False + cls.is_chunked = False + cls.positions_ = None + cls.chunked_prefill_enabled = False + cls.prefill_cu_seq_lens = None + cls.prefill_max_seq_len = None + cls.decode_cu_seq_lens = None + cls.decode_max_seq_len = None + + def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = self.cos_sin_cache.chunk(2, dim=-1) + sin = sin.view(-1, self.rotary_dim) + cos = cos.view(-1, self.rotary_dim) + return cos, sin + + def _get_positions_with_offsets_mlu( + self, + positions: torch.Tensor, + offsets: torch.Tensor + ) -> torch.Tensor: + if offsets.numel() != positions.numel(): + raise Exception("rope offsets numel mismatch with positions, " + f"positions: {positions.numel()}, offsets: {offsets.numel()}") + return (positions + offsets).to(torch.int32) + + def forward_impl( + self, + positions: torch.Tensor, + x: torch.Tensor, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + common_metadata: MLUCommonAttentionMetadata = get_common_metadata() + if common_metadata is None: + num_tokens, head_num, head_size = x.shape + x = mlu_ops.rotary_embedding( + x.view(1, num_tokens, head_num, head_size), + self.sin_, + self.cos_, + positions, + None, + not self.is_neox_style, + True, + False, + num_tokens + ) + return x + else: + cu_seq_lens_ = common_metadata.query_start_loc + + if offsets is not None: + if MLURotaryEmbedding.positions_ is None: + MLURotaryEmbedding.positions_ = ( + self._get_positions_with_offsets_mlu(positions, offsets)) + position_ids = MLURotaryEmbedding.positions_ + discrete = True + elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt: + position_ids = positions + discrete = True + else: + position_ids = None + discrete = False + + x = mlu_ops.rotary_embedding( + x, + self.sin_, + self.cos_, + position_ids, + cu_seq_lens_, + not self.is_neox_style, + discrete, + False, + MLURotaryEmbedding.max_seq_len + ) + return x + + def get_param(self, positions, discrete=False): + interleaved = True + if self.is_neox_style: + interleaved = False + + if discrete: + position_ids = positions + discrete = discrete + else: + if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt: + position_ids = positions + discrete = True + else: + position_ids = None + discrete = False + + return position_ids, interleaved, discrete + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.outer(t, inv_freq) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cos = freqs_cis.real + sin = freqs_cis.imag * (-1 if self.inverse else 1) + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + only_prefill: bool | None = False, + only_decode: bool | None = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.forward_impl(positions, query, offsets) + if key is not None: + self.forward_impl(positions, key, offsets) + return query, key + + +def rope_forward( + positions: torch.Tensor, + x: torch.Tensor, + layer_name: str, + offsets: torch.Tensor | None = None, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + + self.forward_impl(positions, x, offsets) + + +def rope_forward_fake( + positions: torch.Tensor, + x: torch.Tensor, + layer_name: str, + offsets: torch.Tensor | None = None, +) -> None: + return + + +direct_register_custom_op( + op_name="rope_forward", + op_func=rope_forward, + mutates_args=["x"], + fake_impl=rope_forward_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm_mlu/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py new file mode 100644 index 0000000..225dba2 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Tuple +import torch + +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import ( + DeepseekScalingRotaryEmbedding, + yarn_get_mscale, +) +from vllm.model_executor.layers.rotary_embedding.common import ( + rotate_gptj, + rotate_neox, + yarn_find_correction_range, + yarn_linear_ramp_mask, +) +from vllm.platforms import current_platform + +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding + + +class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + inverse: bool = False, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + self.inverse = inverse + MLURotaryEmbedding.__init__( + self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def forward_mlu_rot(self, input, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len): + """only one input rotary implementation""" + if input is None: + return None + if self.rotary_dim < self.head_size: + input_pass = input[..., self.rotary_dim:] + input_rot = input[..., :self.rotary_dim] + input_rot = mlu_ops.rotary_embedding( + input_rot, + self.sin_, + self.cos_, + position_ids, + cu_seq_lens, + interleaved, + discrete, + False, + max_seq_len + ) + + if self.rotary_dim < self.head_size: + input = torch.cat((input_rot, input_pass), dim=-1) + else: + input = input_rot + + return input + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + only_prefill: bool | None = False, + only_decode: bool | None = False, + discrete: bool | None = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + position_ids, interleaved, discrete = self.get_param(positions, discrete) + + cu_seq_lens = MLURotaryEmbedding.cu_seq_lens + max_seq_len = MLURotaryEmbedding.max_seq_len + + # for MLA + attn_metadata = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + _, attn_metadata = next(iter(attn_metadata.items())) + if isinstance(attn_metadata, MLACommonMetadata): + if only_prefill: + cu_seq_lens = MLURotaryEmbedding.prefill_cu_seq_lens + max_seq_len = MLURotaryEmbedding.prefill_max_seq_len + elif only_decode: + cu_seq_lens = MLURotaryEmbedding.decode_cu_seq_lens + max_seq_len = MLURotaryEmbedding.decode_max_seq_len + + query = self.forward_mlu_rot(query, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len) + key = self.forward_mlu_rot(key, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len) + + return query, key + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange( + 0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type, + ) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + device = current_platform.device_type + inv_freq_mask = (( + 1 + - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor).to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale * (-1 if self.inverse else 1) + cache = torch.cat((cos, sin), dim=-1) + return cache + + forward = MLURotaryEmbedding.forward + forward_native = forward_oot \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py b/vllm_mlu/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py new file mode 100644 index 0000000..002c96e --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding + +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding + + +class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_alpha: float, + dtype: torch.dtype, + ) -> None: + self.scaling_alpha = scaling_alpha + MLURotaryEmbedding.__init__( + self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py b/vllm_mlu/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py new file mode 100644 index 0000000..89e4fc5 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding + +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding + + +class MLUDynamicNTKScalingRotaryEmbedding(MLURotaryEmbedding, DynamicNTKScalingRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + MLURotaryEmbedding.__init__( + self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/linear_scaling_rope.py b/vllm_mlu/model_executor/layers/rotary_embedding/linear_scaling_rope.py new file mode 100644 index 0000000..8450420 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/linear_scaling_rope.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Union + +import torch + +from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding.linear_scaling_rope import LinearScalingRotaryEmbedding + +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding + +class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factors: list[float] | float, + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: list[float] = scaling_factors # noqa + MLURotaryEmbedding.__init__( + self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + # Lazy initialized. + self._scaling_factor_to_offset: dict[float, int] + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + device = current_platform.device_type + if self.is_neox_style: + half_dim = self.rotary_dim // 2 + inv_freq = 1.0 / ( + base + ** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device) + % half_dim * 2 / self.rotary_dim) + ) + else: + inv_freq = 1.0 / ( + base + ** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device) + // 2 * 2 / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: list[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: list[int] = [] + device = current_platform.device_type + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float, device=device) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/llama3_rope.py b/vllm_mlu/model_executor/layers/rotary_embedding/llama3_rope.py new file mode 100644 index 0000000..2f0a207 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/llama3_rope.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding + + +class MLULlama3RotaryEmbedding(MLURotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) \ No newline at end of file diff --git a/vllm_mlu/model_executor/layers/rotary_embedding/mrope.py b/vllm_mlu/model_executor/layers/rotary_embedding/mrope.py new file mode 100644 index 0000000..5a9d682 --- /dev/null +++ b/vllm_mlu/model_executor/layers/rotary_embedding/mrope.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch + +from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale +from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding + + +class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: list[int] | None = None, + mrope_interleaved: bool = False, + # YaRN parameters. + *, + scaling_factor: float | None = None, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + if self.scaling_factor is not None: + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) + else: + self.mscale = 1.0 + + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get + # a larger the cos and sin cache. + self.cache_max_position_num = max_position_embeddings * 4 + MLURotaryEmbedding.__init__( + self, + head_size, + rotary_dim, + self.cache_max_position_num, + base, + is_neox_style, + dtype, + ) + + self.mrope_section = mrope_section + self.mrope_interleaved = mrope_interleaved + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def _apply_mrope(self, positions): + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + num_section = len(self.mrope_section) + mrope_section = self.mrope_section * 2 + def _apply(x): + x = torch.cat([ + m[i % num_section] + for i, m in enumerate(x.split(mrope_section, dim=-1)) + ], + dim=-1) + return x + return _apply(cos), _apply(sin) + + def _apply_interleaved_mrope(self, positions): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + mrope_section = self.mrope_section + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + def _apply(x): + x_t = x[0].clone() + x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] + x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] + offset = self.rotary_dim // 2 + x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3] + x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3] + return x_t + return _apply(cos), _apply(sin) + + def precompute_sin_cos_cache( + self, + positions: torch.Tensor + ): + ''' + call this function before forward decoder layers + precompute sin/cos cache for mrope + ''' + if positions.ndim == 1: + return + assert positions.ndim == 2 + assert self.mrope_section + if self.mrope_interleaved: + cos, sin = self._apply_interleaved_mrope(positions) + else: + cos, sin = self._apply_mrope(positions) + self.mrope_cos_cache = cos + self.mrope_sin_cache = sin + self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device) + num_tokens = positions.shape[-1] + self.mrope_cu_seq_lens[1] = num_tokens + + def forward_oot( + self, + positions: torch.Tensor, + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert positions.ndim == 1 or positions.ndim == 2 + if positions.ndim == 1: + return MLURotaryEmbedding.forward_oot(self, positions, x) + assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\ + "call precompute_sin_cos_cache first!" + num_tokens = positions.shape[-1] + x = mlu_ops.rotary_embedding(x, + self.mrope_sin_cache, + self.mrope_cos_cache, + None, + self.mrope_cu_seq_lens, + not self.is_neox_style, + False, + False, + num_tokens) + return x + + forward = MLURotaryEmbedding.forward diff --git a/vllm_mlu/model_executor/layers/sparse_moe_mlp.py b/vllm_mlu/model_executor/layers/sparse_moe_mlp.py new file mode 100644 index 0000000..3f4ddfa --- /dev/null +++ b/vllm_mlu/model_executor/layers/sparse_moe_mlp.py @@ -0,0 +1,1271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +"""Inference-only MOE model.""" +from typing import Any, List, Optional, Dict, Tuple +from dataclasses import dataclass + +import torch +from torch import nn + +from vllm.config import get_current_vllm_config +from vllm.distributed import ( + get_moe_tensor_parallel_rank, + get_moe_tensor_parallel_world_size, + get_moe_tensor_parallel_group, + get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, + get_moe_expert_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + get_dp_group, + divide, +) +from vllm.distributed.utils import divide +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.fused_moe import fused_grouped_topk +from vllm.utils.torch_utils import get_dtype_size +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.utils import maybe_disable_graph_partition +from vllm.platforms import current_platform + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu._mlu_utils import * +from vllm_mlu.model_executor.layers.feed_forward import FeedForward +from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig +from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig +from vllm_mlu.distributed.parallel_state import( + CnclEP, cnclep_dispatch, cnclep_combine) + +from vllm_mlu.distributed.parallel_state import( + CnclEP, cnclep_dispatch, cnclep_combine) + +@dataclass +class MoeGroupInfo: + tp_rank: int + tp_size: int + dp_rank: int + dp_size: int + moe_tp_size: int + moe_tp_rank: int + moe_ep_size: int + moe_ep_rank: int + moe_group: Any + moe_kwargs: dict + + def __init__(self): + self.tp_rank = get_tp_group().rank_in_group + self.tp_size = get_tp_group().world_size + self.dp_rank = get_dp_group().rank_in_group + self.dp_size = get_dp_group().world_size + + self.moe_tp_size = get_moe_tensor_parallel_world_size() + self.moe_tp_rank = get_moe_tensor_parallel_rank() + self.moe_tp_group = get_moe_tensor_parallel_group() + self.moe_ep_size = get_moe_expert_parallel_world_size() + self.moe_ep_rank = get_moe_expert_parallel_rank() + self.moe_ep_group = get_moe_expert_parallel_group() + self.moe_group = self.moe_ep_group if self.moe_ep_size > 1 else self.moe_tp_group + self.moe_kwargs = {"tp_group": self.moe_tp_group} + + +class SqrtSoftPlusTopK(torch.nn.Module): + + def __init__(self, + score_func: str, + use_hash: bool, + n_routed_experts: int, + n_activated_experts: int, + route_scale: float, + vocab_size: int, + prefix: str = ""): + super().__init__() + self.topk = n_activated_experts + self.n_activated_experts = n_activated_experts + self.score_func = score_func + self.route_scale = route_scale + self.use_hash = use_hash + self.n_routed_experts = n_routed_experts + self.vocab_size = vocab_size + if self.use_hash: + self.tid2eid = nn.Parameter( + torch.randint(0, + self.n_activated_experts, + (self.vocab_size, self.n_activated_experts), + dtype=torch.int32), + requires_grad=False, + ) + self.bias = None + else: + self.tid2eid = None + self.bias = nn.Parameter(torch.empty(self.n_routed_experts, dtype=torch.float32), requires_grad=False) + + def forward(self, scores: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.score_func == "sqrtsoftplus" + return mlu_ops.moe_softplus_topk( + scores, + self.topk, + input_ids, + self.tid2eid, + self.bias, + self.route_scale, + ) + + +# This is used by the Deepseek-V2 and Deepseek-V3 model +''' +============================= +Modify by vllm_mlu +============================= +@brief: comment out decorator torch.compiler to avoid triton bug for torch_mlu 2.9.1 +''' +# @torch.compile( +# dynamic=True, +# backend=current_platform.simple_compile_backend, +# options=maybe_disable_graph_partition(current_platform.simple_compile_backend), +# ) +''' +================== +End of MLU Hijack +================== +''' +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if ( + envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + and current_platform.is_cuda() + and num_expert_group <= 32 + and topk <= 32 + and e_score_correction_bias is not None + ): + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + ) + + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.size(0) + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=use_sorted + ) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +class SparseMoeMlp(nn.Module): + """ + Tensor Parallel evenly splits each expert's weight and distributes them to different ranks, + which means each rank holds partial weight of all experts. + While Expert Parallel evenly distributes some of the experts' full weight to different ranks, + which means each rank holds part of the experts' full weight. + + As a result, each rank in the Tensor Parallel group receives all tokens' hidden states for all experts, + then computes using the partial weights, while for Expert Parallel, each rank only receives + part of tokens' hidden states for experts on this rank, then computes using the full weights. + + When both Tensor Parallel and Expert Parallel are enabled, each rank handles + a portion of the expert weights matrices (as in EP mode) and these weights are further sliced + across ranks (as in TP mode). This hybrid approach aims to balance the workload more evenly across ranks, + enhancing efficiency and reducing the likelihood of bottlenecks associated with EP mode alone. + """ + reduce_weight : torch.Tensor = None + expert_id : torch.Tensor = None + is_expert_avg : bool = False + max_batched_token : int = 2048 + random_idx : int = 0 + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + up_proj_name: str, + is_gated: bool, + down_proj_name: str, + has_bias: bool, + skip_bias_add: bool = False, + renormalize:bool = False, + hidden_act: str = "silu", + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + is_use_fused_moe: bool = False, + expert_group: int | None = 1, + topk_group: int | None = 1, + scoring_func: str = "softmax", + topk_method: str = "", + routed_scaling_factor: float = 1.0, + tp_group: Any = None, + use_all2all: bool = False, + use_hash: bool = False, + vocab_size: int = 0, + prefix: str = "", + init_avg_moe: bool = True, + ): + super().__init__() + if tp_group is None: + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + else: + self.tp_rank = tp_group.rank_in_group + self.tp_size = tp_group.world_size + self.tp_group = tp_group + self.use_hash = use_hash + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj_name = up_proj_name + self.is_gated = is_gated + self.down_proj_name = down_proj_name + self.has_bias = has_bias + self.renormalize = renormalize + self.hidden_act = hidden_act + self.quant_config = quant_config + self.is_use_fused_moe = is_use_fused_moe + self.expert_group = expert_group + self.topk_group = topk_group + self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor + self.use_all2all = use_all2all + self.vocab_size = vocab_size + # fused_moe doesn't support weightonly quantization + if isinstance(quant_config, WeightOnlyConfig): + self.is_use_fused_moe = False + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + # [num_bytes_hidden_states, num_bytes_reduce_weights, num_bytes_expert_id] + self.precompute_dim_bytes_list: List[int] | None = None + # sum(self.precompute_dim_bytes_list) + self.precompute_dim_bytes = -1 + + moe_group_info = MoeGroupInfo() + self.moe_tp_size = moe_group_info.moe_tp_size + self.moe_tp_rank = moe_group_info.moe_tp_rank + self.moe_ep_size = moe_group_info.moe_ep_size + self.moe_ep_rank = moe_group_info.moe_ep_rank + self.dp_size = moe_group_info.dp_size + self.dp_rank = moe_group_info.dp_rank + self.moe_group = moe_group_info.moe_group + self.moe_kwargs = moe_group_info.moe_kwargs + + vllm_config = get_current_vllm_config() + model_config = getattr(vllm_config, "model_config", None) + hf_text_config = getattr(model_config, "hf_text_config", None) + self.model_type = getattr(hf_text_config, "model_type", "") + + if (init_avg_moe and + VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg): + n_tokens = SparseMoeMlp.max_batched_token * self.dp_size + expert_group = self.moe_ep_size + val = 1.0 / float(num_experts) + SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32) + import math + if VLLM_RANDOM_MOE_EN: + import numpy as np + # example deepseekv2: experts 160 topk 6 + # avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,...... + array = np.stack([np.random.permutation(num_experts)[:top_k] for _ in range(n_tokens)]) + table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32) + else: + # example deepseekv2: experts 160 + # avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,...... + batch_table = math.ceil(n_tokens * top_k / num_experts) * num_experts + hi_val = batch_table // num_experts + table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view( + hi_val, expert_group, num_experts // expert_group).transpose(1, 2) + SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k) + SparseMoeMlp.is_expert_avg = True + # NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would + # contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr + self.skip_bias_add = True if self.moe_tp_rank > 0 else False + + assert self.num_total_experts >= self.moe_ep_size, ( + f"need num_total_experts:{self.num_total_experts} >= moe_ep_size:{self.moe_ep_size}") + + assert self.intermediate_size % self.moe_tp_size == 0, ( + f"need intermediate_size:{self.intermediate_size} % moe_tp_size:{self.moe_tp_size} == 0") + + self.num_experts_per_rank = (self.num_total_experts + self.moe_ep_size - 1) // self.moe_ep_size + if self.moe_ep_rank + 1 == self.moe_ep_size and self.num_total_experts % self.moe_ep_size: + self.num_experts_per_rank = self.num_total_experts % self.moe_ep_size + + self.start_expert_id = self.moe_ep_rank * ((self.num_total_experts + self.moe_ep_size - 1) // self.moe_ep_size) + self.end_expert_id = self.start_expert_id + self.num_experts_per_rank + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None, + ) + if self.is_deepseek_v4: + self.deepseekv4_topk = SqrtSoftPlusTopK( + score_func=self.scoring_func, + use_hash=self.use_hash, + n_routed_experts=self.num_total_experts, + n_activated_experts=self.top_k, + route_scale=self.routed_scaling_factor, + vocab_size=self.vocab_size, + prefix=f"{prefix}.topk", + ) + if topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(self.num_total_experts, device="mlu")) + else: + self.gate.e_score_correction_bias = None + self.is_fp8_block_wise = (isinstance(self.quant_config, Fp8Config) + and (self.quant_config.weight_block_size is not None)) + if self.is_fp8_block_wise: + self.experts = FusedMoE( + num_experts=self.num_experts_per_rank, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + reduce_results=False, + renormalize=self.renormalize, + quant_config=self.quant_config, + use_grouped_topk=True, + num_expert_group=self.expert_group, + topk_group=self.topk_group, + prefix=f"{prefix}.experts", + scoring_func=self.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + else: + self.experts = nn.ModuleList([ + FeedForward(hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + up_proj_name=self.up_proj_name, + is_gated=self.is_gated, + down_proj_name=self.down_proj_name, + bias=self.has_bias, + quant_config=self.quant_config, + skip_bias_add=self.skip_bias_add, + reduce_results=False, + prefix=f"experts.{idx}", + **self.moe_kwargs) for idx in range(self.num_experts_per_rank) + ]) + + self.init_pack_param() + + @property + def is_deepseek_v4(self): + return self.scoring_func == 'sqrtsoftplus' + + @property + def is_kimi_k2(self): + kimi_k2_scoring_func = "sigmoid" + kimi_k2_expert_group_num = 1 + kimi_k2_experts_num = 384 + return (self.scoring_func == kimi_k2_scoring_func + and self.expert_group == kimi_k2_expert_group_num + and self.num_total_experts == kimi_k2_experts_num) + + @property + def is_glm4_moe(self): + return self.model_type == "glm4_moe" + + def init_pack_param(self): + self.w13 = None + self.w2 = None + self.b13 = None + self.b2 = None + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a13_scale_all_experts = None + self.a2_scale = None + self.pack_params_done = False + self.pack_params_after_loading_done = False + + + def map_param_data(self, param_list, is_use_first_data=False): + if len(param_list) == 0: + return None + + if is_use_first_data or len(param_list) == 1: + first_data = param_list[0].data + for param in param_list[1: -1]: + param.data = first_data + if is_use_first_data: + out_param = first_data.view_as(param_list[0]) + else: + out_param = first_data.view(len(param_list), *first_data.shape) + else: + packed_param = torch._utils._flatten_dense_tensors(param_list) + data_list = torch._utils._unflatten_dense_tensors(packed_param, param_list) + for data, param in zip(data_list, param_list): + param.data = data + out_param = packed_param.view(len(param_list), *data_list[0].shape) + + torch.mlu.empty_cache() + + return out_param + + + def pack_unquantized_params(self, w13, w2, b13, b2): + for expert in self.experts: + up_proj = getattr(expert, self.up_proj_name) + down_proj = getattr(expert, self.down_proj_name) + w13.append(up_proj.weight) + w2.append(down_proj.weight) + if self.has_bias: + b13.append(up_proj.bias) + b2.append(down_proj.bias) + + + def pack_smoothquant_params(self, w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale): + for expert in self.experts: + up_proj = getattr(expert, self.up_proj_name) + down_proj = getattr(expert, self.down_proj_name) + w13.append(up_proj.qweight) + w2.append(down_proj.qweight) + if self.has_bias: + b13.append(up_proj.bias) + b2.append(down_proj.bias) + w13_scale.append(up_proj.per_channel_scale) + w2_scale.append(down_proj.per_channel_scale) + if self.quant_config.input_quant_method == "per_token": + a13_scale.append(up_proj.smooth) + a2_scale.append(down_proj.smooth) + else: + a13_scale.append(up_proj.scale_to_int) + a2_scale.append(down_proj.scale_to_int) + + + def pack_weightonly_params(self, w13, w2, b13, b2, w13_scale, w2_scale): + for expert in self.experts: + up_proj = getattr(expert, self.up_proj_name) + down_proj = getattr(expert, self.down_proj_name) + w13.append(up_proj.qweight) + w2.append(down_proj.qweight) + if self.has_bias: + b13.append(up_proj.bias) + b2.append(down_proj.bias) + w13_scale.append(up_proj.scales) + w2_scale.append(down_proj.scales) + + def pack_fp8_params_without_activation_scheme(self, w13, w2, b13, b2, w13_scale, w2_scale): + for expert in self.experts: + up_proj = getattr(expert, self.up_proj_name) + down_proj = getattr(expert, self.down_proj_name) + w13.append(up_proj.weight) + w2.append(down_proj.weight) + if self.has_bias: + b13.append(up_proj.bias) + b2.append(down_proj.bias) + w13_scale.append(up_proj.weight_scale) + w2_scale.append(down_proj.weight_scale) + + + def pack_params(self): + if self.pack_params_done or self.is_fp8_block_wise: + return + + w13 = [] + w2 = [] + b13 = [] + b2 = [] + w13_scale = [] + w2_scale = [] + a13_scale = [] + a2_scale = [] + + if self.quant_config is None: + self.pack_unquantized_params(w13, w2, b13, b2) + elif isinstance(self.quant_config, SmoothQuantConfig): + self.pack_smoothquant_params(w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale) + elif isinstance(self.quant_config, WeightOnlyConfig): + self.pack_weightonly_params(w13, w2, b13, b2, w13_scale, w2_scale) + elif isinstance(self.quant_config, Fp8Config) and self.quant_config.activation_scheme == 'dynamic': + self.pack_fp8_params_without_activation_scheme(w13, w2, b13, b2, w13_scale, w2_scale) + else: + raise ValueError(f'Unsupported quantization:{self.quant_config}') + + # pack weight + self.w13 = self.map_param_data(w13) + self.w2 = self.map_param_data(w2) + + # pack bias + if self.has_bias: + self.b13 = self.map_param_data(b13) + # NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would + # contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr + if self.skip_bias_add is False: + self.b2 = self.map_param_data(b2) + + + # pack weight scale + if len(w13_scale) > 0: + self.w13_scale = self.map_param_data(w13_scale) + if len(w2_scale) > 0: + self.w2_scale = self.map_param_data(w2_scale) + + # pack activate scale + if len(a13_scale) > 0: + self.a13_scale = self.map_param_data(a13_scale) + if len(a2_scale) > 0: + self.a2_scale = self.map_param_data(a2_scale) + + self.pack_params_done = True + + def pack_params_after_loading(self): + if self.pack_params_after_loading_done or self.is_fp8_block_wise: + return + + if isinstance(self.quant_config, SmoothQuantConfig) and self.quant_config.group_size > 1 and self.is_use_fused_moe: + assert self.w13_scale is not None and self.w2_scale is not None, "w13_scale and w2_scale must be not None" + self.w13_scale = self.w13_scale.permute(2, 0, 1).contiguous() + self.w2_scale = self.w2_scale.permute(2, 0, 1).contiguous() + + # pack smooth variables for moe_quantize if fp8 + # FIXME: replace smooth to None after tmo supports. + if isinstance(self.quant_config, Fp8Config): + expert_size = self.w13.shape[0] + fp8_smooth_2_hidden_size = self.w13.shape[1] // 2 if self.is_gated else self.w13.shape[1] + self.fp8_smooth_1 = torch.ones([expert_size, self.hidden_size], device=self.w13.device, dtype=torch.float32) + self.fp8_smooth_2 = torch.ones([expert_size, fp8_smooth_2_hidden_size], device=self.w13.device, dtype=torch.float32) + + self.pack_params_done = True + self.pack_params_after_loading_done = True + + def get_precompute_dim_bytes_list(self, hidden_states_dtype: torch.dtype) -> List[int]: + ''' + get the number of bytes of the hidden dimension corresponding to + hidden_states, reduce_weight, and expert_id, respectively. + ''' + if not self.precompute_dim_bytes_list: + hidden_states_size = self.hidden_size * get_dtype_size(hidden_states_dtype) + reduce_weights_size = self.top_k * get_dtype_size(torch.float) + expert_id_size = self.top_k * get_dtype_size(torch.int32) + self.precompute_dim_bytes_list = [ + hidden_states_size, reduce_weights_size, expert_id_size + ] + return self.precompute_dim_bytes_list + + def get_precompute_dim_bytes(self, hidden_states_dtype: torch.dtype) -> int: + ''' + get the hidden dimension in bytes for a packed hidden states that + include + [hidden_states | reduce_weights | expert_id] + ''' + if self.precompute_dim_bytes < 0: + self.precompute_dim_bytes = sum(self.get_precompute_dim_bytes_list(hidden_states_dtype)) + return self.precompute_dim_bytes + + def reduce_results(self, final_hidden_states: torch.Tensor, reduce_results: bool = True): + if reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states, self.moe_group) + return final_hidden_states + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor | None = None) -> torch.Tensor: + orig_hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # expert_logits: [num_tokens, self.num_experts_per_rank] + expert_logits, _ = self.gate(hidden_states) + final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual) + final_hidden_states = self.reduce_results(final_hidden_states) + output = final_hidden_states.view(orig_hidden_states_shape) + return output + + def precompute_weight_expert_id( + self, + packed: torch.Tensor, + ) -> torch.Tensor: + ''' + pre compute gate and softmax_topk/sigmoid_topk, and fill the weight and + expert_id part as below + in = [ hidden_states | ------ | --------- ] + [ bf16 | fp32 | int32 ] + + out = [ hidden_states | weight | expert_id ] + [ bf16 | fp32 | int32 ] + ''' + hidden_states_size, weight_size, expert_id_size = self.get_precompute_dim_bytes_list(packed.dtype) + packed_int8 = packed.view(torch.int8) + hidden_states = packed_int8[:, : hidden_states_size].view(packed.dtype) + router_logits, _ = self.gate(hidden_states) + topk=self.top_k + renormalized=self.renormalize + reduce_weight = packed_int8[:, hidden_states_size : hidden_states_size + weight_size].view(torch.float) + expert_id = packed_int8[:, hidden_states_size + weight_size :].view(torch.int32) + if self.scoring_func == "softmax": + reduce_weight, expert_id = mlu_ops.moe_softmax_topk(router_logits, topk, renormalized, self.expert_group, + self.topk_group, route_scale=self.routed_scaling_factor, + reduce_weight=reduce_weight, + expert_id=expert_id) + elif self.scoring_func == "sigmoid": + reduce_weight, expert_id = mlu_ops.moe_sigmoid_topk(router_logits, topk, renormalized, + self.expert_group, self.topk_group, + self.routed_scaling_factor, + self.gate.e_score_correction_bias, + reduce_weight=reduce_weight, + expert_id=expert_id) + else: + raise ValueError(f"Unsupported scoring function: {self.scoring_func}") + return packed + + def forward_experts(self, hidden_states, expert_logits, residual: torch.Tensor | None = None, + shared_output: torch.Tensor | None = None, + input_ids: torch.Tensor | None = None): + assert not (residual is not None and shared_output is not None) + residual_ = None if self.tp_rank > 0 else residual + + # change only for deepseek_model without residual_ + if shared_output is not None: + residual_ = shared_output + + if self.is_fp8_block_wise: + output = self.experts(hidden_states=hidden_states, + router_logits=expert_logits) * self.routed_scaling_factor + if residual_ is not None: + output = output + residual_ + return output + + use_forward_group_experts = (self.is_use_fused_moe + and ( + self.is_kimi_k2 + or self.is_glm4_moe + or self.is_deepseek_v4 + or self.expert_group != 1) + ) + if use_forward_group_experts: + final_hidden_states = self.forward_group_experts( + hidden_states, + expert_logits, + residual_, + input_ids=input_ids, + ) + elif self.is_use_fused_moe: + self.pack_params() + self.pack_params_after_loading() + final_hidden_states = mlu_ops.fused_moe(hidden_states=hidden_states, + gating_output=expert_logits, + w1=self.w13, + w2=self.w2, + bias1=self.b13, + bias2=self.b2, + residual=residual_, + input_smooth=self.a13_scale, + act_smooth=self.a2_scale, + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + topk=self.top_k, + renormalize=self.renormalize, + gated=self.is_gated, + act_mode=self.hidden_act, + start_expert_id=self.start_expert_id, + avg_moe=VLLM_AVG_MOE_EN, + class_reduce_weight=SparseMoeMlp.reduce_weight, + class_expert_id=SparseMoeMlp.expert_id, + ) + else: + final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits) + if residual_ is not None: + final_hidden_states = final_hidden_states + residual_ + return final_hidden_states + + + def forward_experts_nofused(self, hidden_states, expert_logits): + hidden_states_shape = hidden_states.shape + if self.scoring_func == "softmax": + topk_values, topk_indices = self.topk_softmax(expert_logits) + elif self.scoring_func == "sigmoid": + gating_output = expert_logits.to(torch.float32) + gating_output = gating_output.view(-1, gating_output.size(-1)) + topk_values, topk_indices = grouped_topk(hidden_states, gating_output, self.top_k, self.renormalize, + self.expert_group, self.topk_group, self.scoring_func, + self.routed_scaling_factor, self.gate.e_score_correction_bias) + topk_values = topk_values.to(hidden_states.dtype) + topk_indices = topk_indices.to(torch.int64) + expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx( + topk_indices) + # no expert is routed, then expand_gather_idx, expand_scatter_idx has no item, + # expand_token_count and expand_cusum_token_count has item but the value is all zero + # so this rank should only return final_hidden_states with zero value + if expand_gather_idx.numel() == 0: + final_hidden_states = torch.zeros_like(hidden_states, + dtype=hidden_states.dtype, + device=hidden_states.device) + return final_hidden_states + + expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx) + + expand_output_list = [] + expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id + + 1] - cusum_token_count[self.start_expert_id] + for expert_idx, num_tokens_per_expert in enumerate(expand_token_count): + if num_tokens_per_expert > 0: + expert_hidden_states = expand_hidden_states[ + expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]] + expert_output = self.experts[expert_idx](expert_hidden_states) + expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output + expand_output_list.append(expert_output) + expand_output = torch.cat(expand_output_list, dim=0) + final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape, + topk_values) + + return final_hidden_states + + def forward_group_experts(self, hidden_states, gating_output, residual_, input_ids: torch.Tensor | None = None): + # determine if hidden_states packs reduce_weight and expert_id in it, + # and if so, extract them. + orig_dtype = hidden_states.dtype + device = hidden_states.device + hidden_states_int8 = hidden_states.view(torch.int8) + hidden_states_size, weight_size, _ = self.get_precompute_dim_bytes_list(orig_dtype) + packed_dim = self.get_precompute_dim_bytes(orig_dtype) + is_precompute_weight_expert_id: bool = (hidden_states_int8.shape[1] == packed_dim) + if is_precompute_weight_expert_id: + assert gating_output is None + hidden_states = hidden_states_int8[:, : hidden_states_size].view(orig_dtype) + reduce_weight = hidden_states_int8[:, hidden_states_size : hidden_states_size + weight_size].view(torch.float) + expert_id = hidden_states_int8[:, hidden_states_size + weight_size :].view(torch.int32) + + is_fp8_quant = isinstance(self.quant_config, Fp8Config) + ori_input_shape = hidden_states.shape + dtype = hidden_states.dtype + self.pack_params() + self.pack_params_after_loading() + w1=self.w13.to(device) if self.w13 is not None else None + w2=self.w2.to(device) if self.w2 is not None else None + bias1=self.b13.to(device) if self.b13 is not None else None + bias2=self.b2.to(device) if self.b2 is not None else None + input_smooth=self.a13_scale.to(device) if self.a13_scale is not None else None + act_smooth=self.a2_scale.to(device) if self.a2_scale is not None else None + w1_scale=self.w13_scale.to(device) if self.w13_scale is not None else None + w2_scale=self.w2_scale.to(device) if self.w2_scale is not None else None + topk=self.top_k + renormalized=self.renormalize + gated=self.is_gated + act_mode=self.hidden_act + quant_input=None + + start_expert_id=self.start_expert_id + expert_size = w1.size(0) + max_m = hidden_states.shape[0] + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None + # Check smooth quant parameters. + per_token_sq = False + if not is_fp8_quant: + check_list = [input_smooth, act_smooth, w1_scale, w2_scale] + if all(x is not None for x in check_list): + per_token_sq = True + + if not (all(x is None for x in check_list) or all(x is not None for x in check_list)): + raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present " + "and absent at the same time.") + + # softmax_topk + if not is_precompute_weight_expert_id: + gating_output = gating_output.view(-1, gating_output.size(-1)) + if self.scoring_func == "softmax": + reduce_weight, expert_id = mlu_ops.moe_softmax_topk(gating_output, topk, renormalized, self.expert_group, + self.topk_group, route_scale=self.routed_scaling_factor) + elif self.scoring_func == "sigmoid": + reduce_weight, expert_id = mlu_ops.moe_sigmoid_topk(gating_output, topk, renormalized, + self.expert_group, self.topk_group, + self.routed_scaling_factor, + self.gate.e_score_correction_bias) + elif self.scoring_func == "sqrtsoftplus": + assert hasattr(self,"deepseekv4_topk") + reduce_weight, expert_id = self.deepseekv4_topk( + gating_output, + input_ids, + ) + else: + raise ValueError(f"Unsupported scoring function: {self.scoring_func}") + + if VLLM_RANDOM_MOE_EN: + n_tokens = hidden_states.shape[0] + token_len = SparseMoeMlp.expert_id.size(0) + SparseMoeMlp.random_idx = 0 if token_len == n_tokens else (SparseMoeMlp.random_idx+1) % (token_len-n_tokens) + n_tokens = hidden_states.shape[0] + reduce_weight = SparseMoeMlp.reduce_weight[:n_tokens] + expert_id = SparseMoeMlp.expert_id[SparseMoeMlp.random_idx: SparseMoeMlp.random_idx + n_tokens] + elif VLLM_AVG_MOE_EN: + n_tokens = hidden_states.shape[0] + reduce_weight = SparseMoeMlp.reduce_weight[:n_tokens] + expert_id = SparseMoeMlp.expert_id[:n_tokens] + # gen_idx + expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, self.num_total_experts) + # check quant + if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token': + quant_input, input_scale = mlu_ops.moe_quantize( + hidden_states, + self.fp8_smooth_1, + zero=None, + token_count=token_count[start_expert_id:start_expert_id+expert_size], + gather_index=expand_idx, + gather_index_start_position=cusum_token_count[start_expert_id].unsqueeze(0), + output=None, + output_scale=None, + dynamic_quant=True, + quant_type=torch.float8_e4m3fn + ) + elif per_token_sq: + quant_input, input_scale = mlu_ops.moe_quantize(hidden_states, + input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx, + cusum_token_count[start_expert_id].unsqueeze(0)) + else: + expand_hidden_states = mlu_ops.moe_expand_input( + hidden_states, + expand_idx, + cusum_token_count, + start_expert_id, + expert_size, + ) + + if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq: + gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, + input_scale, w1_scale, dtype, max_m) + else: + gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, max_m) + # add_bias_active + if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token': + act_out = mlu_ops.moe_active(gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2], bias=bias1, cusum_token_count=cusum_token_count, start_expert_id=start_expert_id, expert_size=expert_size) + quant_input, input_scale = mlu_ops.moe_quantize( + act_out, + self.fp8_smooth_2, + zero=None, + token_count=token_count[start_expert_id:start_expert_id+expert_size], + gather_index=None, + gather_index_start_position=None, + output=quant_input[:,:act_out.shape[-1]], + output_scale=None, + dynamic_quant=True, + quant_type=torch.float8_e4m3fn + ) + elif per_token_sq: + quant_input = quant_input[:, :gemm_out.shape[-1] // 2] + input_scale = input_scale[:gemm_out.shape[0]] + quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None, + token_count[start_expert_id:start_expert_id+expert_size], + output=quant_input, + output_scale=input_scale, + act_mode=act_mode, + is_gated=self.is_gated) + + if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq: + # Remove the reference to gemm_out tensor. + # If that was the only reference, the tensor’s memory becomes eligible for deallocation + # So that we can reuse this memory for the new allocation of next gemm operation + del gemm_out + gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, input_scale, w2_scale, dtype, max_m) + else: + act_out = mlu_ops.moe_active(gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2], bias1, cusum_token_count, start_expert_id, expert_size) + gemm_out = mlu_ops.group_gemm(act_out, w2, + token_count[start_expert_id:start_expert_id+expert_size], + None, None, None, None, max_m) + + # we reuse the memory of hidden_states to store the output + output = mlu_ops.moe_combine_result( + gemm_out, reduce_weight, combine_idx, + residual_, cusum_token_count, start_expert_id, + expert_size, bias2, + output=hidden_states if not is_precompute_weight_expert_id else None) + return output.view(ori_input_shape) + + + def topk_softmax(self, expert_logits): + # expert_logits: [num_tokens, self.num_experts_per_rank] + # topk_values: [num_tokens, self.top_k] + # topk_indices: [num_tokens, self.top_k] + if self.renormalize: + topk_values, topk_indices = torch.topk(expert_logits, self.top_k, dim=-1) + topk_values = torch.softmax(topk_values, -1) + else: + router_probs = torch.softmax(expert_logits, -1) + topk_values, topk_indices = torch.topk(router_probs, self.top_k, dim=-1) + + return topk_values, topk_indices + + + def generate_gather_idx(self, topk_indices): + device = topk_indices.device + # gather_expand_idx: [num_tokens * self.top_k] + sorted_expert_id, indices = topk_indices.flatten().sort() + gather_idx = indices // self.top_k + + seqs = torch.arange(indices.numel(), dtype=indices.dtype, device=indices.device) + scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs) + + # token_count: [self.num_experts_per_rank] + partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True) + zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device) + token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count) + # cusum_token_count: [self.num_experts_per_rank + 1] + cusum_token_count = torch.cat( + [torch.tensor([0], dtype=token_count.dtype, device=device), + token_count.cumsum(dim=0)]) + + num_tokens_before_expert = cusum_token_count[self.start_expert_id] + num_tokens_including_expert = cusum_token_count[self.end_expert_id] + + expand_gather_idx = gather_idx[num_tokens_before_expert:num_tokens_including_expert] + expand_token_count = token_count[self.start_expert_id:self.end_expert_id] + + return expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count + + + def expand_input(self, hidden_states, expand_gather_idx): + expand_hidden_states = hidden_states[expand_gather_idx] + return expand_hidden_states + + + def combine_moe(self, expand_output, scatter_idx, cusum_token_count, hidden_states_shape, topk_values): + num_tokens, hidden_size = hidden_states_shape + num_tokens_before_expert = cusum_token_count[self.start_expert_id] + num_tokens_after_expert = cusum_token_count[-1] - cusum_token_count[self.end_expert_id] + + expand_output_before_expert = torch.zeros((num_tokens_before_expert, hidden_size), + dtype=expand_output.dtype, + device=expand_output.device) + expand_output_after_expert = torch.zeros((num_tokens_after_expert, hidden_size), + dtype=expand_output.dtype, + device=expand_output.device) + unscatted_output = torch.cat([expand_output_before_expert, expand_output, expand_output_after_expert], dim=0) + scatter_output = unscatted_output[scatter_idx] + hidden_states_weight = topk_values.flatten().unsqueeze(-1) + weighted_hidden_states = scatter_output * hidden_states_weight + unreduced_hidden_states = weighted_hidden_states.view(num_tokens, self.top_k, hidden_size) + final_hidden_states = unreduced_hidden_states.sum(dim=1) + + return final_hidden_states + + def prepare_for_cnclep(self, cnclep: CnclEP) -> None: + if cnclep.use_quant_dispatch: + self.prepare_for_cnclep_quant_dispatch(cnclep) + else: + self.prepare_for_cnclep_bf16(cnclep) + + def prepare_for_cnclep_bf16(self, cnclep: CnclEP) -> None: + # prepare buffers for the forward process + buffer = cnclep.buffer + self.dispatch_send_buffer = buffer.dispatch_send_token_tensor + self.dispatch_recv_buffer = buffer.dispatch_recv_token_tensor + self.combine_send_buffer = buffer.combine_send_token_tensor + self.combine_recv_buffer = buffer.combine_recv_token_tensor + self.max_num_tokens_per_rank = cnclep.max_num_tokens_per_rank + + # get sizes in bytes + self.dispatch_token_size = self.config.hidden_size * 2 + # [nranks, 2] + self.dispatch_recv_layout = torch.empty((self.moe_ep_size, 2), dtype=torch.int32, device="mlu") + # [num_total_experts] + self.dispatch_recv_token_num = torch.empty((self.num_total_experts), dtype=torch.int32, device="mlu") + + self.max_num_tokens_recv = self.max_num_tokens_per_rank * self.moe_ep_size + self.max_num_tokens_per_expert = divide(self.max_num_tokens_recv, self.top_k) + + # input to the first groupgemm, in which tokens are ordered by experts. + input_recv_size = self.max_num_tokens_recv * self.dispatch_token_size + self.input_recv = ( + self.combine_send_buffer[:input_recv_size] + .view(self.max_num_tokens_recv, self.dispatch_token_size) + ) + # kept for code without compute-communication parallel, which may have + # become stale. + self.quant_input_recv = self.input_recv + + def prepare_for_cnclep_quant_dispatch(self, cnclep: CnclEP) -> None: + # prepare smooth parameter for _all_ experts globally, which would be needed during + # input quantization before dispatch. + assert self.a13_scale is not None, "a13_scale has not been loaded" + self.a13_scale_all_experts = torch.zeros((self.num_total_experts, self.hidden_size), + dtype=self.a13_scale.dtype, + device=self.a13_scale.device) + torch.distributed.all_gather_into_tensor(self.a13_scale_all_experts, + self.a13_scale, + group=self.moe_group.device_group, + async_op=False) + + # prepare buffers for the forward process + buffer = cnclep.buffer + self.dispatch_send_buffer = buffer.dispatch_send_token_tensor + self.dispatch_recv_buffer = buffer.dispatch_recv_token_tensor + self.combine_send_buffer = buffer.combine_send_token_tensor + self.combine_recv_buffer = buffer.combine_recv_token_tensor + self.max_num_tokens_per_rank = cnclep.max_num_tokens_per_rank + + # get sizes in bytes + self.quant_size = self.hidden_size + self.scale_size = get_dtype_size(torch.float32) + self.dispatch_token_size = self.quant_size + self.scale_size + # [nranks, 2] + self.dispatch_recv_layout = torch.empty((self.moe_ep_size, 2), dtype=torch.int32, device="mlu") + # [num_total_experts] + self.dispatch_recv_token_num = torch.empty((self.num_total_experts), dtype=torch.int32, device="mlu") + + self.max_num_tokens_recv = self.max_num_tokens_per_rank * self.moe_ep_size + self.max_num_tokens_per_expert = divide(self.max_num_tokens_recv, self.top_k) + + quant_input_recv_size = self.max_num_tokens_recv * self.quant_size + input_scale_recv_size = self.max_num_tokens_recv * self.scale_size + self.quant_input_recv = ( + self.combine_send_buffer[:quant_input_recv_size] + .view(self.max_num_tokens_recv, self.quant_size)) + self.input_scale_recv = ( + self.combine_send_buffer[quant_input_recv_size : quant_input_recv_size + input_scale_recv_size] + .view(self.max_num_tokens_recv, self.scale_size)) + + def forward_all2all( + self, + hidden_states: torch.Tensor, + gate: ReplicatedLinear, + streams: Optional[Dict[str, torch.mlu.Stream]] = None, + shared_experts: Optional[nn.Module] = None, + ) -> torch.Tensor: + """forward with all2all.""" + ori_input_shape = hidden_states.shape + dtype = hidden_states.dtype + self.pack_params() + self.pack_params_after_loading() + w1=self.w13 + w2=self.w2 + bias2=self.b2 + input_smooth=self.a13_scale_all_experts + act_smooth=self.a2_scale + w1_scale=self.w13_scale + w2_scale=self.w2_scale + topk=self.top_k + renormalized=self.renormalize + act_mode=self.hidden_act + quant_input=None + + start_expert_id=self.start_expert_id + expert_size = w1.size(0) + max_m = hidden_states.shape[0] + gating_output, _ = gate(hidden_states) + gating_output = gating_output.view(-1, gating_output.size(-1)) + if self.scoring_func == "softmax": + reduce_weight, expert_id = mlu_ops.moe_softmax_topk(gating_output, topk, renormalized, self.expert_group, + self.topk_group, route_scale=self.routed_scaling_factor) + elif self.scoring_func == "sigmoid": + reduce_weight, expert_id = mlu_ops.moe_sigmoid_topk(gating_output, topk, renormalized, + self.expert_group, self.topk_group, + self.routed_scaling_factor, + self.gate.e_score_correction_bias) + else: + raise ValueError(f"Unsupported scoring function: {self.scoring_func}") + + if VLLM_AVG_MOE_EN: + # get dp rank + dp_rank = get_dp_group().rank_in_group + tp_rank = get_tp_group().rank_in_group + global_rank = dp_rank * get_tp_group().world_size + tp_rank + n_tokens = hidden_states.shape[0] + reduce_weight = SparseMoeMlp.reduce_weight[:n_tokens] + if self.use_all2all and VLLM_RANDOM_MOE_EN: + expert_id = SparseMoeMlp.expert_id[global_rank * n_tokens : (global_rank+1) * n_tokens] + elif self.use_all2all: + expert_id = SparseMoeMlp.expert_id[dp_rank * n_tokens: dp_rank * n_tokens + n_tokens] + else: + expert_id = SparseMoeMlp.expert_id[:n_tokens] + + expand_idx, combine_idx, token_count, cusum_token_count \ + = mlu_ops.moe_gen_idx(expert_id, self.num_total_experts) + + num_token_expand = hidden_states.shape[0] * self.top_k + dispatch_bytes = num_token_expand * self.dispatch_token_size + + dispatch_send_token_tensor = ( + self.dispatch_send_buffer[:dispatch_bytes] + .view(num_token_expand, self.dispatch_token_size) + ) + + quant_size = self.hidden_size + quant_input = dispatch_send_token_tensor[:, : quant_size] + input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32) + quant_input, input_scale = mlu_ops.moe_quantize( + hidden_states, input_smooth, None, token_count, expand_idx, None, + output=quant_input, + output_scale=input_scale) + + dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(token_count, self.moe_ep_size) + + cnclep_dispatch(self.dispatch_token_size, + num_token_expand, + dispatch_send_layout, + token_count, + self.dispatch_recv_layout, + self.dispatch_recv_token_num) + + recv_token_num = self.dispatch_recv_token_num.view(self.moe_ep_size, self.num_experts_per_rank) + pad_num = self.max_num_tokens_per_rank + + ( + gather_by_expert_index, + gather_by_rank_index, + tokens_per_local_expert, + token_sum + ) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num) + + max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size + dispatch_recv_token_tensor = ( + self.dispatch_recv_buffer[:max_tokens_bytes_recv] + .view(self.max_num_tokens_recv, self.dispatch_token_size)) + + mlu_ops.gather_split(dispatch_recv_token_tensor, + gather_by_expert_index, + token_sum, + self.quant_input_recv, + self.input_scale_recv) + + max_m = self.max_num_tokens_per_expert + gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1, + tokens_per_local_expert, + None, None, None, None, + self.input_scale_recv.view(torch.float32).flatten(), + w1_scale, dtype, max_m) + + # continue reusing self.quant_input_recv and self.input_scale_recv + quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2] + input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]] + quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None, + tokens_per_local_expert, + output=quant_input, + output_scale=input_scale_fp32, + act_mode=act_mode, + is_gated=self.is_gated) + + gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2, + tokens_per_local_expert, + None, None, None, None, input_scale, w2_scale, dtype, max_m) + + combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype) + mlu_ops.gather_split(gemm_out, + gather_by_rank_index, + token_sum, + combine_send_token_tensor, + None) + + combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size) + combine_recv_layout = self.dispatch_recv_layout + + # combine + combine_args = dict( + token_byte=self.hidden_size * 2, + token_num=num_token_expand, + send_src_layout=combine_send_layout, + send_dst_layout=combine_recv_layout, + send_token=None, + recv_token=None) + + shared_output = None + if shared_experts is not None: + parallelize_shared_expert = streams is not None + if parallelize_shared_expert: + compute_stream = streams['shared'] + comm_stream = streams['routed'] + curr_stream = torch.mlu.current_stream() + compute_stream.wait_stream(curr_stream) + comm_stream.wait_stream(curr_stream) + + with torch.mlu.stream(compute_stream): + shared_output = shared_experts(hidden_states, use_tp_weight=False) + + with torch.mlu.stream(comm_stream): + cnclep_combine(**combine_args) + + curr_stream.wait_stream(compute_stream) + curr_stream.wait_stream(comm_stream) + else: + shared_output = shared_experts(hidden_states, use_tp_weight=False) + cnclep_combine(**combine_args) + else: + cnclep_combine(**combine_args) + + numel_recv = num_token_expand * self.hidden_size + recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv] + .view(num_token_expand, self.hidden_size)) + + residual_ = shared_output + output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx, + residual_, None, start_expert_id, + expert_size, bias2, output=hidden_states) + + return output.view(ori_input_shape) diff --git a/vllm_mlu/model_executor/model_loader/__init__.py b/vllm_mlu/model_executor/model_loader/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/model_executor/model_loader/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/model_executor/model_loader/dummy_loader.py b/vllm_mlu/model_executor/model_loader/dummy_loader.py new file mode 100644 index 0000000..711fa48 --- /dev/null +++ b/vllm_mlu/model_executor/model_loader/dummy_loader.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +import torch +import torch.nn as nn +import numpy as np +from typing import List, Tuple +from tqdm import tqdm + +from vllm.config import ModelConfig +from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def initialize_dummy_weights_normal_dist( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + std: float = 0.5, + seed: int = 1234, +) -> None: + """ + Initialize the weights of a PyTorch model with values drawn from a normal distribution. + Floating point parameters are initialized with a normal distribution whose mean is randomly + sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are + initialized with random integers in [floor(low), ceil(high)). The initialization is performed + in a batched and efficient way for both floating point and integer parameters. + + Optimized version: Uses shared pinned memory based on the largest parameter block size + to minimize H2D transfers, sacrificing global uniqueness for performance. + + Args: + model (torch.nn.Module): The model whose weights will be initialized. + low (float): Lower bound for sampling the mean of the normal distribution (for float params). + high (float): Upper bound for sampling the mean of the normal distribution (for float params). + std (float): Standard deviation for the normal distribution (for float params). + seed (int): Random seed for reproducibility. + """ + # Randomly sample the mean for the normal distribution from [low, high] + rng = np.random.RandomState(seed) + mean = float(rng.uniform(low, high, 1).item()) + + # Create a CPU generator for reproducibility + cpu_gen = torch.Generator(device="cpu") + cpu_gen.manual_seed(seed) + + # Collect parameters: separate into floating point and integer types + float_params: List[Tuple[str, torch.Tensor]] = [] + int_params: List[Tuple[str, torch.Tensor]] = [] + + for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"): + if not isinstance(t, torch.Tensor): + continue + if torch.is_floating_point(t): + float_params.append((name, t)) + elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64): + int_params.append((name, t)) + + # -------- Floating point parameters: optimized shared memory initialization -------- + if float_params: + # Find the largest parameter block size + max_float_elems = max(p.numel() for _, p in float_params) + + # Create shared pinned memory buffer based on largest parameter + shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True) + shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen) + + # Copy shared buffer to device once + device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True) + + for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"): + n = p.numel() + # Extract from device buffer (may reuse same values for different parameters) + view = device_buffer[:n].view(p.shape) + + # torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed + if torch.finfo(p.dtype).bits < 16: + tmp = view.to(torch.float16) + tmp = tmp.to(p.dtype) + else: + tmp = view.to(p.dtype) + + # Copy from device buffer to parameter (D2D copy, much faster) + p.data.copy_(tmp) + + # -------- Integer parameters: optimized shared memory initialization -------- + if int_params: + # Find the largest parameter block size + max_int_elems = max(p.numel() for _, p in int_params) + + int_low = int(np.floor(low)) + int_high = int(np.ceil(high)) + if int_high == int_low: + int_high = int_low + 1 # Ensure at least one possible value + + # Create shared pinned memory buffer based on largest parameter + shared_int_buffer = torch.randint( + low=int_low, + high=int_high, + size=(max_int_elems,), + dtype=torch.int64, + generator=cpu_gen, + device="cpu", + pin_memory=True + ) + + # Copy shared buffer to device once + device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True) + + for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"): + n = p.numel() + # Extract from device buffer (may reuse same values for different parameters) + view = device_int_buffer[:n].view(p.shape) + tmp = view.to(p.dtype) + # Copy from device buffer to parameter (D2D copy, much faster) + p.data.copy_(tmp) + + +SMOOTHQUANT_METHOD = "smoothquant" +MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"} +def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits + std=0.5 is used for better distinguishable logits + ''' + + # === Default parameter setup (Original values as fallback) === + low_val = -1e-3 + high_val = 1e-3 + std_val = 0.5 + + # === Model and Quantization Check Logic === + quant_method = getattr(model_config, "quantization", None) + + # Attempt to get the architectures list from model_config + archs = getattr(model_config, "architectures", []) or [] + + # Determine if the model is multimodal (based on architecture names) + is_multimodal = any( + keyword in arch + for arch in archs + for keyword in MULTIMODAL_ARCH_KEYWORDS + ) + + # === Apply SmoothQuant + Multimodal Parameters === + if is_multimodal and quant_method == SMOOTHQUANT_METHOD: + # (smoothquant) + Multimodal specific values to mitigate NaN overflow + std_val = 1e-4 + + initialize_dummy_weights_normal_dist( + model, + low=low_val, + high=high_val, + std=std_val + ) + # add a sync to make sure the weights are initialized + torch.mlu.synchronize() + ''' + ================== + End of MLU Hijack + ================== + ''' + +MluHijackObject.apply_hijack( + DummyModelLoader, + DummyModelLoader.load_weights, + vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights +) \ No newline at end of file diff --git a/vllm_mlu/model_executor/model_loader/tensorizer.py b/vllm_mlu/model_executor/model_loader/tensorizer.py new file mode 100644 index 0000000..c842d13 --- /dev/null +++ b/vllm_mlu/model_executor/model_loader/tensorizer.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import time +import torch +from torch import nn +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union + +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, TensorDeserializer, TensorizerArgs, + _check_tensors_on_meta_device, _resize_lora_embeddings, + is_valid_deserialization_uri) +from vllm.platforms import current_platform +from vllm.logger import init_logger + +try: + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) + +except ImportError: + open_stream = tensorizer.placeholder_attr("stream_io.open_stream") + convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes") + get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") + no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") + +logger = init_logger(__name__) + + +def deserialize_tensorizer_model(model: nn.Module, + tensorizer_config: TensorizerConfig) -> None: + tensorizer_args = tensorizer_config._construct_tensorizer_args() + if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri): + raise ValueError( + f"{tensorizer_config.tensorizer_uri} is not a valid " + f"tensorizer URI. Please check that the URI is correct. " + f"It must either point to a local existing file, or have a " + f"S3, HTTP or HTTPS scheme.") + before_mem = get_mem_usage() + start = time.perf_counter() + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use mlu device + ''' + device = '' + if current_platform.is_out_of_tree(): + device = f'mlu:{torch.mlu.current_device()}' + elif current_platform.is_xpu(): + device = f'xpu:{torch.xpu.current_device()}' + else: + device = f'cuda:{torch.cuda.current_device()}' + with open_stream( + tensorizer_config.tensorizer_uri, + mode="rb", + **tensorizer_args.stream_kwargs) as stream, TensorDeserializer( + stream, + dtype=tensorizer_config.dtype, + device=device, + **tensorizer_args.deserialization_kwargs) as deserializer: + deserializer.load_into_module(model) + end = time.perf_counter() + ''' + ================== + End of MLU Hijack + ================== + ''' + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) + + _check_tensors_on_meta_device(model) + _resize_lora_embeddings(model) + del model.vllm_tensorized_marker + +def serialize_extra_artifacts( + tensorizer_args: TensorizerArgs, + served_model_name: Union[str, list[str], None]) -> None: + if not isinstance(served_model_name, str): + raise ValueError( + f"served_model_name must be a str for serialize_extra_artifacts, " + f"not {type(served_model_name)}.") + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use local file + ''' + import shutil + from pathlib import Path + local_model_path = Path(served_model_name) + if not local_model_path.exists() or not local_model_path.is_dir(): + raise ValueError( + f"served_model_name must be a valid local directory in offline mode, " + f"but got: {served_model_name}" + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + with tempfile.TemporaryDirectory() as tmpdir: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: copy local file + ''' + logger.info("Copying local model from %s to temporary directory %s", + local_model_path, tmpdir) + shutil.copytree(local_model_path, tmpdir, dirs_exist_ok=True) + ''' + ================== + End of MLU Hijack + ================== + ''' + + for artifact in os.scandir(tmpdir): + if not artifact.is_file(): + continue + with open(artifact.path, "rb") as f, open_stream( + f"{tensorizer_args.tensorizer_dir}/{artifact.name}", + mode="wb+", + **tensorizer_args.stream_kwargs) as stream: + logger.info("Writing artifact %s", artifact.name) + stream.write(f.read()) + diff --git a/vllm_mlu/model_executor/model_loader/tensorizer_loader.py b/vllm_mlu/model_executor/model_loader/tensorizer_loader.py new file mode 100644 index 0000000..21088aa --- /dev/null +++ b/vllm_mlu/model_executor/model_loader/tensorizer_loader.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from torch import nn + +from vllm.config import ModelConfig +from vllm.model_executor.model_loader.tensorizer import is_vllm_tensorized +from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader + +from vllm_mlu.model_executor.model_loader.tensorizer import deserialize_tensorizer_model +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights( + self, + model: nn.Module, + model_config: ModelConfig +) -> None: + """Load serialized model weights with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/others/tensorize_vllm_model.py example script + for serializing vLLM models.""" + if is_vllm_tensorized(self.tensorizer_config): + tensorizer_config = self._patch_tensorizer_config(model_config) + deserialize_tensorizer_model(model, tensorizer_config) + else: + model.load_weights(self._get_weights_iterator()) + + +MluHijackObject.apply_hijack( + TensorizerLoader, + TensorizerLoader.load_weights, + vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights +) \ No newline at end of file diff --git a/vllm_mlu/model_executor/models/__init__.py b/vllm_mlu/model_executor/models/__init__.py new file mode 100755 index 0000000..5ee3283 --- /dev/null +++ b/vllm_mlu/model_executor/models/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm import ModelRegistry + + +def register_model(): + from .deepseek_v4 import MLUDeepseekV4ForCausalLM # noqa: F401 + + ModelRegistry.register_model( + "DeepseekV4ForCausalLM", + "vllm_mlu.model_executor.models.deepseek_v4:MLUDeepseekV4ForCausalLM") diff --git a/vllm_mlu/model_executor/models/config.py b/vllm_mlu/model_executor/models/config.py new file mode 100644 index 0000000..7400103 --- /dev/null +++ b/vllm_mlu/model_executor/models/config.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from math import lcm +from typing import TYPE_CHECKING + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv, round_up +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm.model_executor.models.config import (HybridAttentionMambaModelConfig, + MambaModelConfig) +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +@classmethod +def vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config( + cls, + vllm_config: "VllmConfig" +) -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + # Save the user input before it gets modified by MambaModelConfig + mamba_block_size = vllm_config.cache_config.mamba_block_size + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: kernel_block_size 128 alignment + # * Other MLA backends: kernel_block_size 64 alignment + if model_config.use_mla: + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + attn_page_size_1_token = MLAAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + else: + kernel_block_alignment_size = 16 + if ( + current_platform.is_device_capability(100) + and model_config.get_head_size() == 256 + and ( + envs.VLLM_ATTENTION_BACKEND is None + or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER" + ) + ): + # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that` + # head size 256 and block size 16 is not supported on blackwell. + kernel_block_alignment_size = 32 + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=model_config.max_model_len, + ).page_size_bytes + + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + + if cache_config.enable_prefix_caching: + # With prefix caching, select attention block size to + # optimize for mamba kernel performance + + # Mamba2 SSD kernel uses a chunk_size, e.g. 256 + # Align the block to the kernel: use lowest multiple of chunk_size + # of attention tokens that would fit mamba_page_size: + # e.g. for mamba page size = 788kB + # attn_1_token = 2kB -> fits ~394 tokens + # then round up to a mulitple of 256 -> 512 tokens + # End result: + # attn_block_size = 512 + # mamba_block_size = 512 (aligned to a multiple of chunk_size) + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, select minimum valid attention block size + # to minimize mamba state padding + + # Calculate minimum attention block size that satisfies both: + # 1. Backend alignment requirements (kernel_block_alignment_size) + # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + if (vllm_config.mlu_config.enable_mamba_split_page_size): + vllm_config.mlu_config.mamba_to_attn_block_ratio = cdiv(attn_block_size, cache_config.block_size) + cache_config.mamba_page_size_padded = cache_config.block_size * attn_page_size_1_token + return + ''' + ================== + End of MLU Hijack + ================== + ''' + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if cache_config.block_size is None or cache_config.block_size < attn_block_size: + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) + + # compute new attention page size + attn_page_size = cache_config.block_size * attn_page_size_1_token + + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size + ) + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", + mamba_padding_pct, + ) + +MluHijackObject.apply_hijack(HybridAttentionMambaModelConfig, + HybridAttentionMambaModelConfig.verify_and_update_config, + vllm__module_executor__models__config__HybridAttentionMambaModelConfig__verify_and_update_config) diff --git a/vllm_mlu/model_executor/models/deepseek_v4.py b/vllm_mlu/model_executor/models/deepseek_v4.py new file mode 100644 index 0000000..6c9410f --- /dev/null +++ b/vllm_mlu/model_executor/models/deepseek_v4.py @@ -0,0 +1,1096 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import re +from typing import Iterable, Set, Tuple + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import ( + get_ep_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + get_data_parallel_group_world_size, + get_tp_group, +) +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.interfaces import SupportsEagle +from vllm.model_executor.models.utils import maybe_prefix +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.attention import AttentionMetadata +from vllm_mlu.model_executor.layers.feed_forward import FeedForward +from vllm_mlu.v1.attention.backends.utils import ( + MLUCommonAttentionMetadata, + get_common_metadata, +) +from vllm_mlu.model_executor.layers.indexer import Indexer +from vllm_mlu.model_executor.layers.compressor import Compressor +from vllm_mlu import _mlu_ops as mlu_ops +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.attention.layer import MLAAttention +from vllm_mlu.model_executor.layers.sparse_moe_mlp import MoeGroupInfo, SparseMoeMlp +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +logger = init_logger(__name__) + + +class HCHead(torch.nn.Module): + + def __init__( + self, + hc_mult, + dim, + hc_eps, + norm_eps, + prefix: str = "", + ): + super().__init__() + self.hc_mult: int = hc_mult + self.dim: int = dim + self.hc_dim: int = hc_mult * dim + self.hc_eps = hc_eps + self.norm_eps = norm_eps + + self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, self.hc_dim, dtype=torch.float), requires_grad=False) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float), requires_grad=False) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float), requires_grad=False) + + def forward(self, x: torch.Tensor): + shape, dtype = x.size(), x.dtype + x = x.flatten(-2).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) + mixes = F.linear(x, self.hc_head_fn) * rsqrt + pre = torch.sigmoid(mixes * self.hc_head_scale + self.hc_head_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=-2) + return y.to(dtype) + + +class HCPre(torch.nn.Module): + + def __init__( + self, + hc_mult, + dim, + hc_sinkhorn_iters, + hc_eps, + norm_eps, + prefix: str = "", + ): + super().__init__() + self.hc_mult: int = hc_mult + self.dim: int = dim + self.hc_dim: int = hc_mult * dim + self.hc_sinkhorn_iters = hc_sinkhorn_iters + self.hc_eps = hc_eps + mix_hc = (2 + hc_mult) * hc_mult + self.norm_eps = norm_eps + + self.hc_fn = nn.Parameter(torch.empty(mix_hc, self.hc_dim, dtype=torch.float), requires_grad=False) + self.hc_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float), requires_grad=False) + self.hc_scale = nn.Parameter(torch.empty(3, dtype=torch.float), requires_grad=False) + + def forward( + self, + x: torch.Tensor, + rsqrt: torch.Tensor | None = None, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(-2).float() + if rsqrt is None: + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) + + mixes = F.linear(x, self.hc_fn) + pre, post, comb = mlu_ops.hc_split_sinkhorn( + mixes.unsqueeze(0), + self.hc_scale, + self.hc_base, + rsqrt.squeeze(-1).unsqueeze(0), + self.hc_mult, + self.hc_sinkhorn_iters, + self.hc_eps, + ) + pre, post, comb = pre.squeeze(0), post.squeeze(0), comb.squeeze(0) + + + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=-2) + return y.to(dtype), post, comb + + +class HCPost(torch.nn.Module): + def __init__( + self, + norm_eps: float, + prefix: str = "", + ): + self.norm_eps = norm_eps + super().__init__() + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, + compute_rms: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor| None]: + # x: [bs, dim], residual: [bs, hc, dim], post: [bs, hc], comb: [bs, hc, hc] + # return + # y: [bs, hc, dim] + # [bs, hc, 1] * [bs, 1, dim] + torch.sum([bs, hc, hc, 1] * [bs, hc, 1, dim], -2) + # rsqrt: Optional, [bs, 1] + use_tmo = True + if use_tmo: + y, rsqrt = mlu_ops.fused_mhc_post(x, residual, post, comb, compute_rms, self.norm_eps) + return y, (rsqrt.unsqueeze(-1) if rsqrt is not None else None) + + y = post.unsqueeze(-1) * x.unsqueeze(-2) + \ + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=-3) + + rsqrt = ( + torch.rsqrt(y.type_as(x).flatten(-2).float().square().mean(-1, keepdim=True) + self.norm_eps) + if compute_rms + else None + ) + + return y.type_as(x), rsqrt + + +class MLUDeepseekV4MoE(SparseMoeMlp): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ): + layer_id = int(prefix.split(sep=".")[-2]) + self.layer_id = layer_id + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + SparseMoeMlp.__init__( + self, + num_experts=config.n_routed_experts, + top_k=config.n_activated_experts, + hidden_size=config.dim, + intermediate_size=config.moe_inter_dim, + up_proj_name='w13', + is_gated=True, + down_proj_name='w2', + has_bias=False, + hidden_act='silu', + params_dtype=torch.float, + quant_config=quant_config, + is_use_fused_moe=True, + expert_group=1, + topk_group=1, + scoring_func=config.score_func, + topk_method='', + routed_scaling_factor=config.route_scale, + use_hash=(layer_id < config.n_hash_layers), + vocab_size=config.vocab_size, + prefix=prefix, + ) + + self.dim = config.dim + world_size = get_ep_group().world_size + self.world_size = world_size + assert config.n_routed_experts % world_size == 0, \ + f"Number of experts must be divisible by world size (world_size={world_size})" + self.n_routed_experts = config.n_routed_experts + self.n_local_experts = self.n_routed_experts // world_size + self.n_activated_experts = config.n_activated_experts + self.experts_start_idx = get_ep_group().rank_in_group * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + + assert config.n_shared_experts == 1 + self.shared_experts = FeedForward( + hidden_size=config.dim, + intermediate_size=config.moe_inter_dim, + hidden_act='silu', + up_proj_name='w13', + is_gated=True, + down_proj_name='w2', + bias=False, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ) -> torch.Tensor: + shape = hidden_states.size() + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states.float()) + hidden_states = self.forward_experts( + hidden_states, + router_logits, + shared_output=shared_output, + input_ids=input_ids, + ) + hidden_states = self.reduce_results(hidden_states) + return hidden_states.view(shape) + + +class MLUDeepseekV4Attention(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + layer_id = int(prefix.split(sep=".")[-2]) + self.layer_id = layer_id + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + self.attn_data_parallel_size = get_data_parallel_group_world_size() + self.attn_tensor_parallel_size = get_tensor_model_parallel_world_size() + + self.num_heads = vllm_config.model_config.hf_config.n_heads + assert self.num_heads % self.tp_size == 0 + self.num_local_heads = self.num_heads // self.attn_tensor_parallel_size + self.model_type = config.model_type + self.use_indexer = hasattr(config, 'index_n_heads') + + self.hidden_size = config.dim + self.head_dim = config.head_dim + self.q_lora_rank = config.q_lora_rank + self.rope_head_dim = config.rope_head_dim + self.eps = config.norm_eps + self.o_groups = config.o_groups + self.o_local_groups = self.o_groups // self.attn_tensor_parallel_size + self.softmax_scale = self.head_dim ** -0.5 + self.compress_ratio = config.compress_ratios[layer_id] + self.window_size = config.window_size + self.max_model_len = vllm_config.model_config.max_model_len + self.original_seq_len = config.original_seq_len + self.index_topk = config.index_topk + + self.o_lora_rank = config.o_lora_rank + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + + tp_group = get_tp_group() + + # disable YaRN and use base rope_theta in pure sliding-window attention + if self.compress_ratio > 1: + max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", 65536) + self.rope_scaling["rope_type"] = 'deepseek_yarn' + else: + max_position_embeddings = 0 + self.rope_scaling["rope_type"] = 'default' + if self.rope_scaling is not None: + self.rope_scaling["original_max_position_embeddings"] = 0 + + self.rotary_emb = get_rope( + self.rope_head_dim, + rotary_dim=self.rope_head_dim, + max_position=max_position_embeddings, + base=config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta, + rope_scaling=self.rope_scaling, + is_neox_style=False, + ) + + self.output_rotary_emb = get_rope( + self.rope_head_dim, + rotary_dim=self.rope_head_dim, + max_position=max_position_embeddings, + base=config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta, + rope_scaling=self.rope_scaling, + dtype=torch.float32, + is_neox_style=False, + inverse=True, + ) + + if self.q_lora_rank is not None: + self.wq_a = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=None, + prefix=f"{prefix}.wq_a", + ) + self.q_norm = RMSNorm( + self.q_lora_rank, + eps=self.eps, + ) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + tp_group=tp_group, + ) + + self.wkv = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=None, + prefix=f"{prefix}.wkv", + ) + self.kv_norm = RMSNorm( + self.head_dim, + eps=self.eps, + ) + if get_tensor_model_parallel_world_size() <= self.o_groups: + self.wo_a = ColumnParallelLinear( + self.num_heads * self.head_dim // self.o_groups, + self.o_groups * self.o_lora_rank, + bias=False, + quant_config=None, + prefix=f"{prefix}.wo_a", + ) + self.wo_b = RowParallelLinear( + self.o_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.wo_b", + tp_group=tp_group, + ) + else: + self.wo_a = ReplicatedLinear( + self.num_heads * self.head_dim // self.o_groups, + self.o_groups * self.o_lora_rank, + bias=False, + quant_config=None, + prefix=f"{prefix}.wo_a", + ) + self.wo_b = ReplicatedLinear( + self.o_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wo_b", + ) + + self.attn = MLAAttention( + self.num_local_heads, # num_heads + self.softmax_scale, # scale + self.head_dim - self.rope_head_dim, # qk_nope_head_dim + self.rope_head_dim, # qk_rope_head_dim + self.head_dim, # v_head_dim + self.q_lora_rank, # q_lora_rank + self.head_dim, # kv_lora_rank + self.wkv, # kv_b_proj + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + # extra_impl_args + num_kv_heads=1, + prefix=f"{prefix}.attn", + use_fused_mla_qkv=False, + ) + + if self.compress_ratio: + self.compressor = Compressor(vllm_config, self.rotary_emb, self.compress_ratio, self.head_dim, False, f"{prefix}.compressor") + if self.compress_ratio == 4: + self.indexer = Indexer(vllm_config, self.rotary_emb, self.compress_ratio, f"{prefix}.indexer") + else: + self.indexer = None + + self.attn_sink = nn.Parameter(torch.empty(self.num_local_heads, dtype=torch.float32)) + + def forward_sparse_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + batch_to_kv_state: torch.Tensor, + window_compress_params: dict | None, + window_slot_mapping: torch.Tensor, + compressor_slot_mapping: dict | None, + ) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + if self.q_lora_rank is not None: + q = self.wq_a(hidden_states)[0] + + q = self.q_norm(q) + qr = q + q = self.wq_b(q)[0].view(-1, self.num_local_heads, self.head_dim) + + q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps) + + _, q_pe = q.split([self.head_dim - self.rope_head_dim, self.rope_head_dim], dim=-1) + + kv = self.wkv(hidden_states)[0] + kv = self.kv_norm(kv) + kv = kv.unsqueeze(-2) + kv_pe = kv[..., -self.rope_head_dim :] + + q_pe, kv_pe = self.rotary_emb(positions, q_pe, kv_pe, only_prefill=False) + + common_metadata = get_common_metadata() + query_start_loc = common_metadata.query_start_loc + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + key_cache = kv_cache[0][0] + mlu_ops.reshape_paged_cache( + kv, + None, + key_cache, + None, + window_slot_mapping, + ) + + if self.compress_ratio: + offsets = query_lens if common_metadata.is_prefill_only else torch.full_like(query_lens, self.window_size) + if self.indexer is not None: + indexer_kv_cache = kv_cache[2] + compress_block_tables, compress_context_lens = self.indexer( + hidden_states, + qr, + positions, + offsets, + attn_metadata, + batch_to_kv_state, + indexer_kv_cache, + compressor_slot_mapping[(0, self.compress_ratio)], + ) + + if self.compress_ratio: + compress_kv = self.compressor( + hidden_states, + positions, + attn_metadata, + batch_to_kv_state, + key_cache, + self.window_size, + compressor_slot_mapping[(self.window_size, self.compress_ratio)], + ) + + if common_metadata.is_prefill_only: + kv = torch.cat([kv, compress_kv], dim=0) + + assert window_compress_params != None + if self.compress_ratio: + if self.indexer is not None: + window_block_tables = window_compress_params.get("window_block_tables", None) + window_context_lens = window_compress_params.get("window_context_lens", None) + new_block_tables = torch.empty([num_tokens, self.window_size + self.index_topk], dtype=torch.int32, device=hidden_states.device) + new_context_lens = torch.empty([num_tokens], dtype=torch.int32, device=hidden_states.device) + mlu_ops.concat_block_table( + window_block_tables, + window_context_lens, + compress_block_tables, + compress_context_lens, + new_block_tables, + new_context_lens, + ) + max_contxt_len = self.window_size + self.index_topk + else: + new_block_tables = window_compress_params.get("compress_block_tables", None) + new_context_lens = window_compress_params.get("compress_context_lens", None) + max_contxt_len = self.window_size + (self.max_model_len // self.compress_ratio) + else: + new_block_tables = window_compress_params.get("window_block_tables", None) + new_context_lens = window_compress_params.get("window_context_lens", None) + max_contxt_len = self.window_size + + + attn_output = torch.zeros_like(q) + total_token = q.size(0) + assert total_token == new_block_tables.size(0) + q_ = q.view(total_token, -1, self.num_local_heads, self.head_dim) + attn_output = attn_output.view(total_token, -1, self.num_local_heads, self.head_dim) + if common_metadata.is_prefill_only: + kv_cache_ = kv.unsqueeze(1) # insert block_size, [total_token, 1, head_dim] -> [total_token, 1, 1, head_dim] + else: + kv_cache_ = kv_cache[0].view(-1, 1, 1, self.head_dim) + + mlu_ops.single_query_cached_kv_attn( + q=q_, + k_cache=kv_cache_, + v_cache=None, + out=attn_output, + block_tables=new_block_tables, + context_lens=new_context_lens, + k_cache_quant_scale=None, + v_cache_quant_scale=None, + alibi_slopes=None, + max_contxt_len=max_contxt_len, + windows_size_left=-1, + windows_size_right=-1, + softmax_scale=self.softmax_scale, + compute_dtype=torch.float, + learnable_sink=self.attn_sink, + ) + + attn_output = attn_output.reshape(-1, self.num_local_heads, self.head_dim).to(torch.float) + + attn_output_pe = attn_output[..., -self.rope_head_dim:] + attn_output_pe, _ = self.output_rotary_emb(positions, attn_output_pe, None, only_prefill=False) + + attn_output = attn_output.to(dtype=torch.bfloat16) + + if get_tensor_model_parallel_world_size() <= self.o_groups: + attn_output = attn_output.reshape(num_tokens, self.o_local_groups, -1) + wo_a = self.wo_a.weight.view(self.o_local_groups, self.o_lora_rank, -1) + + o = torch.einsum("ngd,grd->ngr", attn_output, wo_a) + output = self.wo_b(o.flatten(-2))[0] + + output = tensor_model_parallel_all_reduce(output) + else: + # (token, 64/tp, head_dim) -> (64/tp, token, head_dim) + attn_output = attn_output.flatten(-2).contiguous() + attn_output = tensor_model_parallel_all_gather(attn_output, dim=-1) + # (token, 64 * head_dim) -> (token, 64, head_dim) + attn_output = attn_output.reshape(-1, self.num_heads, self.head_dim).contiguous() # t, 64 + wo_a = self.wo_a.weight.view(self.o_groups, self.o_lora_rank, -1) + attn_output = attn_output.reshape(num_tokens, self.o_groups, -1) + o = torch.einsum("ngd,grd->ngr", attn_output, wo_a) + output = self.wo_b(o.flatten(-2))[0] + return output + + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + batch_to_kv_state: torch.Tensor, + window_compress_params: dict | None, + window_slot_mapping: torch.Tensor, + compressor_slot_mapping: dict | None, + ) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return torch.empty_like(hidden_states) + + # self.attn and self.attn_decoder always have the same attn_metadata + # and share the same kv cache for each layer + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.attn.layer_name] + kv_cache = self.attn.kv_cache[forward_context.virtual_engine] + + output = self.forward_sparse_attn( + positions, + hidden_states, + kv_cache, + attn_metadata, + batch_to_kv_state, + window_compress_params, + window_slot_mapping, + compressor_slot_mapping, + ) + + return output + + +class MLUDeepseekV4DecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + config: PretrainedConfig | None = None, + ) -> None: + super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + self.config = config + + self.dim = config.dim + layer_idx = int(prefix.split(sep=".")[-1]) + self.layer_idx = layer_idx + + self.attn = MLUDeepseekV4Attention( + vllm_config=vllm_config, + prefix=f"{prefix}.attn", + ) + + self.hc_mult = config.hc_mult + self.mix_hc = (2 + self.hc_mult) * self.hc_mult + self.hc_dim = self.hc_mult * config.dim + self.norm_eps = config.norm_eps + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + + self.hc_attn_pre = HCPre( + self.hc_mult, + config.dim, + self.hc_sinkhorn_iters, + self.hc_eps, + self.norm_eps, + prefix=f"{prefix}.hc_attn_pre" + ) + self.hc_attn_post = HCPost( + self.norm_eps, + ) + + self.hc_ffn_pre = HCPre( + self.hc_mult, + config.dim, + self.hc_sinkhorn_iters, + self.hc_eps, + self.norm_eps, + prefix=f"{prefix}.hc_attn_pre" + ) + self.hc_ffn_post = HCPost( + self.norm_eps, + ) + + self.attn_norm = RMSNorm(config.dim, config.norm_eps) + + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.ffn = MLUDeepseekV4MoE( + vllm_config=vllm_config, + prefix=f"{prefix}.ffn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_ids: torch.Tensor | None, + residual: torch.Tensor | None, + batch_to_kv_state: torch.Tensor, + window_compress_params: dict | None = None, + hc_attn_pre_norm: torch.Tensor | None = None, + window_slot_mapping: torch.Tensor | None = None, + compressor_slot_mapping: dict | None = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, post, comb = self.hc_attn_pre(hidden_states, rsqrt=hc_attn_pre_norm) + hidden_states = self.attn_norm(hidden_states) + hidden_states = self.attn( + positions, + hidden_states, + batch_to_kv_state, + window_compress_params, + window_slot_mapping, + compressor_slot_mapping, + ) + hidden_states, hc_ffn_pre_norm = self.hc_attn_post( + hidden_states, + residual, + post, + comb, + compute_rms=True, + ) + residual = hidden_states + + is_last_layer = (self.layer_idx == self.config.n_layers - 1) + hidden_states, post, comb = self.hc_ffn_pre(hidden_states, rsqrt=hc_ffn_pre_norm) + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.ffn(hidden_states, input_ids) + hidden_states, hc_attn_pre_norm = self.hc_ffn_post( + hidden_states, + residual, + post, + comb, + compute_rms=(not is_last_layer), + ) + + return hidden_states, hc_attn_pre_norm + +@support_torch_compile +class MLUDeepseekV4Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.device = current_platform.device_type + + self.compress_ratio = 128 # only compressor layer 128 + self.window_size = config.window_size + self.max_model_len = vllm_config.model_config.max_model_len + + self.vocab_size = config.vocab_size + self.norm_eps = config.norm_eps + self.hc_eps = config.hc_eps + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.dim, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + self.layers = nn.ModuleList() + for layer_id in range(config.n_layers): + self.layers.append(MLUDeepseekV4DecoderLayer( + vllm_config=vllm_config, + prefix=f"{prefix}.layers.{layer_id}", + config=config, + )) + + self.hc_mult = config.hc_mult + self.dim = config.dim + self.hc_head = HCHead( + self.hc_mult, + self.dim, + self.hc_eps, + self.norm_eps, + prefix=f"{prefix}.hc_head", + ) + + self.norm = RMSNorm(config.dim, self.norm_eps) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + batch_to_kv_state: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + window_slot_mapping: torch.Tensor | None = None, + compressor_slot_mapping: dict | None = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.config.hc_mult, 1) + + common_metadata = get_common_metadata() + if common_metadata is not None: + total_token_num = hidden_states.size(0) + window_block_tables = torch.empty([total_token_num, self.window_size], dtype=torch.int32, device=hidden_states.device) + window_context_lens = torch.empty([total_token_num], dtype=torch.int32, device=hidden_states.device) + kv_cache_size = self.window_size + (self.max_model_len // self.compress_ratio if self.compress_ratio else 0) + compress_block_tables = torch.empty([total_token_num, kv_cache_size], dtype=torch.int32, device=hidden_states.device) + compress_context_lens = torch.empty([total_token_num], dtype=torch.int32, device=hidden_states.device) + + mlu_ops.get_window_block_tables( + window_size = self.window_size, + block_size = 1, + seq_k_lens = common_metadata.seq_lens, + query_start_loc = common_metadata.query_start_loc, + block_table = common_metadata.block_table_tensor, + window_block_tables = window_block_tables, + window_context_lens = window_context_lens + ) + + # get_compress_block_tables + query_start_loc = common_metadata.query_start_loc + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + compress_lens = query_lens // self.compress_ratio + cu_compress_lens = torch.cat([ + torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device), + torch.cumsum(compress_lens, dim=0) + ]) + offsets = cu_compress_lens[: -1] + total_token_num if common_metadata.is_prefill_only else torch.full_like(query_lens, self.window_size) + + mlu_ops.get_compress_block_tables( + ratio = self.compress_ratio, + block_size = 1, + seq_k_lens = common_metadata.seq_lens, + query_start_loc = common_metadata.query_start_loc, + offset = offsets, + block_table = common_metadata.block_table_tensor, + compress_block_tables = compress_block_tables, + compress_context_lens = compress_context_lens, + ) + + win_comp_block_tables = torch.empty([total_token_num, kv_cache_size], dtype=torch.int32, device=hidden_states.device) + win_comp_context_lens = torch.empty([total_token_num], dtype=torch.int32, device=hidden_states.device) + mlu_ops.concat_block_table( + window_block_tables, + window_context_lens, + compress_block_tables, + compress_context_lens, + win_comp_block_tables, + win_comp_context_lens, + ) + + window_compress_params = { + "window_block_tables": window_block_tables, + "window_context_lens": window_context_lens, + "compress_block_tables": win_comp_block_tables, + "compress_context_lens": win_comp_context_lens, + } + else: + window_compress_params = None + + hc_attn_pre_norm = None + for layer in self.layers: + hidden_states, hc_attn_pre_norm = layer( + positions, + hidden_states, + input_ids, + None, + batch_to_kv_state, + window_compress_params, + hc_attn_pre_norm=hc_attn_pre_norm, + window_slot_mapping=window_slot_mapping, + compressor_slot_mapping=compressor_slot_mapping, + ) + hidden_states = self.hc_head(hidden_states) + hidden_states = self.norm(hidden_states).to(dtype=torch.float) + + return hidden_states + +class MLUDeepseekV4ForCausalLM(nn.Module, SupportsEagle): + packed_modules_mapping = { + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = MLUDeepseekV4Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.lm_head = ColumnParallelLinear( + config.dim, + config.vocab_size, + params_dtype=torch.float32, + quant_config=quant_config, + bias=False, + skip_bias_add=True, + return_bias=False, + ) + + def update_forward_args(self, args, kwargs): + window_size = self.config.window_size + + # Part 1. window slot mapping. + common_metadata: MLUCommonAttentionMetadata = get_common_metadata() + if common_metadata is None or common_metadata.block_table_tensor is None: + window_slot_mapping = None + elif common_metadata.is_prefill_only: + block_table = common_metadata.block_table_tensor + query_start_loc = common_metadata.query_start_loc + window_slot_mapping = torch.empty([query_start_loc[-1]], dtype=torch.int32, device=block_table.device) + window_slot_mapping.fill_(-1) + for i, seq_len in enumerate(common_metadata.seq_lens): + if seq_len < window_size: + window_slot_mapping[query_start_loc[i]: query_start_loc[i+1]].copy_(block_table[i, :seq_len]) + else: + # | <------- seqlen--------> | + # | other | window size | + # | other | tail | head | + # move head to the front of window, and move tail to the latter. + tail_pos = query_start_loc[i].item() + seq_len - window_size + head_size = seq_len % window_size + tail_size = window_size - head_size + window_slot_mapping[tail_pos: tail_pos + tail_size].copy_( + block_table[i, head_size:window_size], + ) + window_slot_mapping[tail_pos + tail_size: tail_pos + window_size].copy_( + block_table[i, :head_size] + ) + else: + block_table = common_metadata.block_table_tensor + window_pos = (common_metadata.seq_lens - 1) % window_size + window_slot_mapping = torch.gather(block_table, 1, window_pos.unsqueeze(1)).squeeze(1) + + kwargs["window_slot_mapping"] = window_slot_mapping + + # Part 2. compressor slot mapping + assert set(self.config.compress_ratios) == {0, 4, 128} + # The pairs <128, 128> <128, 4> <0, 4> contain all cases. + # <128, 128> and <128, 4> indicate attn.compressor, and + # <0, 4> indicates attn.indexer.compressor. + window_offsets = [128, 128, 0] + compress_ratios = [128, 4, 4] + # dict key: (window_size, compress_ratio) + compressor_slot_mapping = dict() + if common_metadata is None or common_metadata.block_table_tensor is None: + pass + elif common_metadata.is_prefill_only: + block_tables = common_metadata.block_table_tensor + query_start_loc = common_metadata.query_start_loc + query_start_loc = common_metadata.query_start_loc + query_lens = (query_start_loc[1:] - query_start_loc[:-1]).tolist() + + for compress_ratio, window_offset in zip(compress_ratios, window_offsets): + slot_lens = [q // compress_ratio for q in query_lens] + cu_slot_lens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device='cpu'), + torch.cumsum(torch.tensor(slot_lens, dtype=torch.int32, device='cpu'), dim=0)], + ) + slot_mapping = torch.empty(sum(slot_lens), dtype=torch.int32, device=block_table.device) + for i in range(len(query_lens)): + slot_mapping[cu_slot_lens[i]: cu_slot_lens[i+1]] = \ + block_tables[i, window_offset: window_offset + slot_lens[i]] + compressor_slot_mapping[(window_offset, compress_ratio)] = slot_mapping + else: + block_tables = common_metadata.block_table_tensor + seq_lens = common_metadata.seq_lens + query_start_loc = common_metadata.query_start_loc + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + for compress_ratio, window_offset in zip(compress_ratios, window_offsets): + offset = window_offset + (seq_lens - query_lens) // compress_ratio + slot_mapping = torch.gather(block_tables, 1, offset.unsqueeze(1)).squeeze(1) + compressor_slot_mapping[(window_offset, compress_ratio)] = slot_mapping + + kwargs["compressor_slot_mapping"] = compressor_slot_mapping + + return args, kwargs + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + batch_to_kv_state: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + window_slot_mapping: torch.Tensor | None = None, + compressor_slot_mapping: dict | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, + positions, + batch_to_kv_state, + inputs_embeds, + window_slot_mapping, + compressor_slot_mapping, + ) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('w13', 'w1', 0), + ('w13', 'w3', 1), + ] + + for name, m in self.model.named_modules(): + if isinstance(m, SparseMoeMlp): + m.pack_params() + + moe_group_info = MoeGroupInfo() + moe_ep_size = moe_group_info.moe_ep_size + moe_ep_rank = moe_group_info.moe_ep_rank + num_total_experts = self.config.n_routed_experts + start_expert_id = moe_ep_rank * ((num_total_experts + moe_ep_size - 1) // moe_ep_size) + expert_num_per_rank = (num_total_experts + moe_ep_size - 1) // moe_ep_size + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + pattern = r'layers\.([0-9]*)\.' + match = re.search(pattern, name) + if match: + layer_id = int(match.group(1)) + if layer_id >= self.config.n_layers: + continue + + # The following parameters are not included yet. + skiped_parameters = ['mtp'] + if any(param in name for param in skiped_parameters): + continue + + name = name.replace("embed.weight", "embed_tokens.weight") + name = "model." + name + name = name.replace("model.head.weight", "lm_head.weight") + + if "ffn.experts." in name: + expert_id = int(name.split(".")[-3]) + if expert_id < start_expert_id or expert_id >= start_expert_id + ((num_total_experts + moe_ep_size - 1) // moe_ep_size): + continue + new_expert_id = expert_id - start_expert_id + name = name.replace(f"experts.{expert_id}", f"experts.{new_expert_id}") + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if "w1.weight" not in name and \ + "w3.weight" not in name: + continue + + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # remap parameter name for hc pre + name = name.replace("hc_attn_base", "hc_attn_pre.hc_base") + name = name.replace("hc_attn_fn", "hc_attn_pre.hc_fn") + name = name.replace("hc_attn_scale", "hc_attn_pre.hc_scale") + name = name.replace("hc_ffn_base", "hc_ffn_pre.hc_base") + name = name.replace("hc_ffn_fn", "hc_ffn_pre.hc_fn") + name = name.replace("hc_ffn_scale", "hc_ffn_pre.hc_scale") + + # remap parameter name for hc head + name = name.replace("hc_head_base", "hc_head.hc_head_base") + name = name.replace("hc_head_fn", "hc_head.hc_head_fn") + name = name.replace("hc_head_scale", "hc_head.hc_head_scale") + + name = name.replace("gate.tid2eid", "deepseekv4_topk.tid2eid") + name = name.replace("ffn.gate.bias", "ffn.deepseekv4_topk.bias") + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + + if 'attn_sink' in name: + num_heads = self.config.n_heads + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + assert num_heads % tp_size == 0 + num_local_heads = num_heads // tp_size + loaded_weight = loaded_weight[tp_rank * num_local_heads: (tp_rank + 1) * num_local_heads] + + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if diff := set(params_dict.keys()) - loaded_params: + logger.error(f"The following params are not loaded: {diff}") + + for name, m in self.model.named_modules(): + if isinstance(m, SparseMoeMlp): + m.pack_params_after_loading() + + + return set(loaded_params) diff --git a/vllm_mlu/model_executor/models/dp_utils.py b/vllm_mlu/model_executor/models/dp_utils.py new file mode 100644 index 0000000..ed1919d --- /dev/null +++ b/vllm_mlu/model_executor/models/dp_utils.py @@ -0,0 +1,607 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import ( + Any, List, Tuple, Optional, Dict, Union, ClassVar, Literal, + Protocol, overload, runtime_checkable) +from typing_extensions import TypeIs + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.config import VllmConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_gather_into_list, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter, +) +from vllm.distributed import ( + get_tp_group, + get_pp_group, + get_dp_group, + get_data_parallel_group_rank, + get_data_parallel_group_world_size, + get_dense_mlp_tp_world_size, + get_tp_world_world_size, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_logits_tp_world_size, + get_parallel_rank_with_group, + get_tp_world_group, + get_tp_world_rank, + GroupCoordinator, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +from vllm_mlu.mlu_forward_context import MLUDPMetadata +from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp +from vllm_mlu.v1.attention.backends.utils import get_common_metadata + +logger = init_logger(__name__) + +# alias after refactor +DataParallelRuntimeParams = MLUDPMetadata + + +def enable_data_parallel(): + return get_dp_group().world_size > 1 + + +def enable_emb_logits_custom_parallel(): + return get_logits_tp_world_size() != get_tensor_model_parallel_world_size() + + +def enable_dense_mlp_custom_parallel(): + return get_dense_mlp_tp_world_size() != get_tp_world_world_size() + + +def get_runtime_infos_per_dp_group( + num_tokens: int, num_requests: int, all_prefill: bool, seq_lens: List[int], + device: torch.device, vllm_config: VllmConfig) -> Tuple[List[int], List[bool]]: + dp_tensor = torch.tensor([num_tokens, num_requests, int(all_prefill)]).to(device, non_blocking=True) + outputs = tensor_model_parallel_all_gather_into_list(dp_tensor, get_dp_group()) + outputs = torch.cat(outputs).tolist() # d2h + dp_world_size = get_data_parallel_group_world_size() + dp_is_prefill, dp_query_lens, dp_group_bs, seq_len_per_batch = [], [], [], [] + for i in range(0, 3 * dp_world_size, 3): + dp_query_lens.append(outputs[i]) + dp_group_bs.append(outputs[i + 1]) + dp_is_prefill.append(bool(outputs[i + 2])) + + # Only run communication if mcc is enabled and is prefill. + if vllm_config.mlu_config.is_dpsk_mcc_enabled and all(dp_is_prefill): + assert len(seq_lens) == num_requests + seq_len_per_batch = [torch.empty([bs], dtype=dp_tensor.dtype, device=device) for bs in dp_group_bs] + seq_lens_tensor = torch.tensor(seq_lens, dtype=dp_tensor.dtype, device=device) + torch.distributed.all_gather(seq_len_per_batch, seq_lens_tensor, group=get_dp_group().device_group) + seq_len_per_batch=torch.cat(seq_len_per_batch).tolist() + else: + seq_len_per_batch = [0] * sum(dp_group_bs) + + return dp_query_lens, dp_group_bs, dp_is_prefill, seq_len_per_batch + + +def get_deepseek_layer_split_list( + dp_query_lens: List[int], dp_group_bs: List[int] +) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]: + if len(dp_query_lens) != len(dp_group_bs) or len(dp_query_lens) != get_data_parallel_group_world_size(): + logger.warning(f"dp_query_lens length: {len(dp_query_lens)} != dp_group_bs length: {len(dp_group_bs)}, " + f"disable deepseek layer split") + return None, None, None + emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None + all_dp_query_lens, all_dp_group_bs = [], [] + for i in range(len(dp_query_lens)): + all_dp_query_lens.extend([dp_query_lens[i]] * get_tensor_model_parallel_world_size()) + all_dp_group_bs.extend([dp_group_bs[i]] * get_tensor_model_parallel_world_size()) + if get_logits_tp_world_size() != get_tensor_model_parallel_world_size(): + slice_start = get_tp_world_rank() // get_logits_tp_world_size() * get_logits_tp_world_size() + slice_end = slice_start + get_logits_tp_world_size() + emb_query_lens = all_dp_query_lens[slice_start:slice_end] + logits_batch_sizes = all_dp_group_bs[slice_start:slice_end] + if get_dense_mlp_tp_world_size() != get_tp_world_world_size(): + slice_start = get_tp_world_rank() // get_dense_mlp_tp_world_size() * get_dense_mlp_tp_world_size() + slice_end = slice_start + get_dense_mlp_tp_world_size() + dense_attn_token_split_list = all_dp_query_lens[slice_start:slice_end] + return emb_query_lens, logits_batch_sizes, dense_attn_token_split_list + + +def get_dp_metadata( + num_tokens: int, + data_parallel_size: int, + data_parallel_rank: int, + tensor_parallel_size: int, + prefill_dispatch_use_RS_AG: bool, +) -> DataParallelRuntimeParams: + """ + Get dp params when dummy run or capture model graph. These two cases do not have + dp_params when forward call, because we do not want to hijack to much. + """ + dp_query_lens = [num_tokens] * data_parallel_size + in_prefill = get_forward_context().attn_metadata is None # dummy run + dp_is_prefill = [in_prefill] * data_parallel_size + emb_query_lens, logits_batch_sizes, dense_attn_token_split_list = None, None, None + if get_logits_tp_world_size() != get_tensor_model_parallel_world_size(): + emb_query_lens = [num_tokens] * get_logits_tp_world_size() + logits_batch_sizes = None # dummy run and capture model does not contain logits + if get_dense_mlp_tp_world_size() != get_tp_world_world_size(): + dense_attn_token_split_list = [num_tokens] * get_dense_mlp_tp_world_size() + + return MLUDPMetadata.make_oot(data_parallel_rank, + data_parallel_size, + tensor_parallel_size, + dp_query_lens, + dp_is_prefill, + prefill_dispatch_use_RS_AG, + emb_query_lens=emb_query_lens, + logits_batch_sizes=logits_batch_sizes, + dense_attn_token_split_list=dense_attn_token_split_list) + + +def remove_paddings_after_all_gather( + hidden_states: torch.Tensor, + padding_to_token_num: int, + token_num_list: List[int], +) -> torch.Tensor: + dp_group_tensors = [] + offset = 0 + for token_num in token_num_list: + if token_num != 0: + dp_group_tensors.append(hidden_states[offset:offset+token_num]) + offset += padding_to_token_num + if len(dp_group_tensors) == 1: + hidden_states = dp_group_tensors[0] + else: + hidden_states = torch.cat(dp_group_tensors) + return hidden_states + + +def tensor_model_parallel_all_gather_dp( + group_num_tokens: List[int], + rank: int, + hidden_states: Optional[torch.Tensor], + group: GroupCoordinator, + hidden_size: int = None, + dtype: torch.dtype = None, + device: torch.device = None) -> torch.Tensor: + """ + All gather in the group. + Input is a 2-D tensor, and can have different shape in the first dim, + for example, [4, 7, 5, 8], [2, 5, 4, 0]. + """ + num_tokens_equal = all(x == group_num_tokens[0] for x in group_num_tokens) + if num_tokens_equal: + hidden_states = tensor_model_parallel_all_gather( + input_=hidden_states, dim=0, tp_group=group) + else: + max_num_tokens = max(group_num_tokens) + num_padding = max_num_tokens - group_num_tokens[rank] + if num_padding > 0: + if hidden_states is None: + hidden_states = torch.empty((max_num_tokens, hidden_size), + dtype=dtype, device=device) + else: + hidden_states = F.pad(hidden_states, (0, 0, 0, num_padding)) + hidden_states = tensor_model_parallel_all_gather( + input_=hidden_states, dim=0, tp_group=group) + hidden_states = remove_paddings_after_all_gather( + hidden_states, max_num_tokens, group_num_tokens) + return hidden_states + +def tensor_model_parallel_all_gather_op_v2( + input_: torch.Tensor, + dim_size_list: List[int], + group_coordinator: GroupCoordinator, + non_leading_dim_size: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """ + All gather the input tensor across model parallel group with only communication ops. + + Note: compared to `tensor_model_parallel_all_gather_dp`, this method supports different + sizes in the first dim, and does not involve padding operation. + """ + all_size_equal = all([dim_size == dim_size_list[0] for dim_size in dim_size_list]) + + output_shape = (sum(dim_size_list), non_leading_dim_size) + output = torch.empty(output_shape, device=device, dtype=dtype) + + if input_ is None: + input_ = torch.empty((0, non_leading_dim_size), device=device, dtype=dtype) + + if all_size_equal: + torch.distributed.all_gather_into_tensor( + output, input_, group=group_coordinator.device_group) + else: + # Note: torch.split splits the tensor into chunks. And each chunk + # is a view of the original tensor. + tensor_list = torch.split(output, dim_size_list, dim=0) + torch.distributed.all_gather( + list(tensor_list), input_, group=group_coordinator.device_group) + return output + +def process_post_attention_communication( + hidden_states: Optional[torch.Tensor], + dp_params: DataParallelRuntimeParams, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + tp_group: Any = None, +): + """ + Processes distributed communication operations after attention computation. + + This function performs necessary communication operations after attention computation + to ensure data synchronization across different parallel groups. + Supports two modes: + 1. Tensor parallel mode: Uses tp_group for all-reduce and all-gather operations + 2. Data parallel mode: Uses reduce-scatter and all-gather for global synchronization + + Args: + hidden_states: Hidden states tensor after attention computation, can be None + dp_params: Data parallel runtime parameters containing token distribution and padding info + hidden_size: Dimension size of hidden states + dtype: Data type of the tensor + device: Device where the tensor is located + tp_group: Tensor parallel group, if None uses data parallel mode + + Returns: + Hidden states tensor after communication synchronization processing + + Note: + - When prefill_pad_to_token_num != -1, padding and unpadding operations will be performed + - Function selects optimal communication path based on token count and parallel strategy + """ + if tp_group is not None: + if dp_params.token_num != 0: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + hidden_states = tensor_model_parallel_all_gather_dp( + group_num_tokens=dp_params.dense_attn_token_split_list, + rank=get_parallel_rank_with_group(tp_group), + hidden_states=hidden_states, + group=tp_group, + ) + else: + if dp_params.prefill_pad_to_token_num != -1: + # pad hidden_states to use reduce_scatter and global all gather + pad_num = dp_params.prefill_pad_to_token_num - dp_params.token_num + if pad_num != 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num)) + + hidden_states = tensor_model_parallel_reduce_scatter( + hidden_states, dim=0) + + hidden_states = tensor_model_parallel_all_gather_dp( + group_num_tokens=dp_params.attn_token_split_list_reduce_scatter, + rank=get_tp_world_rank(), + hidden_states=hidden_states, + group=get_tp_world_group(), + ) + + # get origin hidden_states for moe compute + hidden_states = remove_paddings_after_all_gather( + hidden_states, dp_params.prefill_pad_to_token_num, + dp_params.token_split_list) + else: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + + all_gather_group = get_dp_group() + all_gather_rank = get_data_parallel_group_rank() + hidden_states = tensor_model_parallel_all_gather_dp( + dp_params.token_split_list, all_gather_rank, hidden_states, + all_gather_group, hidden_size, dtype, device) + + return hidden_states + +def dp_model_forward( + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], + dp_params: DataParallelRuntimeParams, + embedding_layer: nn.Module, + model_norm_layer: nn.Module, + start_layer: int, + end_layer: int, + layers: List[nn.Module], + layer_input_norm_name: str, + prefill_dispatch_use_RS_AG: bool, + streams: Optional[Dict[str, torch.mlu.Stream]] = None, +) -> Union[torch.Tensor, IntermediateTensors]: + """run model with dp.""" + if dp_params is None: + dp_params = get_dp_metadata(positions.numel(), + get_data_parallel_group_world_size(), + get_data_parallel_group_rank(), + get_tensor_model_parallel_world_size(), + prefill_dispatch_use_RS_AG) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + if embedding_layer.__class__.__name__ == "DPVocabParallelEmbedding": + hidden_states = embedding_layer(input_ids, dp_params=dp_params) + else: + hidden_states = embedding_layer(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(start_layer, end_layer): + is_first_layer = (i == start_layer) + is_last_layer = (i == end_layer - 1) + next_input_layernorm = None + if not is_last_layer: + next_input_layernorm = getattr(layers[i+1], layer_input_norm_name) + hidden_states, residual = layers[i]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + dp_params=dp_params, + is_first_layer=is_first_layer, + is_last_layer=is_last_layer, + streams=streams, + next_input_layernorm=next_input_layernorm, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = model_norm_layer(hidden_states) + return hidden_states + +def dp_layer_forward( + input_norm: nn.Module, + self_attn: nn.Module, + post_norm: nn.Module, + mlp: nn.Module, + mlp_kwargs: List[Dict[str, Any]], + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + dp_params: DataParallelRuntimeParams, + hidden_size: int, + hidden_states_dtype: torch.dtype, + is_first_layer: bool = False, + is_last_layer: bool = False, + next_input_layernorm: Optional[nn.Module] = None, + enable_all2all: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + run layer with dp. dispatch all2all or rs+ag or common. + + For mlp_kwargs, because all2all forward args is often different with common mlp args. + So here we decide that the mlp_kwargs[-1] is always all2all kwargs. For example: + + Deepseek enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}, {mlp all2all kwargs}]. + Deepseek does not enable all2all, mlp_kwargs will be: [{mlp common forward kwargs}]. + """ + if dp_params.layer_use_reduce_scatter: + common_metadata = get_common_metadata() + is_decode_only = common_metadata is not None and common_metadata.is_decode_only + use_all2all = enable_all2all and is_decode_only and isinstance(mlp, SparseMoeMlp) + forward_func = _dp_forward_layer_all2all if use_all2all else _dp_forward_layer_rs_ag + hidden_states, residual = forward_func(input_norm, + self_attn, + post_norm, + mlp, + mlp_kwargs, + positions, + hidden_states, + residual, + dp_params, + is_first_layer, + is_last_layer, + next_input_layernorm) + else: + hidden_states, residual = _dp_forward_layer_common(input_norm, + self_attn, + post_norm, + mlp, + mlp_kwargs, + positions, + hidden_states, + residual, + dp_params, + hidden_size, + hidden_states_dtype) + return hidden_states, residual + +def _dp_forward_layer_rs_ag( + input_norm: nn.Module, + self_attn: nn.Module, + post_norm: nn.Module, + mlp: nn.Module, + mlp_kwargs: List[Dict[str, Any]], + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + dp_params: DataParallelRuntimeParams, + is_first_layer: bool, + is_last_layer: bool, + next_input_layernorm: List[Optional[nn.Module]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """run layer with rs+ag.""" + if residual is None: + residual = hidden_states + + # We move the input_layernorm of i+1 layer to the end of i layer. + # But for the first layer, we need to do input_layernorm first. + if is_first_layer: + hidden_states = input_norm(hidden_states) + + # Self Attention + hidden_states = self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # add residual here for the first layer + if is_first_layer and get_tensor_model_parallel_rank() == 0: + hidden_states = hidden_states + residual + + hidden_states = tensor_model_parallel_reduce_scatter( + hidden_states, dim=0) + + # move norm between rs and ag + if is_first_layer: + residual = hidden_states + hidden_states = post_norm(hidden_states) + else: + hidden_states, residual = post_norm(hidden_states, residual) + + hidden_states = tensor_model_parallel_all_gather_dp( + group_num_tokens=dp_params.attn_token_split_list_reduce_scatter, + rank=get_tp_world_rank(), + hidden_states=hidden_states, + group=get_tp_world_group(), + ) + + # mlp, use all cards + hidden_states = mlp(hidden_states, **mlp_kwargs[0]) + + hidden_states = tensor_model_parallel_reduce_scatter( + hidden_states, dim=0, tp_group=get_tp_world_group()) + + if is_last_layer: + hidden_states = hidden_states + residual + residual = None + else: + # To reduce layernorm computation, we move the layernorm of i+1 layer to + # the end of i layer. Besides, we fuse residual addition into layernorm. + assert next_input_layernorm is not None + hidden_states, residual = next_input_layernorm(hidden_states, residual) + + hidden_states = tensor_model_parallel_all_gather_dp( + group_num_tokens=dp_params.moe_token_split_list_reduce_scatter, + rank=get_tensor_model_parallel_rank(), + hidden_states=hidden_states, + group=get_tp_group(), + ) + + return hidden_states, residual + +def _dp_forward_layer_all2all( + input_norm: nn.Module, + self_attn: nn.Module, + post_norm: nn.Module, + mlp: nn.Module, + mlp_kwargs: List[Dict[str, Any]], + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + dp_params: DataParallelRuntimeParams, + is_first_layer: bool, + is_last_layer: bool, + next_input_layernorm: List[Optional[nn.Module]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """run layer with all2all.""" + if residual is None: + residual = hidden_states + + # We move the input_layernorm of i+1 layer to the end of i layer. + # But for the first layer, we need to do input_layernorm first. + if is_first_layer: + hidden_states = input_norm(hidden_states) + + # Self Attention + hidden_states = self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # add residual here for the first layer + if is_first_layer and get_tensor_model_parallel_rank() == 0: + hidden_states = hidden_states + residual + + hidden_states = tensor_model_parallel_reduce_scatter( + hidden_states, dim=0) + + # move norm between rs and ag + if is_first_layer: + residual = hidden_states + hidden_states = post_norm(hidden_states) + else: + # add residual in norm for other layers + hidden_states, residual = post_norm(hidden_states, residual) + + hidden_states = mlp.forward_all2all(hidden_states, **mlp_kwargs[-1]) + + if is_last_layer: + hidden_states = hidden_states + residual + residual = None + else: + # To reduce layernorm computation, we move the layernorm of i+1 layer to + # the end of i layer. Besides, we fuse residual addition into layernorm. + assert next_input_layernorm is not None + hidden_states, residual = next_input_layernorm(hidden_states, residual) + + hidden_states = tensor_model_parallel_all_gather_dp( + group_num_tokens=dp_params.moe_token_split_list_reduce_scatter, + rank=get_tensor_model_parallel_rank(), + hidden_states=hidden_states, + group=get_tp_group(), + ) + + return hidden_states, residual + +def _dp_forward_layer_common( + input_norm: nn.Module, + self_attn: nn.Module, + post_norm: nn.Module, + mlp: nn.Module, + mlp_kwargs: List[Dict[str, Any]], + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + dp_params: DataParallelRuntimeParams, + hidden_size: int, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """run layer with common.""" + if residual is None: + residual = hidden_states + + hidden_states = input_norm(hidden_states) + hidden_states = self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # add residual here + if get_tensor_model_parallel_rank() == 0: + hidden_states = hidden_states + residual + + hidden_states = process_post_attention_communication( + hidden_states, dp_params, hidden_size, dtype, positions.device, None + ) + + residual = hidden_states[dp_params.token_num_offset: + dp_params.token_num_offset + dp_params.token_num] + + hidden_states = post_norm(hidden_states) + + hidden_states = mlp(hidden_states, **mlp_kwargs[0]) + + hidden_states = tensor_model_parallel_all_reduce( + hidden_states, tp_group=get_tp_world_group()) + + # add residual here + hidden_states = hidden_states[dp_params.token_num_offset: + dp_params.token_num_offset+dp_params.token_num] + hidden_states = hidden_states + residual + residual = hidden_states + + return hidden_states, residual diff --git a/vllm_mlu/model_executor/models/layer_utils.py b/vllm_mlu/model_executor/models/layer_utils.py new file mode 100755 index 0000000..3a6542e --- /dev/null +++ b/vllm_mlu/model_executor/models/layer_utils.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +from typing import Callable, Optional, List, Union, Tuple + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm.attention import AttentionMetadata +from vllm.sequence import IntermediateTensors +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from transformers import PretrainedConfig + + +def hunyuan_decoder_layer_forward_base( + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_layernorm: Callable, + self_attn: Callable, + post_layernorm: Callable, + mlp: Callable, + kv_states: Optional[Tuple[torch.Tensor]] = None, + apply_residual_connection_post_layernorm: bool = False, + position_name: str = 'positions', + input_norm_fuse_en: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + smooth_quant_scale = None + if input_norm_fuse_en: + layernorm_output, smooth_quant_scale = input_layernorm(hidden_states) + else: + layernorm_output = input_layernorm(hidden_states) + smooth_quant_scale = None + if apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self Attention + attention_output, ori_kv_states = self_attn( + **{position_name: positions}, + hidden_states=layernorm_output, + residual=residual, + kv_states=kv_states, + smooth_quant_scale=smooth_quant_scale, + ) + + layernorm_output = post_layernorm(attention_output) + if apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # Fully Connected + hidden_states = mlp(layernorm_output, residual) + return hidden_states, ori_kv_states + + +def decoder_layer_forward_base( + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_layernorm: Callable, + self_attn: Callable, + post_layernorm: Callable, + mlp: Callable, + apply_residual_connection_post_layernorm: bool = False, + position_name: str = 'positions', + input_norm_fuse_en: bool = False, + post_norm_fuse_en: bool = False, +) -> torch.Tensor: + if input_norm_fuse_en: + layernorm_output, smooth_quant_scale = input_layernorm(hidden_states) + else: + layernorm_output = input_layernorm(hidden_states) + smooth_quant_scale = None + + if apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self Attention + attention_output = self_attn( + **{position_name: positions}, + hidden_states=layernorm_output, + residual=residual, + smooth_quant_scale=smooth_quant_scale, + ) + + if post_norm_fuse_en: + layernorm_output, smooth_quant_scale = post_layernorm(attention_output) + else: + layernorm_output = post_layernorm(attention_output) + smooth_quant_scale = None + + if apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # Fully Connected + kwargs = dict() + if post_norm_fuse_en: + kwargs['smooth_quant_scale'] = smooth_quant_scale + hidden_states = mlp(layernorm_output, residual, **kwargs) + return hidden_states + + +def decoder_model_forward_base( + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + layers: torch.nn.ModuleList, + embed_input_ids: Callable, + norm: Callable +) -> torch.Tensor: + hidden_states = embed_input_ids(input_ids) + for i in range(len(layers)): + layer = layers[i] + hidden_states = layer( + positions, + hidden_states, + ) + hidden_states = norm(hidden_states) + return hidden_states + + +def hunyuan_decoder_model_forward_base_pp( + config: PretrainedConfig, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + layers: torch.nn.ModuleList, + start_layer: int, + end_layer: int, + embed_input_ids: Callable, + norm: Callable, + inputs_embeds: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = embed_input_ids(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + cla_factor = getattr(config, "cla_share_factor", 1) + prev_kv_states = None + for i in range(start_layer, end_layer): + layer = layers[i] + hidden_states, kv_states = layer( + positions, + hidden_states, + prev_kv_states, + ) + if (i - start_layer) % cla_factor == 0: + prev_kv_states = kv_states + else: + prev_kv_states = None + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + + hidden_states = norm(hidden_states) + return hidden_states + + +def decoder_model_forward_base_pp( + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + layers: torch.nn.ModuleList, + start_layer: int, + end_layer: int, + embed_input_ids: Callable, + norm: Callable, + inputs_embeds: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = embed_input_ids(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(start_layer, end_layer): + layer = layers[i] + hidden_states = layer( + positions, + hidden_states, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + + hidden_states = norm(hidden_states) + return hidden_states + + +def is_smoothquant(quant_config: QuantizationConfig) -> bool: + return (quant_config is not None + and quant_config.get_name() == "SmoothQuant") + + +def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool: + return (is_smoothquant(quant_config) + and quant_config.input_quant_method == "per_token") + +def compute_in_loop(func: Callable, + input: torch.Tensor, + chunk_size: int, + feature_size: Optional[int] = None, + **kwargs): + """ + divides input into chunks in the leading dimension (dimension 0), and + compute the chunks in a loop, instead of in a batch at once. + + arg: + feature_size: size of output feature dimension. Provide it when the + the output's feature dimension would differ from the input's + feature dimension. + """ + + total = input.shape[0] + # directly compute if there is only one chunk + if chunk_size >= total: + return func(input, **kwargs) + + feature_size = feature_size or input.shape[1] + output = input.new_empty(total, feature_size) + num_chunks = (total + chunk_size - 1) // chunk_size + + for i in range(num_chunks): + start = i * chunk_size + end = min((i + 1) * chunk_size, total) + output[start : end] = func(input[start : end], **kwargs) + + return output \ No newline at end of file diff --git a/vllm_mlu/model_executor/models/partition_utils.py b/vllm_mlu/model_executor/models/partition_utils.py new file mode 100644 index 0000000..29a97bd --- /dev/null +++ b/vllm_mlu/model_executor/models/partition_utils.py @@ -0,0 +1,507 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import itertools +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm_mlu.mlu_forward_context import MLUDPMetadata +from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams +from vllm_mlu.v1.attention.backends.mla.flashmla import ( + FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata +) +from vllm_mlu.v1.attention.backends.utils import ( + COMMON_METADATA_STR, + MLUCommonAttentionMetadata, +) + +SEQUENCE_DIM_PARITION_THRESHOLD = 1024 + +def get_common_and_layer_metadata( + attn_metadata: Optional[dict], +) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]: + """ + Returns the common metadata and layer metadata from the given attention metadata. + """ + if attn_metadata is None: + return None, None + + if isinstance(attn_metadata, dict): + assert COMMON_METADATA_STR in attn_metadata, ( + f"attn_metadata must contain {COMMON_METADATA_STR} key" + ) + assert len({id(v) for v in attn_metadata.values()}) == 2, ( + f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and " + f"the other for layers." + ) + common_metadata = attn_metadata[COMMON_METADATA_STR] + layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None) + + return common_metadata, layer_metadata + return None, attn_metadata + +def should_skip_partition(layer_metadata, common_metadata) -> bool: + """Helper function to simplify partition condition check""" + is_layer_metadata_invalid = (layer_metadata is None + or layer_metadata.prefill is None + or layer_metadata.query_start_loc is None + or layer_metadata.query_start_loc.numel() == 0) + is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only + return is_layer_metadata_invalid or is_common_metadata_invalid + + +def attn_mcc_plan( + attn_metadata: Any, + dp_params: DataParallelRuntimeParams, + parts_to_split: int, +) -> Tuple[int, int]: + """ + Returns the number of parts for batch size dimension and the number of parts for sequence length dimension. + """ + # In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata + if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))): + raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}") + + if isinstance(attn_metadata, dict): + common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) + else: + common_metadata, layer_metadata = None, attn_metadata + + if dp_params is None: + # We don't support mcc with decode yet. + if should_skip_partition(layer_metadata, common_metadata): + return 1, 1 + + # The priority of batch size dimension to split is higher than sequence length dimension. + # And we ensure each subtask is not empty without dp. + num_prefills = layer_metadata.query_start_loc.numel() - 1 + if num_prefills > 1: + return min(parts_to_split, num_prefills), 1 + + try: + max_query_len = torch.diff(layer_metadata.query_start_loc).max().item() + except RuntimeError: + return 1, 1 + + if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD: + return 1, 1 + return 1, min(parts_to_split, max_query_len) + else: + if not all(is_prefill for is_prefill in dp_params.dp_is_prefill): + return 1, 1 + + max_bs = max(dp_params.batch_sizes) + if max_bs > 1: + # Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits + if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD: + return 1, 1 + return min(parts_to_split, max_bs), 1 + else: + if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD: + return 1, 1 + return 1, parts_to_split + +def get_data_num_and_offset(total_size, parts_to_split): + """ + Get data size and offset for each. + For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9] + total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6] + """ + # Calculate the quotient and remainder of total_size divided by parts_to_split + quotient = total_size // parts_to_split + remainder = total_size % parts_to_split + data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder) + offset_list = [0] + list(itertools.accumulate(data_num_list)) + return data_num_list, offset_list[:-1] + +def split_dp_params( + dp_params: DataParallelRuntimeParams, + bs_parts_to_split: int, + seq_parts_to_split: int, + attn_data_parallel_size: int, + attn_tensor_parallel_size: int, + prefill_dispatch_use_RS_AG: bool, + dp_rank_: int, +) -> List[DataParallelRuntimeParams]: + assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \ + "We don't support split batch and sequence dimensions concurrently." + + if dp_params is None: + return [None] * bs_parts_to_split * seq_parts_to_split + + if bs_parts_to_split * seq_parts_to_split == 1: + return list([dp_params]) + + if bs_parts_to_split == 1: + results : List[DataParallelRuntimeParams] = [] + dp_seq_lens = [] + for seq_len in dp_params.seq_lens: + tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split) + dp_seq_lens.append(tokens) + + query_lens_per_dp_rank = [] + + # For each dp rank, the batch size is 0 or 1. + bs_offset = 0 + for i in range(attn_data_parallel_size): + if dp_params.batch_sizes[i] > 0: + seq_len = dp_params.seq_lens[bs_offset] + tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split) + query_lens_per_dp_rank.append(tokens) + bs_offset += dp_params.batch_sizes[i] + else: + query_lens_per_dp_rank.append([0] * seq_parts_to_split) + + for i in range(seq_parts_to_split): + dp_is_prefill = [] + for dp_rank in range(attn_data_parallel_size): + dp_is_prefill.append(True) + + results.append(MLUDPMetadata.make_oot( + data_parallel_rank=dp_rank_, + data_parallel_size=attn_data_parallel_size, + tensor_parallel_size=attn_tensor_parallel_size, + dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)], + dp_is_prefill=dp_is_prefill, + prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG, + seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens], + batch_sizes=dp_params.batch_sizes, + )) + return results + + bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...] + seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...] + + # [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...] + split_bs_per_dp = [] + # [[ + # [bs0_part_0_rank_0, bs1_part_0_rank_0, ...], + # [bs0_part_1_rank_0, bs1_part_1_rank_0, ...], + # ... + # ], + # [ + # [bs0_part_0_rank_1, bs1_part_0_rank_1, ...], + # [bs0_part_1_rank_1, bs1_part_1_rank_1, ...], + # ... + # ], + # ] + split_query_lens_per_dp = [] + for dp_rank in range(attn_data_parallel_size): + _bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split) + split_bs_per_dp.append(_bs) + split_query_lens_per_dp.append([]) + for i in range(bs_parts_to_split): + start = sum(bs_per_dp[:dp_rank]) + _offset[i] + end = start + _bs[i] + split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end]) + + results : List[DataParallelRuntimeParams] = [] + for i in range(bs_parts_to_split): + dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)] + seq_lens = [] + for dp_rank in range(attn_data_parallel_size): + seq_lens += split_query_lens_per_dp[dp_rank][i] + batch_sizes = [] + for dp_rank in range(attn_data_parallel_size): + batch_sizes.append(split_bs_per_dp[dp_rank][i]) + + dp_is_prefill = [] + for dp_rank in range(attn_data_parallel_size): + dp_is_prefill.append(True) + results.append(MLUDPMetadata.make_oot( + data_parallel_rank=dp_rank_, + data_parallel_size=attn_data_parallel_size, + tensor_parallel_size=attn_tensor_parallel_size, + dp_token_nums=dp_query_lens, + dp_is_prefill=dp_is_prefill, + prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG, + seq_lens=seq_lens, + batch_sizes=batch_sizes, + )) + + return results + + +def split_input( + input: torch.Tensor, + bs_parts_to_split: int, + seq_parts_to_split: int, + attn_metadata_list: List[AttentionMetadata], +) -> List[torch.Tensor]: + assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \ + "We don't support split batch and sequence dimensions concurrently." + + if input is None: + return [None] * bs_parts_to_split * seq_parts_to_split + + if bs_parts_to_split * seq_parts_to_split == 1: + return list([input]) + + token_num_list = [0] * len(attn_metadata_list) + for i, metadata in enumerate(attn_metadata_list): + common_metadata, layer_metadata = get_common_and_layer_metadata(metadata) + if layer_metadata is not None: + token_num_list[i] = layer_metadata.num_actual_tokens + + # A special case for dummy run + if layer_metadata is None and i == 0: + token_num_list[i] = input.shape[0] + + results = list() + for i in range(bs_parts_to_split * seq_parts_to_split): + start = sum(token_num_list[:i]) + end = start + token_num_list[i] + results.append(input[start:end]) + return results + + +def split_positions( + positions: torch.Tensor, + bs_parts_to_split: int, + seq_parts_to_split: int, + attn_metadata: AttentionMetadata, +) -> List[torch.Tensor]: + if seq_parts_to_split == 1: + return [positions] * bs_parts_to_split + + common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) + total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0 + + tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split) + positions_list = [] + for i in range(seq_parts_to_split): + positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]]) + + return positions_list + +def split_attn_metadata( + attn_metadata: dict, + bs_parts_to_split: int, + seq_parts_to_split: int, +) -> List[Any]: + """ attn_metdata is a dict, which contains common and layer metadata.""" + assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \ + "We don't support split batch and sequence dimensions concurrently." + if bs_parts_to_split == 1 and seq_parts_to_split == 1: + return list([attn_metadata]) + + if attn_metadata is None: + return [None] * bs_parts_to_split * seq_parts_to_split + + if seq_parts_to_split > 1: + common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) + if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'): + raise ValueError("common_metadata is invalid or missing num_actual_tokens") + num_prefill_tokens = common_metadata.num_actual_tokens + tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split) + device = common_metadata.seq_lens.device + sub_common_metadata, sub_layer_metadata = [], [] + for i in range(seq_parts_to_split): + # query_start_loc tensor, which indices positions in input. + query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc) + query_start_loc_tensor[0] = 0 + query_start_loc_tensor[1] = tokens[i] + # seq_lens tensor + seq_lens_tensor = torch.tensor( + [offsets[i] + tokens[i]], + dtype=common_metadata.seq_lens.dtype, + device=device + ) + # seq_start_loc tensor, which indicates positions in the sequence(kv cache). + seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc) + seq_start_loc_tensor[0] = offsets[i] + seq_start_loc_tensor[1] = offsets[i] + tokens[i] + # max_query_len scalar + max_query_len = tokens[i] + # num_actual_tokens scalar + num_actual_tokens = tokens[i] + # num_input_tokens scalar + num_input_tokens = num_actual_tokens + # infer_mode + infer_mode = common_metadata.infer_mode + # update common metadata + sub_common_metadata.append(MLUCommonAttentionMetadata( + query_start_loc=query_start_loc_tensor, + query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used + seq_lens=seq_lens_tensor, + seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used + num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used + num_reqs=common_metadata.num_reqs, # FIXME: split when used + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + max_seq_len=max_query_len, + block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used + slot_mapping=common_metadata.slot_mapping, # FIXME: split when used + seq_start_loc=seq_start_loc_tensor, + num_input_tokens=num_input_tokens, + infer_mode=infer_mode, + num_prefill_query_tokens=tokens[i], + num_prefill_kv_tokens=offsets[i] + tokens[i], + )) + # slot_mapping tensor + slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]] + # update layer metadata + REQUIRED_NUM_DECODES = 0 + REQUIRED_NUM_DECODE_TOKENS = 0 + REQUIRED_NUM_PREFILLS = 1 + if not hasattr(layer_metadata, 'num_prefills') or \ + layer_metadata.num_prefills is None: + raise ValueError("layer_metadata.num_prefills is required") + + assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \ + layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \ + layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, ( + f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, " + f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, " + f"{layer_metadata.num_prefills}." + ) + assert layer_metadata.prefill.chunked_context is None, ( + f"chunked_context is only available for prefill with chunked context, " + f"and it is not supported when enabling mcc." + ) + prefill_metadata = FlashMLAPrefillMetadata( + block_table=layer_metadata.prefill.block_table, + query_start_loc=query_start_loc_tensor, + max_query_len=max_query_len, + chunked_context=None, + num_prefills=layer_metadata.prefill.num_prefills, + max_seq_len=layer_metadata.prefill.max_seq_len, + ) + # Note: for sequence dimension partition, we provide cu_seqlens_kv filed to + # indicates key/value size for flash attention operator. + prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc) + prefill_metadata.cu_seqlens_kv[0] = 0 + prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i] + + sub_layer_metadata.append(FlashMLAMetadata( + num_reqs=layer_metadata.num_reqs, + max_query_len=max_query_len, + max_seq_len=max_query_len, + num_actual_tokens=num_actual_tokens, + query_start_loc=query_start_loc_tensor, + slot_mapping=slot_mapping, + num_decodes=layer_metadata.num_decodes, + num_decode_tokens=layer_metadata.num_decode_tokens, + num_prefills=layer_metadata.num_prefills, + num_prefill_tokens=tokens[i], + head_dim=layer_metadata.head_dim, + decode=layer_metadata.decode, + prefill=prefill_metadata, + )) + + sub_attn_metadata_list = [] + for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata): + sub_attn_metadata_dict = {} + for key, value in attn_metadata.items(): + if key == COMMON_METADATA_STR: + sub_attn_metadata_dict[key] = common_meta + else: + sub_attn_metadata_dict[key] = layer_meta + sub_attn_metadata_list.append(sub_attn_metadata_dict) + return sub_attn_metadata_list + elif bs_parts_to_split > 1: + common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) + if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None: + raise ValueError("layer_metadata.num_prefills is required") + total_batch = layer_metadata.num_prefills + batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split) + sub_common_metadata, sub_layer_metadata = [], [] + for i in range(bs_parts_to_split): + # query_start_loc tensor + start, end = offsets[i], offsets[i] + batch_sizes[i] + query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone() + if i > 0: + query_start_loc_tensor -= common_metadata.query_start_loc[start] + # block_table + block_tables = torch.empty( + (batch_sizes[i], 0), + dtype=layer_metadata.prefill.block_table.dtype, + device=layer_metadata.prefill.block_table.device, + ) + # seq_lens tensor + seq_lens_tensor = common_metadata.seq_lens[start:end].clone() + # seq_start_loc tensor + seq_start_loc_tensor = query_start_loc_tensor + # max_query_len scalar + max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0 + # num_actual_tokens scalar + num_actual_tokens = seq_start_loc_tensor[-1].item() + # num_input_tokens scalar + num_input_tokens = num_actual_tokens + # infer_mode + infer_mode = common_metadata.infer_mode + # slot_mapping tensor + slot_mapping_start = 0 + for data in sub_common_metadata: + slot_mapping_start += data.num_actual_tokens + slot_mapping_tensor = layer_metadata.slot_mapping[ + slot_mapping_start:slot_mapping_start + num_actual_tokens + ] + # update common metadata + sub_common_metadata.append(MLUCommonAttentionMetadata( + query_start_loc=query_start_loc_tensor, + query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used + seq_lens=seq_lens_tensor, + seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used + num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used + num_reqs=common_metadata.num_reqs, # FIXME: split when used + block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used + slot_mapping=common_metadata.slot_mapping, # FIXME: split when used + seq_start_loc=seq_start_loc_tensor, + max_query_len=max_query_len, + max_seq_len=max_query_len, + num_actual_tokens=num_actual_tokens, + num_input_tokens=num_input_tokens, + infer_mode=infer_mode, + num_prefill_query_tokens=num_actual_tokens, + num_prefill_kv_tokens=num_actual_tokens, + )) + # update layer_metadata + prefill_metadata = FlashMLAPrefillMetadata( + block_table=block_tables, + query_start_loc=query_start_loc_tensor, + max_query_len=max_query_len, + chunked_context=None, + num_prefills=batch_sizes[i], + max_seq_len=max_query_len, + ) + sub_layer_metadata.append(FlashMLAMetadata( + num_reqs=batch_sizes[i], + max_query_len=max_query_len, + max_seq_len=max_query_len, + num_actual_tokens=num_actual_tokens, + query_start_loc=query_start_loc_tensor, + slot_mapping=slot_mapping_tensor, + num_decodes=layer_metadata.num_decodes, # useless field + num_decode_tokens=0, # useless field + num_prefills=batch_sizes[i], + num_prefill_tokens=num_actual_tokens, + head_dim=layer_metadata.head_dim, + decode=layer_metadata.decode, + prefill=prefill_metadata, + )) + + sub_attn_metadata_list = [] + for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata): + sub_attn_metadata_dict = {} + for key, value in attn_metadata.items(): + if key == COMMON_METADATA_STR: + sub_attn_metadata_dict[key] = common_meta + else: + sub_attn_metadata_dict[key] = layer_meta + sub_attn_metadata_list.append(sub_attn_metadata_dict) + return sub_attn_metadata_list + + +def execute_with_updated_forward_context( + vllm_config: VllmConfig, + attn_metadata: AttentionMetadata, + func: Callable, + kwargs: Dict[str, Any], +): + with set_forward_context(attn_metadata, vllm_config): + return func(**kwargs) diff --git a/vllm_mlu/model_executor/models/registry.py b/vllm_mlu/model_executor/models/registry.py new file mode 100644 index 0000000..fd843b4 --- /dev/null +++ b/vllm_mlu/model_executor/models/registry.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Type, Union + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.model_executor.models.registry import ( + _LazyRegisteredModel, _RegisteredModel, _ModelRegistry) + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +def vllm__model_executor__models__registry___ModelRegistry__register_model( + self, + model_arch: str, + model_cls: Union[type[nn.Module], str], +) -> None: + """ + Register an external model to be used in vLLM. + + `model_cls` can be either: + + - A [`torch.nn.Module`][] class directly referencing the model. + - A string in the format `:` which can be used to + lazily import the model. This is useful to avoid initializing CUDA + when importing the model and thus the related error + `RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + """ + if not isinstance(model_arch, str): + msg = f"`model_arch` should be a string, not a {type(model_arch)}" + raise TypeError(msg) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: change mlu models register log level + ''' + if model_arch in self.models: + if isinstance(model_cls, str) and "MLU" in model_cls: + logger.debug( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls) + else: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if isinstance(model_cls, str): + split_str = model_cls.split(":") + if len(split_str) != 2: + msg = "Expected a string in the format `:`" + raise ValueError(msg) + + model = _LazyRegisteredModel(*split_str) + elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): + model = _RegisteredModel.from_model_cls(model_cls) + else: + msg = ("`model_cls` should be a string or PyTorch model class, " + f"not a {type(model_arch)}") + raise TypeError(msg) + + self.models[model_arch] = model + + +MluHijackObject.apply_hijack( + _ModelRegistry, + _ModelRegistry.register_model, + vllm__model_executor__models__registry___ModelRegistry__register_model +) \ No newline at end of file diff --git a/vllm_mlu/model_executor/models/utils.py b/vllm_mlu/model_executor/models/utils.py new file mode 100644 index 0000000..f1800b6 --- /dev/null +++ b/vllm_mlu/model_executor/models/utils.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import json +import os + +import torch + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.forward_context import get_forward_context + +def set_attn_compute_dtype_v1(attn_metadata, dtype: torch.dtype): + ''' + set attn compute_dtype for v1 + ''' + if isinstance(attn_metadata, dict): + for _, metadata in attn_metadata.items(): + metadata.compute_dtype = dtype + else: + metadata.compute_dtype = dtype + +def set_attn_compute_dtype(dtype: torch.dtype): + ''' + set attn compute_dtype. + + TODO: FA may standardize on half precision computation in the future + set_attn_compute_dtype might be deprecated and removed + ''' + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + set_attn_compute_dtype_v1(attn_metadata, dtype) + +def is_tie_word_embeddings( + model_config: ModelConfig, + org_tie_word_embeddings: bool +) -> bool: + ''' + Vllm language model config for multimodal model may have wrong tie_word_embeddings, + for example, InternVL3.5-38B, InternVL3.5-30B-A3B, etc. + + This function is a WorkAround. + ''' + from vllm.lora.utils import get_adapter_absolute_path + + if not model_config.is_multimodal_model: + return org_tie_word_embeddings + + model_path = get_adapter_absolute_path(model_config.model) + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + return org_tie_word_embeddings + + tie_word_embeddings = org_tie_word_embeddings + with open(config_path) as f: + config = json.load(f) + # first, we find if tie_word_embeddings config is in overall config + if config.get("tie_word_embeddings") is not None: + tie_word_embeddings = config["tie_word_embeddings"] + # then, we find if tie_word_embeddings config is in language model config + if (config.get("llm_config") is not None + and config["llm_config"].get("tie_word_embeddings") is not None): + tie_word_embeddings = config["llm_config"]["tie_word_embeddings"] + + return tie_word_embeddings diff --git a/vllm_mlu/model_executor/parameter.py b/vllm_mlu/model_executor/parameter.py new file mode 100644 index 0000000..42e073c --- /dev/null +++ b/vllm_mlu/model_executor/parameter.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Callable, Any + +import torch + +from vllm.model_executor.parameter import BasevLLMParameter +from vllm.distributed import ( + get_parallel_rank_with_group, + get_parallel_world_size_with_group, +) + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +vllm__model_executor__parameter__BasevLLMParameter____init__org = BasevLLMParameter.__init__ + + +def vllm__model_executor__parameter__BasevLLMParameter____init__( + self, + data: torch.Tensor, + weight_loader: Callable, + tp_group: Any = None +): + vllm__model_executor__parameter__BasevLLMParameter____init__org( + self, data, weight_loader + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add self.tp_group, world_size and tp_rank + ''' + if tp_group is not None: + self.tp_group = tp_group + self.tp_world_size = get_parallel_world_size_with_group(self.tp_group) + self.tp_rank = get_parallel_rank_with_group(self.tp_group) + ''' + ================= + End of MLU Hijack + ================= + ''' + + +MluHijackObject.apply_hijack(BasevLLMParameter, + BasevLLMParameter.__init__, + vllm__model_executor__parameter__BasevLLMParameter____init__) diff --git a/vllm_mlu/model_executor/warmup/kernel_warmup.py b/vllm_mlu/model_executor/warmup/kernel_warmup.py new file mode 100644 index 0000000..0c620ac --- /dev/null +++ b/vllm_mlu/model_executor/warmup/kernel_warmup.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +""" +Warmup kernels used during model execution. +This is useful specifically for JIT'ed kernels as we don't want JIT'ing to +happen during model execution. +""" + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_worker import Worker + +logger = init_logger(__name__) + + +def kernel_warmup(worker: "Worker"): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: skip deep GEMM warmup, flashinfer autotune, and + flash infer attention warmup + ''' + + ''' + ================== + End of MLU Hijack + ================== + ''' + pass \ No newline at end of file diff --git a/vllm_mlu/multimodal/__init__.py b/vllm_mlu/multimodal/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/multimodal/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/multimodal/utils.py b/vllm_mlu/multimodal/utils.py new file mode 100644 index 0000000..37667ed --- /dev/null +++ b/vllm_mlu/multimodal/utils.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import Any, TypeVar + +from PIL import Image + +from vllm import multimodal +from vllm.logger import init_logger +from vllm.multimodal.utils import MediaConnector + + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__multimodal__utils__fetch_image( + image_url: str, + image_io_kwargs: dict[str, Any] | None = None, +) -> Image.Image: + """ + Args: + image_url: URL of the image file to fetch. + image_io_kwargs: Additional kwargs passed to handle image IO. + """ + media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs} + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: set 'allowed_local_media_path' as default + ''' + media_connector = MediaConnector(media_io_kwargs, + allowed_local_media_path=image_io_kwargs["allowed_local_media_path"] + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + return media_connector.fetch_image(image_url) + + +MluHijackObject.apply_hijack(multimodal, + multimodal.utils, + vllm__multimodal__utils__fetch_image) \ No newline at end of file diff --git a/vllm_mlu/platforms/__init__.py b/vllm_mlu/platforms/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/platforms/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/platforms/mlu.py b/vllm_mlu/platforms/mlu.py new file mode 100644 index 0000000..46ce7ed --- /dev/null +++ b/vllm_mlu/platforms/mlu.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import contextlib +from functools import lru_cache +from typing import TYPE_CHECKING, Optional, Tuple + +import os +import torch + +from vllm.logger import init_logger + +import vllm.envs as envs +from vllm.platforms.interface import ( + DeviceCapability, + Platform, + PlatformEnum, +) + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu.logger import logger + +if TYPE_CHECKING: + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import CacheDType + from vllm.utils.argparse_utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = object + + +envs.environment_variables.update({ + "MLU_VISIBLE_DEVICES": + lambda: os.environ.get("MLU_VISIBLE_DEVICES", None) +}) + +logger = init_logger(__name__) + + +class MLUPlatform(Platform): + _enum = PlatformEnum.OOT + device_name: str = "mlu" + device_type: str = "mlu" + dispatch_key: str = "MLU" + ray_device_key: str = "GPU" + device_control_env_var: str = "MLU_VISIBLE_DEVICES" + simple_compile_backend: str = "inductor" + dist_backend: str = "cncl" + + supported_quantization: list[str] = ["weightonly", "smoothquant", + "awq_mlu", "gptq_mlu", "fp8"] + additional_env_vars: list[str] = ["VLLM_LATENCY_DEBUG", + "VLLM_LATENCY_DEBUG_NO_DEVICE", + "MLU_GRAPH_CAPTURE_LIST", + "VLLM_LOGITS_USE_ALL_GATHER", + "VLLM_V1_USE_FULL_GRAPH", + "VLLM_MTP_FIXED_ACCEPTANCE_RATE"] + + @classmethod + def import_kernels(cls) -> None: + """Import any platform-specific C kernels.""" + try: + import torch_mlu_ops + except ImportError as e: + logger.warning("Failed to import from torch_mlu_ops with %r", e) + + @classmethod + def pre_register_and_update( + cls, parser: FlexibleArgumentParser | None = None + ) -> None: + from vllm_mlu.model_executor.layers.quantization import ( + register_real_mlu_quantization_methods + ) + register_real_mlu_quantization_methods() + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + attn_type: str | None = None, + ) -> str: + + if use_mla: + logger.info(f"[MLU-V1][MLA] Select FlashMLABackend.") + return "vllm_mlu.v1.attention.backends.mla.flashmla.FlashMLABackend" + else: + logger.info(f"[MLU-V1] Select FlashAttentionBackend.") + return "vllm_mlu.v1.attention.backends.flash_attn.MLUFlashAttentionBackend" + + @classmethod + @lru_cache(maxsize=8) + def get_device_capability( + cls, + device_id: int = 0, + ) -> DeviceCapability | None: + try: + major, minor = torch.mlu.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + except RuntimeError: + return None + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.mlu.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.mlu.get_device_properties(device_id) + return device_props.total_memory + + @classmethod + def set_device(cls, device: torch.device): + torch.mlu.set_device(device) + + @classmethod + def empty_cache(cls): + torch.mlu.empty_cache() + + @classmethod + def synchronize(cls): + torch.mlu.synchronize() + + @classmethod + def mem_get_info(cls) -> Tuple[int, int]: + return torch.mlu.mem_get_info() + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + + @classmethod + def inference_mode(cls): + return torch.no_grad() + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + + @classmethod + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + model_config = vllm_config.model_config + speculative_config = vllm_config.speculative_config + kv_transfer_config = vllm_config.kv_transfer_config + mlu_config = vllm_config.mlu_config + + # Decode use full mlugraph: V1 mode + VLLM_V1_USE_FULL_GRAPH=true + use_full_mlugraph = mlu_envs.VLLM_V1_USE_FULL_GRAPH + + # Check compilation config + from vllm.config import CompilationMode, CUDAGraphMode + logger.info( + "[MLU] Force select CompilationMode.None, CUDAGraphMode.FULL_DECODE_ONLY." + ) + compilation_config.level = None + compilation_config.mode = CompilationMode.NONE + compilation_config.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY + + # Dispatch worker + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm_mlu.v1.worker.gpu_worker.MLUWorker" + + cls.simple_compile_backend = "inductor" + # Activate custom ops for v1. + compilation_config.custom_ops = ["all"] + if compilation_config.splitting_ops is None: + compilation_config.splitting_ops = [] + compilation_config.splitting_ops.extend(["vllm.rope_forward"]) + + # FIXME: support cascade attention in VLLM-1710 + model_config = vllm_config.model_config + if model_config: + model_config.disable_cascade_attn = True + + # Select v1 scheduler type + if scheduler_config: + if not scheduler_config.async_scheduling: + if (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED + and not scheduler_config.enable_chunked_prefill): + vllm_config.scheduler_config.scheduler_cls = \ + "vllm_mlu.v1.core.sched.scheduler.MLUUnchunkScheduler" + logger.info(f"[MLU-V1] Select UnchunkScheduler.") + else: + vllm_config.scheduler_config.scheduler_cls = \ + "vllm_mlu.v1.core.sched.scheduler.SchedulerWithProfiler" + logger.info(f"[MLU-V1] Select ChunkScheduler.") + else: + if (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED + and not scheduler_config.enable_chunked_prefill): + vllm_config.scheduler_config.scheduler_cls = \ + "vllm_mlu.v1.core.sched.async_scheduler.MLUUnchunkAsyncScheduler" + logger.info(f"[MLU-V1] Select UnchunkAsyncScheduler.") + + + # Check cache config + if cache_config: + logger.info( + f"[MLU] Select kv_cache_dtype={cache_config.cache_dtype}." + ) + if cache_config.block_size is None: + cache_config.block_size = 16 + + # Check mla config + if model_config and model_config.use_mla: + if (mlu_config.is_dpsk_mcc_enabled or not use_full_mlugraph): + scheduler_config.enable_chunked_prefill = False + scheduler_config.chunked_prefill_enabled = False + logger.warning( + "[MLA] Chunked prefill is disabled when deepseek mcc is enabled, " + "or not use full mlugraph.") + + if mlu_config.is_dpsk_mcc_enabled: + cache_config.enable_prefix_caching = False + logger.warning("[MLA] Prefix Caching is disabled when deepseek mcc is enabled.") + + + # For mlu benchmark, we allow max_num_batched_tokens < max_model_len + # in certain scenarios. + if ( + model_config + and scheduler_config.max_num_batched_tokens < model_config.max_model_len + and not scheduler_config.chunked_prefill_enabled + ): + msg = f"max_num_batched_tokens ({scheduler_config.max_num_batched_tokens}) is " + \ + f"smaller than max_model_len ({model_config.max_model_len}). " + \ + "This effectively limits the maximum sequence length to " + \ + "max_num_batched_tokens and makes vLLM reject longer " + \ + "sequences. Please increase max_num_batched_tokens or " + \ + "decrease max_model_len." + if not mlu_envs.VLLM_V1_BENCHMARK: + raise ValueError(msg) + else: + logger.warning(msg) + + if (mlu_config.dispatch_shared_expert_parallel + and parallel_config.data_parallel_size <= 1 + and not mlu_config.prefill_use_sequence_parallel): + mlu_config.dispatch_shared_expert_parallel = False + logger.info( + "Disabling `mlu_config.dispatch_shared_expert_parallel` when " + "data_parallel_size == 1 or not using sequence parallel." + ) + + # Check kv_transfer config + if kv_transfer_config: + # Register mlu kv_connectors + import vllm_mlu.distributed.kv_transfer.kv_connector.factory + + @classmethod + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: + torch.mlu.reset_peak_memory_stats(device) + return torch.mlu.max_memory_allocated(device) + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm_mlu.lora.punica_wrapper.punica_mlu.PunicaWrapperMLU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm_mlu.distributed.device_communicators.mlu_communicator.MLUCommunicator" + + @classmethod + def use_all_gather(cls) -> bool: + return True + + @classmethod + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm_mlu.compilation.mlu_graph.MLUGraphWrapper" + + @classmethod + def can_update_inplace(cls) -> bool: + """ + Checks if the platform allows inplace memory updates + """ + return True + + def is_sleep_mode_available(self) -> bool: + return True + + @classmethod + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + + @classmethod + def support_static_graph_mode(cls) -> bool: + """ + Returns if the graph mode is supported by the current platform. + """ + return True diff --git a/vllm_mlu/profiler/mlu_profiler.py b/vllm_mlu/profiler/mlu_profiler.py new file mode 100644 index 0000000..6492c04 --- /dev/null +++ b/vllm_mlu/profiler/mlu_profiler.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class MluProfilerWrapper: + def __init__(self) -> None: + self._profiler_running = False + # Note: lazy import to avoid dependency issues if MLU is not available. + import torch.mlu.profiler as mlu_profiler + + self._mlu_profiler = mlu_profiler + + def start(self) -> None: + try: + self._mlu_profiler.start() + self._profiler_running = True + logger.info_once("Started MLU profiler") + except Exception as e: + logger.warning_once("Failed to start MLu profiler: %s", e) + + def stop(self) -> None: + if self._profiler_running: + try: + self._mlu_profiler.stop() + logger.info_once("Stopped MLU profiler") + except Exception as e: + logger.warning_once("Failed to stop MLU profiler: %s", e) + finally: + self._profiler_running = False + + def shutdown(self) -> None: + """Ensure profiler is stopped when shutting down.""" + self.stop() diff --git a/vllm_mlu/utils.py b/vllm_mlu/utils.py new file mode 100644 index 0000000..07a6d49 --- /dev/null +++ b/vllm_mlu/utils.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from __future__ import annotations + +import contextlib +import gc +import os +import time +import torch +from torch.library import Library +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Optional, Callable, Tuple, Generator + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.ray.lazy_utils import is_in_ray_actor +from vllm.utils import ( + torch_utils, + system_utils, +) +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + supports_custom_op, + vllm_lib, +) +from vllm.utils.mem_utils import GiB_bytes +from vllm.utils.platform_utils import ( + cuda_is_initialized, + xpu_is_initialized, +) +from vllm.logger import init_logger + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +STR_DTYPE_TO_TORCH_DTYPE["int8"] = torch.int8 + + +@dataclass +class MemorySnapshot: + """Memory snapshot.""" + torch_peak: int = 0 + free_memory: int = 0 + total_memory: int = 0 + mlu_memory: int = 0 + torch_memory: int = 0 + non_torch_memory: int = 0 + timestamp: float = 0.0 + auto_measure: bool = True + + def __post_init__(self): + if self.auto_measure: + self.measure() + + def measure(self): + # we measure the torch peak memory usage via allocated_bytes, + # rather than `torch.mlu.memory_reserved()` . + # After `torch.mlu.reset_peak_memory_stats()`, + # `torch.mlu.memory_reserved()` will keep growing, and only shrink + # when we call `torch.mlu.empty_cache()` or OOM happens. + self.torch_peak = torch.mlu.memory_stats().get( + "allocated_bytes.all.peak", 0) + + self.free_memory, self.total_memory = torch.mlu.mem_get_info() + self.mlu_memory = self.total_memory - self.free_memory + + # torch.mlu.memory_reserved() is how many bytes + # PyTorch gets from mlu (by calling mluMalloc, etc.) + # this is used to measure the non-torch memory usage + self.torch_memory = torch.mlu.memory_reserved() + + self.non_torch_memory = self.mlu_memory - self.torch_memory + self.timestamp = time.time() + + def __sub__(self, other: MemorySnapshot) -> MemorySnapshot: + return MemorySnapshot( + torch_peak=self.torch_peak - other.torch_peak, + free_memory=self.free_memory - other.free_memory, + total_memory=self.total_memory - other.total_memory, + mlu_memory=self.mlu_memory - other.mlu_memory, + torch_memory=self.torch_memory - other.torch_memory, + non_torch_memory=self.non_torch_memory - other.non_torch_memory, + timestamp=self.timestamp - other.timestamp, + auto_measure=False, + ) + + +@dataclass +class MemoryProfilingResult: + """Memory profiling result. All numbers are in bytes. + """ + non_kv_cache_memory: int = 0 + torch_peak_increase: int = 0 + non_torch_increase: int = 0 + weights_memory: float = 0 + before_create: MemorySnapshot = field(default_factory=MemorySnapshot) + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + profile_time: float = 0.0 + + def __repr__(self) -> str: + return (f"Memory profiling takes {self.profile_time:.2f} seconds. " + f"Total non KV cache memory: " + f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " + f"torch peak memory increase: " + f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " + f"non-torch forward increase memory: " + f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " + f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.") + + +@contextlib.contextmanager +def memory_profiling( + baseline_snapshot: MemorySnapshot, + weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_snapshot: the memory snapshot before the current vLLM instance. + weights_memory: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory`. + + The increase of `torch.mlu.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). + + The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). + """ # noqa + gc.collect() + torch.mlu.empty_cache() + torch.mlu.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.before_create = baseline_snapshot + # the part of memory used for holding the model weights + result.weights_memory = weights_memory + + result.before_profile.measure() + + yield result + + gc.collect() + torch.mlu.empty_cache() + + result.after_profile.measure() + + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + result.torch_peak_increase = diff_profile.torch_peak + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +@lru_cache(maxsize=8) +def _mlu_device_count_stateless( + mlu_visible_devices: Optional[str] = None) -> int: + + if mlu_visible_devices is None: + return torch.mlu.device_count() + if mlu_visible_devices == "": + return 0 + if "," not in mlu_visible_devices: + return 1 + return len(mlu_visible_devices.split(",")) + + +def mlu_device_count_stateless() -> int: + """Get number of MLU devices, caching based on the value of + MLU_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless MLU_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _mlu_device_count_stateless(os.environ.get("MLU_VISIBLE_DEVICES", "mlu")) + + +def vllm__utils_system_utils___maybe_force_spawn(): + """Check if we need to force the use of the `spawn` multiprocessing start + method. + """ + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": + return + + reasons = [] + if is_in_ray_actor(): + # even if we choose to spawn, we need to pass the ray address + # to the subprocess so that it knows how to connect to the ray cluster. + # env vars are inherited by subprocesses, even if we use spawn. + import ray + + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address + reasons.append("In a Ray actor and can only be spawned") + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Force use spawn for MLU platform. + ''' + if cuda_is_initialized(): + reasons.append("CUDA is initialized") + elif xpu_is_initialized(): + reasons.append("XPU is initialized") + elif current_platform.is_out_of_tree(): + reasons.append("MLU is initialized") + ''' + ================== + End of MLU Hijack + ================== + ''' + + if reasons: + logger.warning( + "We must use the `spawn` multiprocessing start method. " + "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/getting_started/" + "troubleshooting.html#python-multiprocessing " + "for more information. Reason: %s", reasons) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +''' +============================= +Modify by vllm_mlu +============================= +@brief: change dispatch_key default value from 'CUDA' to 'MLU' +''' +vllm__utils__torch_utils__direct_register_custom_op_org = torch_utils.direct_register_custom_op +def vllm__utils__torch_utils__direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = [], + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str = "MLU", + tags: Tuple[torch.Tag, ...] = (), +): + vllm__utils__torch_utils__direct_register_custom_op_org( + op_name=op_name, + op_func=op_func, + mutates_args=mutates_args, + fake_impl=fake_impl, + target_lib=target_lib, + dispatch_key=dispatch_key, + tags=tags, + ) +''' +================== +End of MLU Hijack +================== +''' + + +MluHijackObject.apply_hijack(torch_utils, + torch_utils.direct_register_custom_op, + vllm__utils__torch_utils__direct_register_custom_op) +MluHijackObject.apply_hijack(system_utils, + system_utils._maybe_force_spawn, + vllm__utils_system_utils___maybe_force_spawn) \ No newline at end of file diff --git a/vllm_mlu/v1/__init__.py b/vllm_mlu/v1/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/attention/__init__.py b/vllm_mlu/v1/attention/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/attention/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/attention/backends/__init__.py b/vllm_mlu/v1/attention/backends/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/attention/backends/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/attention/backends/flash_attn.py b/vllm_mlu/v1/attention/backends/flash_attn.py new file mode 100644 index 0000000..0297777 --- /dev/null +++ b/vllm_mlu/v1/attention/backends/flash_attn.py @@ -0,0 +1,1050 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, ClassVar + +import numpy as np +import torch +import torch.nn.functional as F + +from vllm.attention.backends.abstract import (AttentionImpl, + AttentionMetadata, AttentionType, + is_quantized_kv_cache, + MultipleOf,) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms import current_platform +from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.config.vllm import VllmConfig +from vllm.v1.worker.block_table import BlockTable + +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, FlashAttentionMetadata, + FlashAttentionMetadataBuilder, + _get_sliding_window_configs +) +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm_mlu.v1.worker.gpu_model_runner import MLUModelRunner + +if current_platform.is_cuda(): + from vllm.attention.utils.fa_utils import get_scheduler_metadata + +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.v1.attention.backends.utils import ( + MLUCommonAttentionMetadata, + MLUInferMode, + get_common_metadata, +) +from vllm_mlu.model_executor.layers.quantization.utils.common_utils import attn_str_dtype_to_torch + +logger = init_logger(__name__) + + +class MLUFlashAttentionBackend(FlashAttentionBackend): + + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1, 16, 32, 64] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576] + + @staticmethod + def get_impl_cls() -> type["MLUFlashAttentionImpl"]: + return MLUFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return MLUFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["MLUFlashAttentionMetadataBuilder"]: + return MLUFlashAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def get_kv_cache_scale_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + ) -> tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size) + + +@dataclass +class MLUChunkFlashAttentionMetadata: + """ + Chunked prefill metadata for MLU backend, which splits both + input and metadata into prefill and decode phases. With splitting, + the MLU backend can invoke FA and single_query_cached_kv_attn kerels + seperately, thus yields better performance. + """ + @dataclass + class ChunkContextMetadata: + """ + ChunkContextMetadata for prefill chunks and decode tokens. + """ + batch_size: int + num_actual_tokens: int + cu_seqlens_q: torch.Tensor + cu_seqlens_kv: torch.Tensor + max_query_len: int + max_seq_len: int + total_seqlens: int = 0 + + prefill_ctx: ChunkContextMetadata + decode_ctx: ChunkContextMetadata + + @classmethod + def build( + cls, + common_attn_metadata: MLUCommonAttentionMetadata, + uniform_decode_query_len: int = 1, + ): + assert common_attn_metadata.infer_mode.is_chunked + ( + num_decodes, + num_prefills, + num_decode_tokens, + num_prefill_tokens, + ) = split_decodes_and_prefills(common_attn_metadata, + uniform_decode_query_len, + require_uniform=True) + # split cu_seqlens_q and cu_seqlens_kv + query_start_loc = common_attn_metadata.query_start_loc + d_cu_seqlens_q = query_start_loc[:num_decodes + 1] + p_cu_seqlens_q = query_start_loc[num_decodes:] - query_start_loc[num_decodes] + seq_start_loc = common_attn_metadata.seq_start_loc + d_cu_seqlens_kv = seq_start_loc[:num_decodes + 1] + p_cu_seqlens_kv = seq_start_loc[num_decodes:] - seq_start_loc[num_decodes] + # compute max_query_len and max_seq_len after split + # NOTE: use cpu tensor to avoid d2h copy. + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + query_len_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + seq_len_cpu = common_attn_metadata.seq_lens_cpu + d_max_query_len = 0 + d_max_seq_len = 0 + p_max_query_len = 0 + p_max_seq_len = 0 + p_total_seqlens = 0 + if num_decodes > 0: + d_max_query_len = query_len_cpu[:num_decodes].max().item() + d_max_seq_len = seq_len_cpu[:num_decodes].max().item() + if num_prefills > 0: + p_max_query_len = query_len_cpu[num_decodes:].max().item() + p_max_seq_len = seq_len_cpu[num_decodes:].max().item() + p_total_seqlens = seq_len_cpu[num_decodes:].sum().item() + + return MLUChunkFlashAttentionMetadata( + prefill_ctx=MLUChunkFlashAttentionMetadata. + ChunkContextMetadata( + batch_size=num_prefills, + num_actual_tokens=num_prefill_tokens, + cu_seqlens_q=p_cu_seqlens_q, + cu_seqlens_kv=p_cu_seqlens_kv, + max_query_len=p_max_query_len, + max_seq_len=p_max_seq_len, + total_seqlens=p_total_seqlens, + ), + decode_ctx=MLUChunkFlashAttentionMetadata. + ChunkContextMetadata( + batch_size=num_decodes, + num_actual_tokens=num_decode_tokens, + cu_seqlens_q=d_cu_seqlens_q, + cu_seqlens_kv=d_cu_seqlens_kv, + max_query_len=d_max_query_len, + max_seq_len=d_max_seq_len, + ), + ) + +@dataclass +class MLUFlashAttentionMetadata(FlashAttentionMetadata): + # For mlu infer + seq_start_loc: torch.Tensor | None = None + infer_mode: MLUInferMode | None = None + num_input_tokens: int = 0 # Number of tokens including padding. + compute_dtype: torch.dtype = torch.float32 + chunk_fa_metadata: MLUChunkFlashAttentionMetadata | None = None + + @property + def num_decode_tokens(self): + assert self.infer_mode is not None, ( + f"MLUFlashAttentionMetadata infer_mode is not set." + ) + + if self.infer_mode == MLUInferMode.PREFILL_ONLY: + return 0 + + if self.infer_mode == MLUInferMode.DECODE_ONLY: + return self.num_actual_tokens + + assert self.chunk_fa_metadata is not None, ( + f"chunk_fa_metadata must be set under chunked infer mode." + ) + return self.chunk_fa_metadata.decode_ctx.num_actual_tokens + + +def pad_attn_metadata( + attn_metadata: MLACommonMetadata | FlashAttentionMetadata, + common_metadata: MLUCommonAttentionMetadata, + block_table: BlockTable, + runner: "MLUModelRunner", + num_scheduled_tokens: int, + num_input_tokens: int, + num_reqs: int, + num_paded_reqs: int, +) -> None: + is_mla = isinstance(attn_metadata, MLACommonMetadata) + if is_mla: + assert attn_metadata.prefill is None and attn_metadata.decode is not None + + pad_token_num = num_input_tokens - num_scheduled_tokens + pad_req_num = num_paded_reqs - num_reqs + if pad_token_num == 0: + return + + query_start_loc_cpu = runner.query_start_loc.cpu[:num_paded_reqs + 1] + query_start_loc = runner.query_start_loc.gpu[:num_paded_reqs + 1] + seq_lens_cpu = runner.seq_lens.cpu[:num_paded_reqs] + seq_lens = runner.seq_lens.gpu[:num_paded_reqs] + if pad_req_num > 0: + query_lens = torch.diff(query_start_loc_cpu[:num_reqs + 1]) + pad_lens = torch.full( + (pad_req_num,), + pad_token_num // pad_req_num, + dtype=query_lens.dtype, + device=query_lens.device) + query_lens = torch.cat([query_lens, pad_lens]) + torch.cumsum(query_lens, dim=0, out=query_start_loc_cpu[1:]) + query_start_loc.copy_(query_start_loc_cpu, non_blocking=True) + seq_lens_cpu[num_reqs:].fill_(common_metadata.max_query_len) + seq_lens[num_reqs:].fill_(common_metadata.max_query_len) + + seq_start_loc_cpu = runner.seq_start_loc.cpu[:(num_paded_reqs + 1)] + seq_start_loc = runner.seq_start_loc.gpu[:(num_paded_reqs + 1)] + torch.cumsum(seq_lens, dim=0, out=seq_start_loc[1:]) + torch.cumsum(seq_lens_cpu, dim=0, out=seq_start_loc_cpu[1:]) + + slot_mapping_org_num = attn_metadata.slot_mapping.numel() + slot_mapping = block_table.slot_mapping.gpu[:(slot_mapping_org_num + pad_token_num)] + slot_mapping[slot_mapping_org_num:] = PAD_SLOT_ID + + block_table = block_table.get_device_tensor(num_paded_reqs) + + attn_metadata.slot_mapping = slot_mapping + attn_metadata.query_start_loc = query_start_loc + if is_mla: + attn_metadata.decode.query_start_loc = query_start_loc + attn_metadata.decode.seq_lens = seq_lens + attn_metadata.decode.block_table = block_table + else: + attn_metadata.seq_lens = seq_lens + attn_metadata.seq_start_loc = seq_start_loc + attn_metadata.block_table = block_table + + common_metadata.num_input_tokens = num_input_tokens + common_metadata.seq_start_loc = seq_start_loc + common_metadata.seq_start_loc_cpu = seq_start_loc_cpu + common_metadata.query_start_loc = query_start_loc + common_metadata.query_start_loc_cpu = query_start_loc_cpu + common_metadata.seq_lens = seq_lens + common_metadata.seq_lens_cpu = seq_lens_cpu + common_metadata.num_reqs = num_paded_reqs + common_metadata.block_table_tensor = block_table + common_metadata.slot_mapping = slot_mapping + + +class MLUFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + cudagraph_support = ( + AttentionCGSupport.UNIFORM_BATCH + ) + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add class member - uniform_decode_query_len + ''' + self.uniform_decode_query_len = ( + 1 if not self.vllm_config.speculative_config + else 1 + self.vllm_config.speculative_config.num_speculative_tokens + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + def build( + self, + common_prefix_len: int, + common_attn_metadata: MLUCommonAttentionMetadata, + fast_build: bool = False, + ) -> MLUFlashAttentionMetadata: + """ + fast_build disables AOT scheduling, used when there will be few + iterations i.e. spec-decode + """ + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add seq_start_loc for chunk fa + ''' + seq_start_loc = common_attn_metadata.seq_start_loc + ''' + ================== + End of MLU Hijack + ================== + ''' + + # the overhead of the aot schedule is not worth it for spec-decode + aot_schedule = self.aot_schedule and not fast_build + + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if aot_schedule: + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + aot_schedule = False + + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + if vllm_is_batch_invariant(): + max_num_splits = 1 + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + cache_dtype + ) + else: + qkv_dtype = self.kv_cache_dtype + if aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads_q * self.dcp_world_size, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + cache_seqlens=seqlens, + qkv_dtype=qkv_dtype, + cu_seqlens_q=cu_query_lens, + page_size=self.block_size, + causal=causal, + window_size=self.aot_sliding_window, + num_splits=max_num_splits, + ) + return None + + use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + + dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( + dcp_context_kv_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() + + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( + self.device, non_blocking=True + ) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) + else: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) + # For FA3 + full cudagraph + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: 1. build MLUChunkFlashAttentionMetadata to split prefill and decode; + 2. replace metadata with MLUFlashAttnetionMetadta. + ''' + chunk_fa_metadata = None + if common_attn_metadata.infer_mode.is_chunked: + chunk_fa_metadata = MLUChunkFlashAttentionMetadata.build( + common_attn_metadata, + self.uniform_decode_query_len, + ) + + attn_metadata = MLUFlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + max_num_splits=max_num_splits, + causal=causal, + # For mlu infer + seq_start_loc=common_attn_metadata.seq_start_loc, + infer_mode=common_attn_metadata.infer_mode, + chunk_fa_metadata=chunk_fa_metadata, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + return attn_metadata + + +class MLUFlashAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + sinks: torch.Tensor | None = None, + **extra_impl_args, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: 1. move alibi_slopes to mlu, + 2. sliding_window_right only support -1. + 3. add self.use_fused_mla_qkv. + ''' + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32).mlu() + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.is_mla = extra_impl_args.get("is_mla", False) + self.use_fused_mla_qkv = extra_impl_args.get("use_fused_mla_qkv", False) + + self.decoder_attn_dtype = extra_impl_args.get("decoder_attn_dtype", None) + ''' + ================== + End of MLU Hijack + ================== + ''' + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() + + self.sinks = sinks + if self.sinks is not None: + assert flash_attn_supports_sinks(), ( + "Sinks are only supported in FlashAttention 3" + ) + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer" + ) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: MLUFlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + kwargs: dict[str, Any] = {}, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for FlashAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: set mlu infer mode. + ''' + infer_mode = attn_metadata.infer_mode + assert not attn_metadata.use_cascade, ( + f"MLU not support use_cascade={attn_metadata.use_cascade}, " + + f"attn_metadata={attn_metadata}." + ) + assert self.dcp_world_size <= 1, ( + f"MLU not support dcp_world_size={self.dcp_world_size}." + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + + # For decoder and cross-attention, use KV cache as before + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: kv_cache[0] is [key_cache, value_cache], and + kv_cache[1] is [key_cache_scale, value_cache_scale]. + ''' + key_cache, value_cache = kv_cache[0].unbind(0) + if is_quantized_kv_cache(self.kv_cache_dtype): + key_cache_scale, value_cache_scale = kv_cache[1].unbind(0) + else: + key_cache_scale = None + value_cache_scale = None + ''' + ================== + End of MLU Hijack + ================== + ''' + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: skip store key/value to kv cache in mla prefill phase. + @brief: support value is None. + ''' + skip_process_cache = ( + self.is_mla + and (infer_mode.is_prefill_only or self.use_fused_mla_qkv) + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + and not skip_process_cache + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: store key/value cache with mlu ops. + ''' + if is_quantized_kv_cache(self.kv_cache_dtype): + mlu_ops.quant_to_paged_cache( + k=key[:num_actual_tokens], + v=(None if self.is_mla else value[:num_actual_tokens]), + k_cache=key_cache, + v_cache=value_cache, + k_cache_quant_scale=key_cache_scale, + v_cache_quant_scale=value_cache_scale, + slot_mapping=attn_metadata.slot_mapping.flatten(), + ) + else: + mlu_ops.reshape_paged_cache( + k=key[:num_actual_tokens], + v=(None if self.is_mla else value[:num_actual_tokens]), + k_cache=key_cache, + v_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping.flatten(), + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: skip cascade attention for mlu platform. + ''' + if attn_metadata.use_cascade: + raise RuntimeError( + f"mlu v1 not support use_cascade={attn_metadata.use_cascade}, " + + f"attn_metadata={attn_metadata}." + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_kv = attn_metadata.seq_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + alibi_slopes = ( + None if self.alibi_slopes is None + else self.alibi_slopes.repeat(seqused_k.shape[0], 1) + ) + head_size_v = value.shape[-1] if self.is_mla else self.head_size + q_quant_scale = kwargs.get("q_quant_scale", None) + + if infer_mode.is_prefill_only: + num_prefill_query_tokens = num_actual_tokens + num_prefill_kv_tokens = num_actual_tokens + mlu_ops.flash_attention( + q=query[:num_prefill_query_tokens], + k=key[:num_prefill_kv_tokens], + v=value[:num_prefill_kv_tokens], + out=output[:num_prefill_query_tokens], + cu_seq_lens_q=cu_seqlens_q, + cu_seq_lens_kv=cu_seqlens_kv, + alibi_slope=alibi_slopes, + attn_bias=None, + max_seq_len_q=max_seqlen_q, + max_seq_len_kv=max_seqlen_k, + softmax_scale=self.scale, + is_causal=True, + window_size_left=self.sliding_window[0], + window_size_right=self.sliding_window[1], + compute_dtype=attn_metadata.compute_dtype, + return_lse=False, + ) + elif infer_mode.is_chunked: + # prefill & decode mixed + # NOTE: Split prefill chunks and decode tokens will + # get better performance on MLU devices. + chunk_fa_metadata = attn_metadata.chunk_fa_metadata + prefill_ctx = chunk_fa_metadata.prefill_ctx + decode_ctx = chunk_fa_metadata.decode_ctx + num_decodes = decode_ctx.batch_size + num_decode_tokens = decode_ctx.num_actual_tokens + num_prefills = prefill_ctx.batch_size + if num_prefills > 0: + self._forward_prefill_chunk( + query=query[num_decode_tokens:], + key_cache=key_cache, + value_cache=value_cache, + output=output[num_decode_tokens:], + block_table=block_table[num_decodes:], + seqused_k=seqused_k[num_decodes:], + compute_dtype=attn_metadata.compute_dtype, + prefill_ctx=prefill_ctx, + alibi_slopes=alibi_slopes, + key_cache_scale=key_cache_scale, + value_cache_scale=value_cache_scale, + ) + if num_decodes > 0: + if q_quant_scale is not None: + q_quant_scale = q_quant_scale[:num_decode_tokens] + self._forward_decode_only( + query=query[:num_decode_tokens], + key_cache=key_cache, + value_cache=value_cache, + output=output[:num_decode_tokens], + block_table=block_table[:num_decodes], + seqused_k=seqused_k[:num_decodes], + max_seqlen_k=decode_ctx.max_seq_len, + head_size_v=head_size_v, + compute_dtype=attn_metadata.compute_dtype, + alibi_slopes=alibi_slopes, + key_cache_scale=key_cache_scale, + value_cache_scale=value_cache_scale, + q_quant_scale=q_quant_scale, + ) + else: + # decode only + if q_quant_scale is not None: + q_quant_scale = q_quant_scale[:num_actual_tokens] + self._forward_decode_only( + query=query[:num_actual_tokens], + key_cache=key_cache, + value_cache=value_cache, + output=output[:num_actual_tokens], + block_table=block_table, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + head_size_v=head_size_v, + compute_dtype=attn_metadata.compute_dtype, + alibi_slopes=alibi_slopes, + key_cache_scale=key_cache_scale, + value_cache_scale=value_cache_scale, + q_quant_scale=q_quant_scale, + ) + return output + + def _forward_prefill_chunk( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + block_table: torch.Tensor, + seqused_k: torch.Tensor, + compute_dtype: torch.dtype, + prefill_ctx: MLUChunkFlashAttentionMetadata.ChunkContextMetadata, + alibi_slopes: torch.Tensor | None = None, + key_cache_scale: torch.Tensor | None = None, + value_cache_scale: torch.Tensor | None = None, + ): + ''' + Compute prefill chunks when enable chunked_prefill. + NOTE: If the kv_cache is quantized, + will first be dequantized, and return continuous key and value. + ''' + if is_quantized_kv_cache(self.kv_cache_dtype): + total_seqlens = prefill_ctx.total_seqlens + key_cache_dequant = torch.zeros( + (total_seqlens, self.num_kv_heads, self.head_size), + dtype=query.dtype, + device=key_cache.device + ) + value_cache_dequant = None + if value_cache is not None: + value_cache_dequant = torch.zeros( + (total_seqlens, self.num_kv_heads, self.head_size), + dtype=query.dtype, + device=key_cache.device + ) + mlu_ops.dequant_from_paged_cache( + key=key_cache_dequant, + value=value_cache_dequant, + key_cache=key_cache, + value_cache=value_cache, + key_cache_quant_scale=key_cache_scale, + value_cache_quant_scale=value_cache_scale, + context_lengths=seqused_k, + max_context_len=prefill_ctx.max_seq_len, + context_seq_offset=None, + block_tables=block_table, + quant_mode=1, + quant_bit=8 + ) + block_table_dequant = None + else: + key_cache_dequant = key_cache + value_cache_dequant = value_cache + block_table_dequant = block_table + mlu_ops.flash_attention( + q=query, + k=key_cache_dequant, + v=value_cache_dequant, + out=output, + cu_seq_lens_q=prefill_ctx.cu_seqlens_q, + cu_seq_lens_kv=prefill_ctx.cu_seqlens_kv, + alibi_slope=alibi_slopes, + attn_bias=None, + max_seq_len_q=prefill_ctx.max_query_len, + max_seq_len_kv=prefill_ctx.max_seq_len, + softmax_scale=self.scale, + is_causal=True, + window_size_left=self.sliding_window[0], + window_size_right=self.sliding_window[1], + compute_dtype=compute_dtype, + return_lse=False, + block_tables=block_table_dequant, + ) + + def _forward_decode_only( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + block_table: torch.Tensor, + seqused_k: torch.Tensor, + max_seqlen_k: int, + head_size_v: int, + compute_dtype: torch.dtype, + alibi_slopes: torch.Tensor | None = None, + key_cache_scale: torch.Tensor | None = None, + value_cache_scale: torch.Tensor | None = None, + q_quant_scale: torch.Tensor | None = None, + ): + ''' + Compute decode tokens only. + NOTE: Query only support pad mode, be careful when using MTP model. + ''' + batch_size = block_table.shape[0] + decode_query = query.view(batch_size, -1, self.num_heads, self.head_size) + decode_output = output.view(batch_size, -1, self.num_heads, head_size_v) + if q_quant_scale is not None: + q_quant_scale = q_quant_scale.view(batch_size, -1, self.num_heads) + mlu_ops.single_query_cached_kv_attn( + q=decode_query, + k_cache=key_cache, + v_cache=value_cache, + out=decode_output, + block_tables=block_table, + context_lens=seqused_k, + k_cache_quant_scale=key_cache_scale, + v_cache_quant_scale=value_cache_scale, + alibi_slopes=alibi_slopes, + max_contxt_len=max_seqlen_k, + windows_size_left=self.sliding_window[0], + windows_size_right=self.sliding_window[1], + softmax_scale=self.scale, + head_size_v=(-1 if not self.is_mla else head_size_v), + compute_dtype=compute_dtype, + q_quant_scale=q_quant_scale, + decoder_attn_dtype=self.decoder_attn_dtype, + ) + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention" + ) + + # Use encoder-specific metadata for sequence information + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_k = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_query_len + + # Call flash attention directly on Q, K, V tensors + mlu_ops.flash_attention( + q=query, + k=key, + v=value, + out=output, + cu_seq_lens_q=cu_seqlens_q, + cu_seq_lens_kv=cu_seqlens_k, + alibi_slope=None, + attn_bias=None, + max_seq_len_q=max_seqlen_q, + max_seq_len_kv=max_seqlen_k, + softmax_scale=self.scale, + is_causal=False, # Encoder attention is bidirectional + window_size_left=self.sliding_window[0], + window_size_right=self.sliding_window[1], + compute_dtype=attn_metadata.compute_dtype, + return_lse=False, + ) + + return output + + + diff --git a/vllm_mlu/v1/attention/backends/gdn_attn.py b/vllm_mlu/v1/attention/backends/gdn_attn.py new file mode 100644 index 0000000..3297150 --- /dev/null +++ b/vllm_mlu/v1/attention/backends/gdn_attn.py @@ -0,0 +1,404 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import torch +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from collections import OrderedDict, deque + +from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.gdn_attn import (GDNAttentionMetadataBuilder, + GDNAttentionMetadata, + ) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills,) + + +class DeviceAwareLocalIdMapper: + def __init__(self, batch_size: int): + if batch_size <= 0: + raise ValueError("batch_size must be positive") + self.batch_size = batch_size + self.global_to_local: OrderedDict[int, int] = OrderedDict() + self.local_to_global = {} + self.available_local_ids = deque(range(batch_size)) + + def batch_get_local_ids(self, global_id_tensor: torch.Tensor) -> torch.Tensor: + original_device = global_id_tensor.device + original_shape = global_id_tensor.shape + + flat_global_cpu = global_id_tensor.cpu().numpy().ravel() + num_elements = flat_global_cpu.size + local_ids_cpu = torch.empty(num_elements, dtype=global_id_tensor.dtype) + + g2l = self.global_to_local + unique_miss_set = set() + + # Pass 1: handle hits and collect unique misses + for i, gid in enumerate(flat_global_cpu): + if gid in g2l: + local_id = g2l[gid] + local_ids_cpu[i] = local_id + g2l.move_to_end(gid) + else: + local_ids_cpu[i] = -1 + unique_miss_set.add(gid) + + # Pass 2: assign local IDs to unique new global IDs + new_mappings = {} + available = self.available_local_ids + local_to_global = self.local_to_global + + for gid in unique_miss_set: + if len(g2l) >= self.batch_size: + old_gid, old_local = g2l.popitem(last=False) + available.append(old_local) + local_to_global.pop(old_local, None) + new_local = available.popleft() + g2l[gid] = new_local + local_to_global[new_local] = gid + new_mappings[gid] = new_local + + # Pass 3: fill in all miss positions + for i, gid in enumerate(flat_global_cpu): + if local_ids_cpu[i].item() == -1: + local_ids_cpu[i] = new_mappings[gid] + + return local_ids_cpu.to(original_device).view(original_shape) + + def reset(self): + self.global_to_local.clear() + self.local_to_global.clear() + self.available_local_ids = deque(range(self.batch_size)) + +def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, +): + assert isinstance(kv_cache_spec, MambaSpec) + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.speculative_config = vllm_config.speculative_config + self.kv_cache_spec = kv_cache_spec + if self.speculative_config: + self.num_spec = self.speculative_config.num_speculative_tokens + else: + self.num_spec = 0 + self.use_spec_decode = self.num_spec > 0 + self._init_reorder_batch_threshold(1, self.use_spec_decode) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), + self.compilation_config.max_cudagraph_capture_size, + ) + + self.spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, self.num_spec + 1), + dtype=torch.int32, + device=device, + ) + self.non_spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.spec_sequence_masks = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self.spec_token_indx = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), + dtype=torch.int32, + device=device, + ) + self.non_spec_token_indx = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), + dtype=torch.int32, + device=device, + ) + self.spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1,), + dtype=torch.int32, + device=device, + ) + self.non_spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1,), + dtype=torch.int32, + device=device, + ) + self.num_accepted_tokens = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + self.mapper = DeviceAwareLocalIdMapper(self.vllm_config.mlu_config.mamba_support_max_batch_size) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + num_accepted_tokens: torch.Tensor | None = None, + num_decode_draft_tokens_cpu: torch.Tensor | None = None, + fast_build: bool = False, +) -> GDNAttentionMetadata: + m = common_attn_metadata + + query_start_loc = m.query_start_loc + context_lens = m.num_computed_tokens_cpu + context_lens_tensor = context_lens.to(query_start_loc.device) + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + + if ( + not self.use_spec_decode + or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0] + .sum() + .item() + == 0 + ): + spec_sequence_masks = None + num_spec_decodes = 0 + else: + spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 + num_spec_decodes = spec_sequence_masks.sum().item() + if num_spec_decodes == 0: + spec_sequence_masks = None + else: + spec_sequence_masks = spec_sequence_masks.to( + query_start_loc.device, non_blocking=True + ) + + if spec_sequence_masks is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(m, decode_threshold=1) + ) + num_spec_decode_tokens = 0 + spec_token_indx = None + non_spec_token_indx = None + spec_state_indices_tensor = None + non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + spec_query_start_loc = None + non_spec_query_start_loc = query_start_loc + num_accepted_tokens = None + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + non_spec_query_lens = query_lens[~spec_sequence_masks] + num_decodes = (non_spec_query_lens == 1).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes + num_decode_tokens = num_decodes + num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens + num_spec_decode_tokens = ( + query_lens.sum().item() - num_prefill_tokens - num_decode_tokens + ) + + if num_prefills == 0 and num_decodes == 0: + spec_token_size = min( + num_spec_decodes * (self.num_spec + 1), + query_start_loc[-1].item(), + ) + spec_token_indx = torch.arange( + spec_token_size, + dtype=torch.int32, + device=query_start_loc.device, + ) + non_spec_token_indx = torch.empty( + 0, dtype=torch.int32, device=query_start_loc.device + ) + spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] + non_spec_state_indices_tensor = None + spec_query_start_loc = query_start_loc + non_spec_query_start_loc = None + else: + spec_token_masks = torch.repeat_interleave( + spec_sequence_masks, query_lens + ) + index = torch.argsort(spec_token_masks) + num_non_spec_tokens = num_prefill_tokens + num_decode_tokens + non_spec_token_indx = index[:num_non_spec_tokens] + spec_token_indx = index[num_non_spec_tokens:] + + spec_state_indices_tensor = m.block_table_tensor[ + spec_sequence_masks, : self.num_spec + 1 + ] + non_spec_state_indices_tensor = m.block_table_tensor[ + ~spec_sequence_masks, 0 + ] + + spec_query_start_loc = torch.zeros( + num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] + ) + non_spec_query_start_loc = torch.zeros( + query_lens.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:], + ) + + assert num_accepted_tokens is not None + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + + if num_prefills > 0: + has_initial_state = context_lens_tensor > 0 + if spec_sequence_masks is not None: + has_initial_state = has_initial_state[~spec_sequence_masks] + nums_dict, batch_ptr, token_chunk_offset_ptr = ( + compute_causal_conv1d_metadata(non_spec_query_start_loc) + ) + else: + has_initial_state = None + num_actual_tokens = ( + num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens + ) + + # prepare tensors for cudagraph + # + # With speculative decoding, the xgrammar backend may rollback tokens + # and causing some sequences has less draft tokens than self.num_spec. + # + # In above cases, the max possible batch size for n tokens, can be + # min(n, cudagraph_max_bs). + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) + batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) + + self.spec_state_indices_tensor[:num_spec_decodes].copy_( + spec_state_indices_tensor, non_blocking=True + ) + spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] + spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + + self.spec_sequence_masks[:num_spec_decodes].copy_( + spec_sequence_masks, non_blocking=True + ) + spec_sequence_masks = self.spec_sequence_masks[:batch_size] + spec_sequence_masks[num_spec_decodes:].fill_(False) + + assert non_spec_token_indx is not None and spec_token_indx is not None + self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_( + non_spec_token_indx, non_blocking=True + ) + non_spec_token_indx = self.non_spec_token_indx[ + : non_spec_token_indx.size(0) + ] + + self.spec_token_indx[: spec_token_indx.size(0)].copy_( + spec_token_indx, non_blocking=True + ) + spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)] + + self.spec_query_start_loc[: num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True + ) + spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1] + spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens) + + self.num_accepted_tokens[:num_spec_decodes].copy_( + num_accepted_tokens, non_blocking=True + ) + num_accepted_tokens = self.num_accepted_tokens[:batch_size] + num_accepted_tokens[num_spec_decodes:].fill_(1) + + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) + batch_size = num_actual_tokens + + self.non_spec_state_indices_tensor[:num_decodes].copy_( + non_spec_state_indices_tensor, non_blocking=True + ) + non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ + :batch_size + ] + non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + + self.non_spec_query_start_loc[: num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True + ) + non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index] + non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] + non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + non_spec_state_indices_tensor = self.mapper.batch_get_local_ids(non_spec_state_indices_tensor) + ''' + ================== + End of MLU Hijack + ================== + ''' + attn_metadata = GDNAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_spec_decodes=num_spec_decodes, + num_spec_decode_tokens=num_spec_decode_tokens, + num_actual_tokens=num_actual_tokens, + has_initial_state=has_initial_state, + spec_query_start_loc=spec_query_start_loc, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_state_indices_tensor=spec_state_indices_tensor, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + spec_sequence_masks=spec_sequence_masks, + spec_token_indx=spec_token_indx, + non_spec_token_indx=non_spec_token_indx, + num_accepted_tokens=num_accepted_tokens, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, + ) + return attn_metadata + +MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder, + GDNAttentionMetadataBuilder.__init__, + vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__) + +MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder, + GDNAttentionMetadataBuilder.build, + vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build) diff --git a/vllm_mlu/v1/attention/backends/mla/flashmla.py b/vllm_mlu/v1/attention/backends/mla/flashmla.py new file mode 100644 index 0000000..502af97 --- /dev/null +++ b/vllm_mlu/v1/attention/backends/mla/flashmla.py @@ -0,0 +1,934 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils.math_utils import cdiv, round_down +from vllm.attention.backends.utils import MLADims +from vllm.config import ModelConfig +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, MLACommonPrefillMetadata, + MLACommonDecodeMetadata, MLACommonMetadata, + MLACommonMetadataBuilder, M, QueryLenSupport, + use_cudnn_prefill, use_flashinfer_prefill, + use_trtllm_ragged_deepseek_prefill, + FlashInferPrefillMetadata, + CudnnPrefillMetadata, + MLACommonImpl, + CUDNN_WORKSPACE_SIZE +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, split_decodes_and_prefills, + infer_global_hyperparameters, get_per_layer_parameters, + ) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + MLAAttentionImpl, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu import _mlu_ops as mlu_ops +from vllm_mlu.v1.attention.backends.flash_attn import MLUFlashAttentionImpl +from vllm_mlu.v1.attention.backends.utils import ( + MLUCommonAttentionMetadata, get_common_metadata, + MLUInferMode) +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.platforms import current_platform +from vllm import envs +try: + from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401 + + flashinfer_available = True +except ImportError: + BatchPrefillWithRaggedKVCacheWrapper = object + + flashinfer_available = False + +logger = init_logger(__name__) + + + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +class MLACommonBackend_MluHijack(MLACommonBackend): + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576, 512] + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + if model_config.hf_text_config.model_type == "deepseek_v4": + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.head_dim, + qk_nope_head_dim=hf_text_config.head_dim - hf_text_config.rope_head_dim, + qk_rope_head_dim=hf_text_config.rope_head_dim, + v_head_dim=hf_text_config.head_dim, + ) + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) + + +class MLACommonMetadataBuilder_MluHijack(MLACommonMetadataBuilder): + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[M] | None = None, + supports_dcp_with_varlen: bool = False, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) + self.kv_cache_spec = kv_cache_spec + scheduler_config = vllm_config.scheduler_config + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + self.device = device + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.aot_schedule = current_platform.is_cuda() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size + self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size + + # Don't try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.kv_cache_spec.block_size + + self.chunked_prefill_workspace_size = ( + self.determine_chunked_prefill_workspace_size(vllm_config) + ) + + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + ( + self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size(), + ), + dtype=self.model_config.dtype, + device=device, + ) + else: + self.chunked_prefill_workspace = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.model_config.get_head_size(), + ), + dtype=self.model_config.dtype, + device=device, + ) + + self._use_cudnn_prefill = use_cudnn_prefill() + self._use_fi_prefill = use_flashinfer_prefill() + self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill() + self.prefill_metadata_cls = ( + FlashInferPrefillMetadata + if self._use_fi_prefill + else CudnnPrefillMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata + ) + + if self._use_fi_prefill: + self._workspace_buffer = torch.empty( + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + + self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None + self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] + + self._global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + ) + + if self._use_trtllm_ragged_prefill: + self._workspace_buffer = torch.empty( + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + + if self._use_cudnn_prefill: + self.cudnn_workspace = torch.empty( + CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, + dtype=torch.int8, + device=device, + ) + + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen + ) + + # Validate consistency between query_len_support and reorder_batch_threshold + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: + assert self.reorder_batch_threshold == 1, ( + f"reorder_batch_threshold must be 1 when query_len_support is " + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + + +MluHijackObject.apply_hijack(MLACommonBackend, + MLACommonBackend.get_supported_head_sizes, + MLACommonBackend_MluHijack.get_supported_head_sizes) +MluHijackObject.apply_hijack(MLACommonMetadataBuilder, + MLACommonMetadataBuilder.__init__, + MLACommonMetadataBuilder_MluHijack.__init__) + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (1, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def get_kv_cache_scale_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + ) -> tuple[int, ...]: + return (1, num_blocks, num_kv_heads, block_size) + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576, 512] + + +@dataclass +class FlashMLAPrefillMetadata(MLACommonPrefillMetadata): + num_prefills: int = -1 # for gather_cache + max_seq_len: int = -1 # for attn forward + + @property + def block_tables(self): + return self.block_table + + @property + def context_chunk_cu_seq_lens(self): + if self.chunked_context is None: + return None + return self.chunked_context.cu_seq_lens + + @property + def context_chunk_starts(self): + if self.chunked_context is None: + return None + return self.chunked_context.starts + + @property + def context_chunk_seq_tot(self): + if self.chunked_context is None: + return None + return self.chunked_context.seq_tot + + @property + def context_chunk_max_seq_lens(self): + if self.chunked_context is None: + return None + return self.chunked_context.max_seq_lens + + @property + def context_chunk_workspace(self): + if self.chunked_context is None: + return None + return self.chunked_context.workspace + + +@dataclass +class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: torch.Tensor + num_splits: torch.Tensor + + # add for mlu rope and attn forward + query_start_loc: torch.Tensor # for rope + max_query_len: int # for rope + max_seq_len:int = -1 # for attn forward + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + num_prefill_tokens: Optional[int] = None + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + reorder_batch_threshold: int = 128 # process small prefills with decode pathway + # ^ TODO(matt): tune this + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata + ) + + self.num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None + self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") + + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: 1. set decoder_query_len for mtp + @brief: 2. init chunk workspace for prefix_caching only + @brief: 3. set prefill_metadata_cls + @brief: 4. add deepseek v3.2 infos + ''' + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config + speculative_config = vllm_config.speculative_config + self.num_speculative_tokens = (speculative_config.num_speculative_tokens + if speculative_config is not None else 0) + self.decoder_query_len = 1 + self.num_speculative_tokens + + self.max_model_len = self.model_config.max_model_len + self.is_deepseek_v32 = self.model_config.hf_text_config.model_type == "deepseek_v32" + + self.enable_caching = cache_config.enable_prefix_caching + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if (not self.is_deepseek_v32 and not self.chunked_prefill_enabled and + (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and self.enable_caching)): + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * self.model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + + self.prefill_metadata_cls = FlashMLAPrefillMetadata + ''' + ================== + End of MLU Hijack + ================== + ''' + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + # mlu v1 mtp forces decoder_query_len = 1 for k > 1, so we should set again + self.decoder_query_len = 1 + self.num_speculative_tokens + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: record prefill and decode requests and token nums to call + chunked fa and single-query attn respectively in forward. + @Notes: decodes need all prompt tokens are computed. + ''' + req_index = input_batch.req_id_to_index.get(req_id) + all_prompt_tokens_has_computed = ( + input_batch.num_computed_tokens_cpu[req_index] >= + input_batch.num_prompt_tokens[req_index]) + if num_tokens <= self.decoder_query_len and all_prompt_tokens_has_computed: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + ''' + ================== + End of MLU Hijack + ================== + ''' + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + return modified_batch + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor, + query_start_loc: torch.Tensor, + max_query_len: int, + max_seq_len: int, + ) -> FlashMLADecodeMetadata: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: set tile_scheduler_metadata and num_splits to None. + @brief: set dcp_tot_seq_lens_device. + ''' + return FlashMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens, + tile_scheduler_metadata=None, + num_splits=None, + dcp_tot_seq_lens=None, + # for mlu + max_seq_len=max_seq_len, + query_start_loc=query_start_loc, + max_query_len=max_query_len + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + def build_for_cudagraph_capture( + self, common_attn_metadata: MLUCommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with MLA. + """ + m = common_attn_metadata + if m.infer_mode == MLUInferMode.DECODE_ONLY: + assert m.num_reqs * m.max_query_len == m.num_actual_tokens, \ + "MLA only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + return self.build(0, m) + + def build(self, + common_prefix_len: int, + common_attn_metadata: MLUCommonAttentionMetadata, + fast_build: bool = False, + input_batch: "InputBatch" = None) -> M: + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - + query_seq_lens_cpu) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support normal and mtp input split + ''' + if input_batch is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, + self.decoder_query_len) + else: + num_decodes, num_prefills = input_batch.split_decodes_and_prefills() + num_decode_tokens = common_attn_metadata.query_start_loc_cpu[num_decodes].item() + num_prefill_tokens = num_tokens - num_decode_tokens + ''' + ================== + End of MLU Hijack + ================== + ''' + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: avoid buffer missing when prefill_only + mlugraph + ''' + if num_decodes > 0: + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + else: + prefill_query_start_loc= query_start_loc + ''' + ================== + End of MLU Hijack + ================== + ''' + + chunked_context_metadata = None + if ((self.chunked_prefill_enabled or + (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and + self.enable_caching and + common_attn_metadata.is_chunked) + ) and num_prefills > 0 and max_context_len_cpu > 0): + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + if self.is_deepseek_v32: + max_context_chunk = self.max_model_len + else: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + + chunked_context_metadata_cls = \ + FlashMLAPrefillMetadata.ChunkedContextMetadata + + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=getattr(self, "chunked_prefill_workspace", None), + ) + + if not self.is_deepseek_v32: + assert max(chunked_context_metadata.max_seq_lens) <= \ + self.chunked_prefill_workspace_size + + prefill_metadata = self.prefill_metadata_cls( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + # for mlu + num_prefills=num_prefills, + max_seq_len=common_attn_metadata.seq_lens_cpu[reqs_start:].max().item(), + ) + + decode_metadata = None + if num_decodes > 0: + decode_metadata = self._build_decode( + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens=seq_lens[:num_decodes], + query_start_loc=query_start_loc[:num_decodes + 1], + max_query_len=query_seq_lens_cpu[:num_decodes].max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu[:num_decodes].max().item(), + ) + + attn_metadata = self.metadata_cls( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=num_tokens, + query_start_loc=query_start_loc, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=num_decodes, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + return attn_metadata + + def can_run_in_cudagraph( + self, common_attn_metadata: MLUCommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == self.decoder_query_len + + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + + +class FlashMLAImpl(MLUFlashAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + kwargs: Optional[dict[str, Any]] = {}, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + + if attn_metadata is None: + # Profiling run. + return output + + out_lse = None + + # use default common metadata if kwargs does not have common_metadata + common_metadata: MLUCommonAttentionMetadata = kwargs.get("common_metadata", None) + if common_metadata is None: + common_metadata = get_common_metadata() + + only_prefill = kwargs.get("only_prefill", False) + only_decode = kwargs.get("only_decode", False) + attn_bias = kwargs.get("attn_bias", None) + + assert only_prefill != only_decode, "only_prefill and only_decode cannot be True and False at the same time." + + if only_prefill: + cu_seqlens_q = attn_metadata.prefill.query_start_loc + cu_seqlens_kv = common_metadata.query_start_loc + seqused_k = common_metadata.seq_lens[attn_metadata.num_decodes:] + max_seqlen_q = attn_metadata.prefill.max_query_len + max_seqlen_k = attn_metadata.prefill.max_seq_len + block_table = attn_metadata.prefill.block_table + num_actual_tokens = attn_metadata.num_prefill_tokens + else: + cu_seqlens_q = None # nouse + cu_seqlens_kv = None # nouse + seqused_k = common_metadata.seq_lens[:attn_metadata.num_decodes] + max_seqlen_q = None # nouse + max_seqlen_k = common_metadata.max_seq_len + block_table = attn_metadata.decode.block_table + num_actual_tokens = attn_metadata.num_decode_tokens + + skip_process_cache = ((self.use_mla + and (common_metadata.is_prefill_only + or self.use_fused_mla_qkv + or only_prefill)) + or self.kv_sharing_target_layer_name is not None) + + kv_cache_, kv_cache_scale_, kv_cache_index_ = kv_cache + key_cache = kv_cache_[0] + value_cache = None if self.use_mla else kv_cache_[1] + key_cache_scale, value_cache_scale = None, None + if kv_cache_scale_.numel() > 0: + key_cache_scale = kv_cache_scale_[0] + value_cache_scale = None if self.use_mla else kv_cache_scale_[1] + if not skip_process_cache: + if is_quantized_kv_cache(self.kv_cache_dtype): + mlu_ops.quant_to_paged_cache( + k=key[:num_actual_tokens], + v=(None if self.use_mla else value[:num_actual_tokens]), + k_cache=key_cache, + v_cache=value_cache, + k_cache_quant_scale=key_cache_scale, + v_cache_quant_scale=value_cache_scale, + slot_mapping=attn_metadata.slot_mapping.flatten(), + ) + else: + mlu_ops.reshape_paged_cache( + k=key[:num_actual_tokens], + v=(None if self.use_mla else value[:num_actual_tokens]), + k_cache=key_cache, + v_cache=value_cache, + slot_mapping=attn_metadata.slot_mapping.flatten() + ) + + alibi_slopes = None if self.alibi_slopes is None else \ + self.alibi_slopes.repeat(seqused_k.shape[0], 1) + + if kwargs.get("model_type", "") == "deepseek_v32": + from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context + sp_context = get_sp_forward_context() + if sp_context is not None and sp_context.is_v32: + num_actual_tokens = sp_context.sp_attn_metadata.num_prefill_tokens + decode_query = query[:num_actual_tokens].view(-1, self.num_heads, self.head_size) + head_size_v = value.shape[-1] if self.use_mla else self.head_size + decode_output = output[:num_actual_tokens].view(-1, self.num_heads, head_size_v) + decode_query = query.unsqueeze(1) # see tokens as batch dim + decode_output = decode_output.unsqueeze(1) + q_quant_scale = kwargs.get("q_quant_scale", None) + if q_quant_scale is not None: + q_quant_scale = q_quant_scale[:num_actual_tokens].view(-1, self.num_heads) + q_quant_scale = q_quant_scale.unsqueeze(1) + mlu_ops.single_query_cached_kv_attn( + q=decode_query, + k_cache=key_cache, + v_cache=value_cache, + out=decode_output, + block_tables=kwargs.get("new_block_tables", None), + context_lens=kwargs.get("new_context_lens", None), + k_cache_quant_scale=key_cache_scale, + v_cache_quant_scale=value_cache_scale, + alibi_slopes=alibi_slopes, + max_contxt_len=kwargs.get("index_topk", None), + windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]), + windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]), + softmax_scale=self.scale, + head_size_v=(-1 if not self.use_mla else head_size_v), + compute_dtype=compute_dtype, + q_quant_scale=q_quant_scale, + decoder_attn_dtype=self.decoder_attn_dtype, + mask=attn_bias, + ) + return output + + if common_metadata.is_prefill_only or only_prefill: + # prefill only + prefill_causal = kwargs.get("prefill_causal", True) + cu_seqlens_q = kwargs.get("cu_seq_lens_q", cu_seqlens_q) + cu_seqlens_kv = kwargs.get("cu_seq_lens_kv", cu_seqlens_kv) + max_seqlen_q = kwargs.get("max_seq_len_q", max_seqlen_q) + max_seqlen_k = kwargs.get("max_seq_len_kv", max_seqlen_k) + return_lse = kwargs.get("return_lse", False) + num_prefill_query_tokens = common_metadata.num_prefill_query_tokens + num_prefill_kv_tokens = common_metadata.num_prefill_kv_tokens + use_f32 = attn_bias is not None and attn_bias.dtype == torch.float32 + if use_f32: + f32_output = torch.empty_like(output, dtype=torch.float32) + attn_output_list = mlu_ops.flash_attention( + q=query[:num_prefill_query_tokens].to(torch.float32) if use_f32 else query[:num_prefill_query_tokens], + k=key[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else key[:num_prefill_kv_tokens], + v=value[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else value[:num_prefill_kv_tokens], + out=f32_output[:num_prefill_query_tokens] if use_f32 else output[:num_prefill_query_tokens], + cu_seq_lens_q=cu_seqlens_q, + cu_seq_lens_kv=cu_seqlens_kv, + alibi_slope=alibi_slopes, + attn_bias=attn_bias, + max_seq_len_q=max_seqlen_q, + max_seq_len_kv=max_seqlen_k, + softmax_scale=self.scale, + is_causal=prefill_causal, + window_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]), + window_size_right=(-1 if self.sliding_window is None else self.sliding_window[1]), + compute_dtype=self.prefill_compute_dtype, + return_lse=return_lse, + q_quant_dtype=self.prefill_q_dtype, + k_quant_dtype=self.prefill_k_dtype, + v_quant_dtype=self.prefill_v_dtype + ) + if use_f32: + output[:num_prefill_query_tokens].copy_(f32_output[:num_prefill_query_tokens]) + + if return_lse: + out_lse = attn_output_list[1] + else: + batch_size = block_table.shape[0] + # decode only + decode_query = query[:num_actual_tokens].view(batch_size, -1, self.num_heads, self.head_size) + head_size_v = value.shape[-1] if self.use_mla else self.head_size + decode_output = output[:num_actual_tokens].view(batch_size, -1, self.num_heads, head_size_v) + q_quant_scale = kwargs.get("q_quant_scale", None) + if q_quant_scale is not None: + q_quant_scale = q_quant_scale[:num_actual_tokens].view(batch_size, -1, self.num_heads) + mlu_ops.single_query_cached_kv_attn( + q=decode_query, + k_cache=key_cache, + v_cache=value_cache, + out=decode_output, + block_tables=block_table, + context_lens=seqused_k, + k_cache_quant_scale=key_cache_scale, + v_cache_quant_scale=value_cache_scale, + alibi_slopes=alibi_slopes, + max_contxt_len=max_seqlen_k, + windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]), + windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]), + softmax_scale=self.scale, + head_size_v=(-1 if not self.use_mla else head_size_v), + compute_dtype=attn_metadata.decode.compute_dtype, + q_quant_scale=q_quant_scale, + decoder_attn_dtype=self.decoder_attn_dtype, + mask=attn_bias, + ) + + return output if out_lse is None else (output, out_lse) diff --git a/vllm_mlu/v1/attention/backends/utils.py b/vllm_mlu/v1/attention/backends/utils.py new file mode 100644 index 0000000..5cba932 --- /dev/null +++ b/vllm_mlu/v1/attention/backends/utils.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import os +import numpy as np +import pandas as pd +import torch +from typing import TYPE_CHECKING, Union + +from dataclasses import dataclass +from enum import Enum + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm.forward_context import get_forward_context +from vllm.v1.attention.backends.utils import CommonAttentionMetadata + + +COMMON_METADATA_STR: str = "common_metadata" + + +class MLUInferMode(Enum): + CHUNKED = 1 + PREFILL_ONLY = 2 + DECODE_ONLY = 3 + + @classmethod + def build( + cls, + max_query_len, + max_computed_tokens, + uniform_decode_query_len: int = 1, + ) -> Enum: + if max_query_len <= uniform_decode_query_len: + return MLUInferMode.DECODE_ONLY + elif max_computed_tokens == 0: + return MLUInferMode.PREFILL_ONLY + else: + return MLUInferMode.CHUNKED + + @property + def is_prefill_only(self): + return self == MLUInferMode.PREFILL_ONLY + + @property + def is_decode_only(self): + return self == MLUInferMode.DECODE_ONLY + + @property + def is_chunked(self): + return self == MLUInferMode.CHUNKED + + +@dataclass +class MLUCommonAttentionMetadata(CommonAttentionMetadata): + """ + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. + """ + seq_start_loc: torch.Tensor | None = None + seq_start_loc_cpu: torch.Tensor | None = None + """(batch_size + 1,), the start location of each request in the input key/value sequence.""" + num_input_tokens: int = 0 + """Number of query tokens with padding.""" + num_prefill_query_tokens: int = 0 + """Number of query tokens in prefill phase.""" + num_prefill_kv_tokens: int = 0 + """Number of key/value tokens in prefill phase.""" + infer_mode: MLUInferMode | None = None + """Inference mode for flash attention.""" + + @property + def is_prefill_only(self): + return self.infer_mode == MLUInferMode.PREFILL_ONLY + + @property + def is_decode_only(self): + return self.infer_mode == MLUInferMode.DECODE_ONLY + + @property + def is_chunked(self): + return self.infer_mode == MLUInferMode.CHUNKED + + @classmethod + def build( + cls, + query_start_loc, query_start_loc_cpu, + seq_lens, seq_lens_cpu, + num_computed_tokens_cpu, + num_reqs, num_actual_tokens, max_query_len, + block_table_tensor, slot_mapping, + seq_start_loc, is_start_loc_match, + num_input_tokens: int = 0, + num_speculative_tokens: int = 0, + has_prefill_reqs: bool = False + ): + """Build attention metadata for MLU inference. + + Args: + has_prefill_reqs: Whether there are pending prefill requests with chunked. + """ + infer_mode = None + if is_start_loc_match: + infer_mode = MLUInferMode.PREFILL_ONLY + elif max_query_len <= (1 + num_speculative_tokens) and (not has_prefill_reqs): + infer_mode = MLUInferMode.DECODE_ONLY + else: + infer_mode = MLUInferMode.CHUNKED + num_input_tokens = ( + num_actual_tokens if num_input_tokens == 0 + else num_input_tokens + ) + max_seq_len = int(seq_lens_cpu.max()) + return cls(query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + seq_start_loc=seq_start_loc, + seq_start_loc_cpu=seq_start_loc.to("cpu", non_blocking=True), + num_input_tokens=num_input_tokens, + infer_mode=infer_mode, + num_prefill_query_tokens=num_actual_tokens, + num_prefill_kv_tokens=num_actual_tokens) + + def save(self, infer_phase: str): + csv_path = os.getenv("VLLM_STEP_INPUT_CSV_PATH", None) + if not csv_path: + return + + header = [ + "infer_phase", "infer_mode", "num_reqs", "num_actual_tokens", + "max_query_len", "max_seq_len", "query_start_loc", "seq_lens" + ] + data = [ + infer_phase, self.infer_mode, self.num_reqs, + self.num_actual_tokens, self.max_query_len, self.max_seq_len, + str(self.query_start_loc_cpu.tolist()), + str(self.seq_lens_cpu.tolist()) + ] + data_dict = dict(zip(header, data)) + df_csv = pd.DataFrame(data_dict, index=[0]) + + if infer_phase == "RealInfer": + print(df_csv.to_string()) + + try: + if dir_path := os.path.dirname(csv_path): + os.makedirs(dir_path, exist_ok=True) + append = False + if os.path.isfile(csv_path): + try: + df_old = pd.read_csv(csv_path) + append = (df_old.columns.tolist() == header) + except Exception as e: + raise RuntimeError(f"Existing {csv_path} failed to be read and will be overwritten") + if append: + df_csv.to_csv(csv_path, mode='a', header=False, index=False) + else: + df_csv.to_csv(csv_path, index=False) + except Exception as e: + raise RuntimeError(f"Invalid VLLM_STEP_INPUT_CSV_PATH: {csv_path} to dump step inputs, Error: {e}") + + +def get_common_metadata_from_attn_metadata( + attn_metadata) -> Union[MLUCommonAttentionMetadata, None]: + """ + Get MLUCommonAttentionMetadata for MLU-V1 inference. + Use outside of set_forward_context(). + """ + if attn_metadata is None: + return + + assert (isinstance(attn_metadata, dict) + and COMMON_METADATA_STR in attn_metadata), \ + f"MLU-V1 only support type(attn_metadata)=dict, and " + \ + f"{COMMON_METADATA_STR} in attn_metadata. Now, type(attn_metadata)=" + \ + f"{type(attn_metadata)}, or {COMMON_METADATA_STR} not in attn_metadata." + return attn_metadata[COMMON_METADATA_STR] + + +def get_common_metadata() -> Union[MLUCommonAttentionMetadata, None]: + """ + Get MLUCommonAttentionMetadata for MLU-V1 inference. + Use inside of set_forward_context(). + """ + attn_metadata = get_forward_context().attn_metadata + return get_common_metadata_from_attn_metadata(attn_metadata) + + +def unpad_common_attn_metadata( + common_metadata: MLUCommonAttentionMetadata, + num_reqs: int, + num_scheduled_tokens: int, +): + """ + Unpad MLUCommonAttentionMetadata by given num_reqs and num_scheduled_tokens. + """ + common_metadata.num_reqs = num_reqs + common_metadata.num_input_tokens = num_scheduled_tokens + common_metadata.query_start_loc = common_metadata.query_start_loc[:num_reqs + 1] + common_metadata.query_start_loc_cpu = common_metadata.query_start_loc_cpu[:num_reqs + 1] + common_metadata.seq_start_loc = common_metadata.seq_start_loc[:num_reqs + 1] + common_metadata.seq_lens = common_metadata.seq_lens[:num_reqs] + common_metadata.seq_lens_cpu = common_metadata.seq_lens_cpu[:num_reqs] + common_metadata.block_table_tensor = common_metadata.block_table_tensor[:num_reqs] + +def reorder_batch_to_split_decodes_and_prefills( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill and decode requests; places all + requests with <= decode_threshold tokens at the front of the batch. + + Returns: + True if the batch was modified, False otherwise. + """ + # We now want to reorder the batch into decode → extend → prefill order + # where: + # decode: request with num_scheduled_tokens <= decode_threshold + # extend: non-decode request with existing context + # prefill: non-decode request with no existing context + # NOTE for now we loosely use "decode" to mean requests where attention is + # likely memory-bound and "prefill" to mean requests where attention is + # likely compute-bound, + num_reqs = len(input_batch.req_ids) + num_scheduled_tokens = [ + scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids + ] + num_scheduled_tokens_np = np.array(num_scheduled_tokens) + num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs] + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: enhence decode mode condition that all prompt tokens are computed. + ''' + # is_decode = num_scheduled_tokens_np <= decode_threshold + is_decode = ( + (num_scheduled_tokens_np <= decode_threshold) + & (num_computed_tokens_np >= input_batch.num_prompt_tokens[:num_reqs]) + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + is_extend = (~is_decode) & (num_computed_tokens_np > 0) + is_prefill = (~is_decode) & (num_computed_tokens_np == 0) + + # Desired order: decode → extend → prefill + req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default + req_regions[is_extend] = 1 + req_regions[is_prefill] = 2 + + num_decodes = int(is_decode.sum()) + num_extends = int(is_extend.sum()) + + target_regions = np.zeros(num_reqs, dtype=np.int32) + target_regions[num_decodes : num_decodes + num_extends] = 1 + target_regions[num_decodes + num_extends :] = 2 + + needs_swap = req_regions != target_regions + + if not needs_swap.any(): + return False + + # Extract indices that need swapping and sort by target region + orig_indices = np.where(needs_swap)[0] + sorted_order = np.argsort(req_regions[needs_swap], kind="stable") + src_indices = orig_indices[sorted_order] + + src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)} + + for src in src_dest_map: + dst = src_dest_map[src] + while src != dst: + input_batch.swap_states(src, dst) + # Mark dst as done by updating its destination to itself + next_dst = src_dest_map.get(dst, dst) + src_dest_map[dst] = dst + dst = next_dst + + return True diff --git a/vllm_mlu/v1/core/__init__.py b/vllm_mlu/v1/core/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/core/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/core/kv_cache_manager.py b/vllm_mlu/v1/core/kv_cache_manager.py new file mode 100644 index 0000000..a151e3d --- /dev/null +++ b/vllm_mlu/v1/core/kv_cache_manager.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import itertools +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, overload + +from vllm.distributed.kv_events import KVCacheEvent +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator +from vllm.v1.core.kv_cache_utils import KVCacheBlock +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request + +logger = init_logger(__name__) + +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +class KVCacheManager_MluHijack(KVCacheManager): + + def allocate_slots( + self, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: KVCacheBlocks | None = None, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, + fixed_window_tokens: int = 0, + ) -> KVCacheBlocks | None: + """Add slots for a request with new tokens to append. + + Args: + request: The request to allocate slots. + num_new_tokens: The number of tokens to allocate, including external + tokens. Note that this does not include tokens that have + already been computed locally (i.e. new_computed_blocks). + num_new_computed_tokens: The number of new computed tokens just + hitting the prefix caching, excluding external tokens. + new_computed_blocks: The cached blocks for the above new computed + tokens. + num_lookahead_tokens: The number of speculative tokens to allocate. + This is used by spec decode proposers with kv-cache such + as eagle. + delay_cache_blocks: Whether to skip caching the blocks. This is + used by P/D when allocating blocks used in a KV transfer + which will complete in a future step. + + Blocks layout: + ``` + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + ``` + The following *_blocks are illustrated in this layout. + + Returns: + A list of new allocated blocks. + """ + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") + + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks + else: + new_computed_block_list = self.empty_kv_cache_blocks.blocks + + # Free the blocks that are skipped during the attention computation + # (e.g., tokens outside the sliding window). + # We can do this even if we cannot schedule this request due to + # insufficient free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + self.coordinator.remove_skipped_blocks( + request.request_id, request.num_computed_tokens + ) + + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens + num_tokens_need_slot = min( + num_computed_tokens + num_new_tokens + num_lookahead_tokens + fixed_window_tokens, + self.max_model_len, + ) + + num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_tokens_need_slot, + new_computed_blocks=new_computed_block_list, + num_encoder_tokens=num_encoder_tokens, + ) + + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + # Cannot allocate new blocks + return None + + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self.block_pool.touch(new_computed_block_list) + else: + assert not any(new_computed_block_list), ( + "Computed blocks should be empty when prefix caching is disabled" + ) + + if new_computed_block_list is not self.empty_kv_cache_blocks.blocks: + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + self.coordinator.save_new_computed_blocks( + request.request_id, new_computed_block_list + ) + + new_blocks = self.coordinator.allocate_new_blocks( + request.request_id, num_tokens_need_slot, num_encoder_tokens + ) + + # P/D: delay caching blocks if we have to recv from + # remote. Update state for locally cached blocks. + if not self.enable_caching or delay_cache_blocks: + return self.create_kv_cache_blocks(new_blocks) + + # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + + # num_new_tokens, but must exclude "non-committable" tokens (e.g., + # draft tokens that could be rejected). Therefore, we cap the number + # at `request.num_tokens`, ensuring only "finalized" tokens are cached. + num_tokens_to_cache = min( + num_computed_tokens + num_new_tokens, request.num_tokens + ) + self.coordinator.cache_blocks(request, num_tokens_to_cache) + + return self.create_kv_cache_blocks(new_blocks) + + +MluHijackObject.apply_hijack(KVCacheManager, + KVCacheManager.allocate_slots, + KVCacheManager_MluHijack.allocate_slots) \ No newline at end of file diff --git a/vllm_mlu/v1/core/kv_cache_utils.py b/vllm_mlu/v1/core/kv_cache_utils.py new file mode 100644 index 0000000..a5ce217 --- /dev/null +++ b/vllm_mlu/v1/core/kv_cache_utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.logger import init_logger +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + UniformTypeKVCacheSpecs, +) +from vllm.v1.core import kv_cache_utils +from vllm.v1.core.kv_cache_utils import (may_override_num_blocks, + get_uniform_page_size, + get_num_blocks) + +logger = init_logger(__name__) + +def vllm__v1__core__kv_cache_utils__get_kv_cache_config_from_groups( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + kv_cache_specs: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: + """ + Generate the KV cache configuration from the KV cache groups and spec + of each layer. + + Args: + vllm_config: The global VllmConfig + kv_cache_groups: The KV cache groups + kv_cache_specs: The KV cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes + Returns: + The generated KVCacheConfig + """ + if len(kv_cache_groups) == 0: + # Attention free models do not have KV cache. + # Return num_blocks=1 as BlockPool always needs a null_block. + return KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=kv_cache_groups, + ) + + # Determine how model runners should initialize the KV cache tensors. + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): + # Special case: all layers have the same type of KV cache but with + # different hidden size. Allocate different amount of memory for each + # layer based on its hidden size. + num_blocks = ( + available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes + ) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs + kv_cache_tensors = [ + KVCacheTensor( + size=per_layer_specs[layer_name].page_size_bytes * num_blocks, + shared_by=[layer_name], + ) + for layer_name in kv_cache_groups[0].layer_names + ] + else: + # General case: + # We will have group_size memory pools, each is shared by one layer from + # each group. As layers of different groups have different block table, + # they will use different parts of the shared Tensor. + # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), + # (sw.1, padding) will be: (group_size = 2) + # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 + # full.1, sw.2: share another Tensor with size=available_memory//2 + group_size = max(len(group.layer_names) for group in kv_cache_groups) + + page_size = get_uniform_page_size(kv_cache_specs) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + if (vllm_config.mlu_config.enable_mamba_split_page_size): + # Note(wulingchao): 预留出linear attention的内存不参与系统调度 + # 当前的 page size是小page,需要扩展到完整的linear attention的page + mamba_page_size = (page_size \ + * vllm_config.mlu_config.mamba_to_attn_block_ratio + * vllm_config.mlu_config.mamba_support_max_batch_size \ + * group_size * 3) + logger.warning(f"all available memory {available_memory}, mamba mem used {mamba_page_size}") + available_memory = available_memory - mamba_page_size + ''' + ================== + End of MLU Hijack + ================== + ''' + assert group_size > 0, "group_size must be greater than 0" + num_blocks = get_num_blocks( + vllm_config, group_size, available_memory, page_size + ) + kv_cache_tensors = [] + for i in range(group_size): + shared_by = [] + for j in range(len(kv_cache_groups)): + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) + kv_cache_tensors.append( + KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) + ) + + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=kv_cache_groups, + ) + +MluHijackObject.apply_hijack(kv_cache_utils, + kv_cache_utils.get_kv_cache_config_from_groups, + vllm__v1__core__kv_cache_utils__get_kv_cache_config_from_groups) + + diff --git a/vllm_mlu/v1/core/sched/__init__.py b/vllm_mlu/v1/core/sched/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/core/sched/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/core/sched/async_scheduler.py b/vllm_mlu/v1/core/sched/async_scheduler.py new file mode 100644 index 0000000..a6e2c25 --- /dev/null +++ b/vllm_mlu/v1/core/sched/async_scheduler.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +from vllm_mlu.v1.core.sched.scheduler import MLUUnchunkScheduler, SchedulerWithProfiler + +logger = init_logger(__name__) + + +class AsyncScheduler(SchedulerWithProfiler): + def _update_after_schedule( + self, + scheduler_output: SchedulerOutput, + ) -> None: + super()._update_after_schedule(scheduler_output) + pending_structured_output_tokens = False + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id in scheduler_output.num_scheduled_tokens: + request = self.requests[req_id] + pending_structured_output_tokens |= ( + request.use_structured_output and request.num_output_placeholders > 0 + ) + cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) + if ( + request.num_computed_tokens + == request.num_tokens + + request.num_output_placeholders + + cur_num_spec_tokens + ): + # The request will generate a new token plus num_spec_tokens + # in this scheduling step. + request.num_output_placeholders += 1 + cur_num_spec_tokens + # Add placeholders for the new tokens in spec_token_ids. + # Wwe will update the actual spec token ids in the worker process. + request.spec_token_ids = [-1] * self.num_spec_tokens + + scheduler_output.pending_structured_output_tokens = ( + pending_structured_output_tokens + ) + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + status_before_update = request.status + new_token_ids, stopped = super()._update_request_with_output( + request, new_token_ids + ) + + # Update the number of output placeholders. + request.num_output_placeholders -= len(new_token_ids) + assert request.num_output_placeholders >= 0 + + # Cache the new tokens. Preempted requests should be skipped. + if status_before_update == RequestStatus.RUNNING: + self.kv_cache_manager.cache_blocks( + request, request.num_computed_tokens - request.num_output_placeholders + ) + return new_token_ids, stopped + +class MLUUnchunkAsyncScheduler(MLUUnchunkScheduler): + def _update_after_schedule( + self, + scheduler_output: SchedulerOutput, + ) -> None: + super()._update_after_schedule(scheduler_output) + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id in scheduler_output.num_scheduled_tokens: + request = self.requests[req_id] + cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, [])) + if ( + request.num_computed_tokens + == request.num_tokens + + request.num_output_placeholders + + cur_num_spec_tokens + ): + # The request will generate a new token plus num_spec_tokens + # in this scheduling step. + request.num_output_placeholders += 1 + cur_num_spec_tokens + # Add a placeholder for the new token in spec_token_ids. + # because the actual token id is not known yet. so just use -1 + # as a placeholder and the length of spec_token_ids is set to + # self.num_spec_tokens. we will update the actual spec token id + # in worker process. + request.spec_token_ids = [-1] * self.num_spec_tokens + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + status_before_update = request.status + new_token_ids, stopped = super()._update_request_with_output( + request, new_token_ids) + + # num_output_placeholders = 0 happend when a request is preempted. + # a preempted request will be added to waiting queue again and + # num_output_placeholders is reset to 0, + # so don't need to revert num_output_placeholders for this situation. + if request.num_output_placeholders > 0: + # Update the number of output placeholders. + request.num_output_placeholders -= len(new_token_ids) + assert request.num_output_placeholders >= 0 + + # Cache the new tokens. Preempted requests should be skipped. + if status_before_update == RequestStatus.RUNNING: + self.kv_cache_manager.cache_blocks( + request, + request.num_computed_tokens - request.num_output_placeholders) + return new_token_ids, stopped + + + def _update_computed_tokens_after_speculation( + self, request: Request, num_rejected: int + ): + """Update the computed tokens for each request, which is necessary + for spec decoding. In sync scheduler, we need to revert + num_computed_tokens by num_rejected tokens, + but in async scheduler, we also need to revert num_output_placeholders + by num_rejected tokens for spec decoding. + """ + # num_computed_tokens = 0 happend when a request is preempted. + # a preempted request will be added to waiting queue again and + # num_computed_tokens is reset to 0, + # so don't need to revert num_computed_tokens for this situation. + if request.num_computed_tokens > 0: + # when spec decoding is enabled, num_output_placeholders + # is increased by num_spec_tokens in _update_after_schedule. + # update num_output_placeholders here to reflect the actual number + # of accepted output tokens. + request.num_output_placeholders -= num_rejected + super()._update_computed_tokens_after_speculation(request, num_rejected) diff --git a/vllm_mlu/v1/core/sched/output.py b/vllm_mlu/v1/core/sched/output.py new file mode 100644 index 0000000..287dd74 --- /dev/null +++ b/vllm_mlu/v1/core/sched/output.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING + +from typing_extensions import deprecated + +from vllm._bc_linter import bc_linter_include + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + import torch + + from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata + from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata + from vllm.lora.request import LoRARequest + from vllm.multimodal.inputs import MultiModalFeatureSpec + from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams + from vllm.v1.request import Request +else: + ECConnectorMetadata = object + KVConnectorMetadata = object + LoRARequest = object + MultiModalFeatureSpec = object + PoolingParams = object + SamplingParams = object + Request = object + +''' +============================= +Modify by vllm_mlu +============================= +@brief: Add new_toked_ids to pass the first token generated +by the prefiller to the decoder's model_runner. +''' +@bc_linter_include +@dataclass +class NewRequestData: + req_id: str + prompt_token_ids: list[int] | None + mm_features: list[MultiModalFeatureSpec] + sampling_params: SamplingParams | None + pooling_params: PoolingParams | None + block_ids: tuple[list[int], ...] + num_computed_tokens: int + lora_request: LoRARequest | None + new_token_ids: list[list[int]] + prompt_embeds: "torch.Tensor | None" = None + + @classmethod + def from_request( + cls, + request: Request, + block_ids: tuple[list[int], ...], + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + mm_features=request.mm_features, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + block_ids=block_ids, + num_computed_tokens=request.num_computed_tokens, + lora_request=request.lora_request, + prompt_embeds=request.prompt_embeds, + new_token_ids=request._output_token_ids, + ) + + def __repr__(self) -> str: + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}," + f"new_token_ids={self.new_token_ids}" + ")" + ) + + # Version of __repr__ with the prompt data obfuscated + def anon_repr(self) -> str: + prompt_token_ids_len = ( + len(self.prompt_token_ids) if self.prompt_token_ids is not None else None + ) + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={prompt_token_ids_len}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) +''' +================== +End of MLU Hijack +================== +''' \ No newline at end of file diff --git a/vllm_mlu/v1/core/sched/scheduler.py b/vllm_mlu/v1/core/sched/scheduler.py new file mode 100644 index 0000000..c05c04a --- /dev/null +++ b/vllm_mlu/v1/core/sched/scheduler.py @@ -0,0 +1,1723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import time +from dataclasses import dataclass +from collections import defaultdict +from enum import Enum +import pandas as pd +import matplotlib.pyplot as plt + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata +from vllm.distributed.kv_events import KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm_mlu.v1.core.sched.output import NewRequestData +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.request_queue import (SchedulingPolicy, + create_request_queue) +from vllm.v1.engine import EngineCoreEventType +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request, RequestStatus +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs +from vllm.v1.core.sched.utils import check_stop, remove_all +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.v1.outputs import ModelRunnerOutput + +import vllm_mlu._mlu_utils as mlu_envs + +logger = init_logger(__name__) + + +class SchedMode(Enum): + CHUNK = 1 + UNCHUNK = 2 + + +@dataclass +class MLUSchedMetric: + # step index + step: int + # request queue + waiting_reqs: int + running_reqs: int + finished_reqs: int + # queue transfer + wait_to_run_reqs: int + run_to_wait_reqs: int + # usage + token_usage: float + batch_usage: float + block_usage: float + # sched info + total_num_scheduled_batchs: int + total_num_scheduled_tokens: int + total_num_scheduled_seqlens: int + + @classmethod + def build( + cls, + step: int, + waiting_reqs: int, + running_reqs: int, + finished_reqs: int, + scheduled_new_reqs: list[Request], + scheduled_resumed_reqs: list[Request], + scheduled_running_reqs: list[Request], + preempted_reqs: list[Request], + total_num_scheduled_tokens: int, + max_num_scheduled_tokens: int, + max_num_running_reqs: int, + block_usage: float, + ): + wait_to_run_reqs = ( + len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + ) + run_to_wait_reqs = len(preempted_reqs) + total_num_scheduled_batchs = ( + wait_to_run_reqs + len(scheduled_running_reqs) + ) + total_num_scheduled_seqlens = 0 + for req in (scheduled_new_reqs + scheduled_resumed_reqs + + scheduled_running_reqs): + total_num_scheduled_seqlens += req.num_computed_tokens + + return cls( + step=step, + waiting_reqs=waiting_reqs, + running_reqs=running_reqs, + finished_reqs=finished_reqs, + wait_to_run_reqs=wait_to_run_reqs, + run_to_wait_reqs=run_to_wait_reqs, + token_usage=(total_num_scheduled_tokens / max_num_scheduled_tokens), + batch_usage=(total_num_scheduled_batchs / max_num_running_reqs), + block_usage=block_usage, + total_num_scheduled_batchs=total_num_scheduled_batchs, + total_num_scheduled_tokens=total_num_scheduled_tokens, + total_num_scheduled_seqlens=total_num_scheduled_seqlens, + ) + + def as_df(self): + df = pd.DataFrame( + data=[ + [ + self.waiting_reqs, + self.running_reqs, + self.finished_reqs, + self.wait_to_run_reqs, + self.run_to_wait_reqs, + self.token_usage, + self.batch_usage, + self.block_usage, + self.total_num_scheduled_batchs, + self.total_num_scheduled_tokens, + self.total_num_scheduled_seqlens, + ] + ], + columns=[ + 'waiting', + 'running', + 'finished', + 'wait_to_run', + 'run_to_wait', + 'token_usage', + 'batch_usage', + 'block_usage', + 'sched_batchs', + 'sched_tokens', + 'sched_seqlens', + ], + index=[str(self.step)] + ) + return df + + @classmethod + def dump_scheduler_metric( + cls, + df: pd.DataFrame, + schd_mode: SchedMode, + dp_rank: str, + save_dir: str = "./", + ) -> None: + plt.rcParams.update({'font.size': 8}) + figure = plt.figure(figsize=(6.4, 11.2)) + gs = figure.add_gridspec(6, hspace=0) + axes = gs.subplots(sharex=True, sharey=False) + scheduler_type = ("Unchunk" if schd_mode == SchedMode.UNCHUNK + else "Chunk") + figure.suptitle(f"Cambricon vLLM {scheduler_type} Scheduler View {dp_rank}") + # requst queue + df.plot(ax=axes[0], y=['waiting', 'running', 'finished']) + axes[0].set_xlabel('vLLM-Engine-Step', loc='left') + axes[0].set_ylabel('Request-Queue', loc='top') + # queue transfer + df.plot(ax=axes[1], y=['wait_to_run', 'run_to_wait']) + axes[1].set_xlabel('vLLM-Engine-Step', loc='left') + axes[1].set_ylabel('Queue-Transfer', loc='top') + # usage + df.plot(ax=axes[2], y=['token_usage', 'batch_usage', 'block_usage']) + axes[2].set_xlabel('vLLM-Engine-Step', loc='left') + axes[2].set_ylabel('Usage(%)', loc='top') + # batch + df.plot(ax=axes[3], y=['sched_batchs']) + axes[3].set_xlabel('vLLM-Engine-Step', loc='left') + axes[3].set_ylabel('Usage(%)', loc='top') + # token + df.plot(ax=axes[4], y=['sched_tokens']) + axes[4].set_xlabel('vLLM-Engine-Step', loc='left') + axes[4].set_ylabel('Usage(%)', loc='top') + # seqlen + df.plot(ax=axes[5], y=['sched_seqlens']) + axes[5].set_xlabel('vLLM-Engine-Step', loc='left') + axes[5].set_ylabel('Usage(%)', loc='top') + for ax in axes: + ax.label_outer() + ax.legend(loc='upper right') + figure.tight_layout() + figure.savefig( + os.path.join(save_dir, f"vllm_scheduler_view_{dp_rank}.svg"), + dpi=300, + format='svg', + ) + plt.close(figure) + + sched_df = df.copy(deep=True) + max_, mean_, min_ = sched_df.max(), sched_df.mean(), sched_df.min() + sched_df.loc["Max"] = max_ + sched_df.loc["Mean"] = mean_ + sched_df.loc["Min"] = min_ + with pd.option_context('display.max_rows', None, + 'display.max_columns', None, + 'display.max_colwidth', None, + 'display.float_format', '{:^6,.2f}'.format, + 'expand_frame_repr', False): + logger.info(sched_df.loc[["Max", "Mean", "Min"]]) + sched_df.astype(str).to_csv( + os.path.join(save_dir, f"vllm_scheduler_view_{dp_rank}.csv"), mode="w") + + +class SchedulerWithProfiler(Scheduler): + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + block_size: int, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + structured_output_manager=structured_output_manager, + block_size=block_size, + mm_registry=mm_registry, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + self.enable_sched_profiler = mlu_envs.VLLM_SCHEDULER_PROFILE + if self.enable_sched_profiler: + self.start_scheduler_profile() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + + self.draft_token_ids = None + + def __del__(self): + if self.enable_sched_profiler: + self.stop_scheduler_profile() + + @property + def sched_mode(self): + return SchedMode.CHUNK + + def start_scheduler_profile(self): + if not self.enable_sched_profiler: + return + + logger.info(f"VLLM-V1 scheduler profiling started.") + self.sched_metrics: list[MLUSchedMetric] = [] + self.sched_step = 0 + + def stop_scheduler_profile(self): + if not self.enable_sched_profiler: + return + + logger.info(f"VLLM-V1 scheduler profiling stopped.") + assert len(self.sched_metrics) > 0, \ + "Profiling scheduler failed, cannot find any scheduler metrics." + df = pd.concat([m.as_df() for m in self.sched_metrics]) + MLUSchedMetric.dump_scheduler_metric(df, self.sched_mode, f"dp{self.dp_rank}") + + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len or + # request's max_tokens. + # This is necessary when using spec decoding and/or async scheduling. + max_total_tokens = min( + request.num_prompt_tokens + request.max_tokens, self.max_model_len + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: modify to ensure same token counts per batch during MTP. + ''' + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + external_load_encoder_input: list[int] = [] + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + # Schedule newly needed KV blocks for the request. + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + + if new_blocks is not None: + # The request can be scheduled. + break + + if mlu_envs.VLLM_V1_BENCHMARK: + raise RuntimeError( + "V1 benchmark does not support recompute. Please increase " + "gpu-memory-utilization or make input smaller.") + + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + scheduled_spec_decode_tokens.pop( + preempted_req.request_id, None + ) + preempted_encoder_inputs = scheduled_encoder_inputs.pop( + preempted_req.request_id, None + ) + if preempted_encoder_inputs: + # Restore encoder compute budget if the preempted + # request had encoder inputs scheduled in this step. + num_tokens_to_restore = sum( + preempted_req.get_num_encoder_tokens(i) + for i in preempted_encoder_inputs + ) + encoder_compute_budget += num_tokens_to_restore + req_index -= 1 + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break + + if new_blocks is None: + # Cannot schedule this request. + break + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders + ) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids + ) + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule + ) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if mlu_envs.VLLM_DISAGG_FAKE_DECODER: + if self.vllm_config.speculative_config is not None: + request.kv_transfer_params = { + "first_tok": 0, + "first_spec_tok": [1], + } + else: + request.kv_transfer_params = { + "first_tok": 0, + } + + if mlu_envs.VLLM_DISAGG_TRANS_ALL_BLOCKS: + kv_transfer_params = request.kv_transfer_params + if kv_transfer_params and kv_transfer_params.get("first_tok", None) is not None: + first_tok = kv_transfer_params.pop("first_tok") + request.append_output_token_ids(first_tok) + if kv_transfer_params and kv_transfer_params.get("first_spec_tok", None) is not None: + first_spec_tok = kv_transfer_params.pop("first_spec_tok") + request.spec_token_ids = first_spec_tok + + ''' + ================== + End of MLU Hijack + ================== + ''' + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id, + ) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + ext_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) + ) + + if ext_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = ext_tokens + + # Total computed tokens (local + external). + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) + else: + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + external_load_encoder_input = [] + new_encoder_compute_budget = encoder_compute_budget + + if load_kv_async: + # KVTransfer: loading remote KV, do not allocate for new work. + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + else: + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if mlu_envs.VLLM_DISAGG_TRANS_ALL_BLOCKS: + if len(request.spec_token_ids) > 0: + num_new_tokens = request.num_tokens_with_spec - num_computed_tokens + ''' + ================== + End of MLU Hijack + ================== + ''' + + threshold = self.scheduler_config.long_prefill_token_threshold + if 0 < threshold < num_new_tokens: + num_new_tokens = threshold + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if ( + not self.scheduler_config.enable_chunked_prefill + and num_new_tokens > token_budget + ): + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens = ( + self.scheduler_config.max_num_encoder_input_tokens + ) + else: + num_encoder_tokens = 0 + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + ) + + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + self._update_connector_prefix_cache_stats( + request, num_external_computed_tokens + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if mlu_envs.VLLM_DISAGG_TRANS_ALL_BLOCKS: + if len(request.spec_token_ids) > 0: + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + ''' + ================== + End of MLU Hijack + ================== + ''' + + + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id) + ) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule + ) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + # Allocate for external load encoder cache + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) + ) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids() + ) + for req in scheduled_new_reqs + ] + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + + # Record the request ids that were scheduled in this step. + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: profiling scheduler each step + ''' + if self.enable_sched_profiler: + # advance schedule step + self.sched_step += 1 + # caculate schedule metrics + metrics = MLUSchedMetric.build( + step=self.sched_step, + waiting_reqs=len(self.waiting), + running_reqs=len(self.running), + finished_reqs=len(self.finished_req_ids), + scheduled_new_reqs=scheduled_new_reqs, + scheduled_resumed_reqs=scheduled_resumed_reqs, + scheduled_running_reqs=scheduled_running_reqs, + preempted_reqs=preempted_reqs, + total_num_scheduled_tokens=total_num_scheduled_tokens, + max_num_scheduled_tokens=self.max_num_scheduled_tokens, + max_num_running_reqs=self.max_num_running_reqs, + block_usage=self.kv_cache_manager.usage, + + ) + self.sched_metrics.append(metrics) + if len(self.running) < 10: + self.stop_scheduler_profile() + ''' + ================== + End of MLU Hijack + ================== + ''' + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta: KVConnectorMetadata = self.connector.build_connector_meta( + scheduler_output + ) + scheduler_output.kv_connector_metadata = meta + + # Build the connector meta for ECConnector + if self.ec_connector is not None: + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( + scheduler_output + ) + scheduler_output.ec_connector_metadata = ec_meta + + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) + return scheduler_output + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output + + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats: KVConnectorStats | None = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) + if kv_connector_stats and self.connector: + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) + + failed_kv_load_req_ids = None + if kv_connector_output and kv_connector_output.invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + failed_kv_load_req_ids = self._handle_invalid_blocks( + kv_connector_output.invalid_block_ids + ) + + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids: list[int] = ( + sampled_token_ids[req_index].tolist() if sampled_token_ids else [] + ) + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) + if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens. + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + # If async scheduling, num_output_placeholders also includes + # the scheduled spec tokens count and so is similarly adjusted. + if request.num_output_placeholders > 0: + request.num_output_placeholders -= num_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted, + ) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + status_before_stop = request.status + + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output( + request, new_token_ids + ) + + # Stop checking for pooler models. + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + stopped = check_stop(request, self.max_model_len, pooler_output) + + if stopped: + kv_transfer_params = self._free_request(request) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if mlu_envs.VLLM_DISAGG_TRANS_ALL_BLOCKS: + params = request.kv_transfer_params + if params and params.get("ret_first_tok", False) and self.draft_token_ids is not None: + try: + req_index = self.draft_token_ids.req_ids.index(req_id) + spec_token_id = self.draft_token_ids.draft_token_ids[req_index] + kv_transfer_params["first_spec_tok"] = spec_token_id + except ValueError: + raise ValueError("failed to put spec_token_id to kv_transfer_params") + ''' + ================== + End of MLU Hijack + ================== + ''' + + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance(request): + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) + + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None or kv_transfer_params: + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + num_nans_in_logits=request.num_nans_in_logits, + ) + ) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = remove_all(self.running, stopped_running_reqs) + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set + ) + finished_req_ids.clear() + + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: + # Return stats to only one of the front-ends. + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats + + return engine_core_outputs + + +class MLUUnchunkScheduler(SchedulerWithProfiler): + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + block_size: int, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + structured_output_manager=structured_output_manager, + block_size=block_size, + mm_registry=mm_registry, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + self.last_log_time = time.time() + + @property + def sched_mode(self): + return SchedMode.UNCHUNK + + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + ########### Schedule waiting ########## + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + req_index = len(self.running) + # First, schedule the WAITING requests. + waiting_prefills = len(self.waiting) + is_prefill_batch_met = (waiting_prefills >= mlu_envs.VLLM_V1_MIN_PREFILL_BATCH) + while self.waiting and token_budget > 0: + if not is_prefill_batch_met: + logger.debug( + f"Skip prefill scheduling, " + + f"VLLM_V1_MIN_PREFILL_BATCH({mlu_envs.VLLM_V1_MIN_PREFILL_BATCH})" + + f" > waiting({waiting_prefills}).") + break + + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id, + ) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + ext_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) + ) + + if ext_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = ext_tokens + + # Total computed tokens (local + external). + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) + else: + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + external_load_encoder_input = [] + new_encoder_compute_budget = encoder_compute_budget + + if load_kv_async: + # KVTransfer: loading remote KV, do not allocate for new work. + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + else: + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + threshold = self.scheduler_config.long_prefill_token_threshold + if 0 < threshold < num_new_tokens: + num_new_tokens = threshold + + # num_new_tokens = min(num_new_tokens, token_budget) + if num_new_tokens > token_budget: + # The request cannot be scheduled. + break + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens = ( + self.scheduler_config.max_num_encoder_input_tokens + ) + else: + num_encoder_tokens = 0 + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + fixed_window_tokens=getattr(self.vllm_config.model_config.hf_config, "window_size", 0), + ) + + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + self._update_connector_prefix_cache_stats( + request, num_external_computed_tokens + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id) + ) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule + ) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + # Allocate for external load encoder cache + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Next, schedule the RUNNING requests. + if (len(scheduled_new_reqs) == 0 and len(scheduled_resumed_reqs) == 0): + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len or + # request's max_tokens. + # This is necessary when using spec decoding and/or async scheduling. + max_total_tokens = min( + request.num_prompt_tokens + request.max_tokens, self.max_model_len + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: modify to ensure same token counts per batch during MTP. + ''' + speculative_config = self.vllm_config.speculative_config + num_speculative_tokens = 0 + if speculative_config is not None: + num_speculative_tokens = speculative_config.num_speculative_tokens + num_new_tokens = min( + num_new_tokens, + max_total_tokens + num_speculative_tokens - 1 - request.num_computed_tokens) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + external_load_encoder_input: list[int] = [] + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + # Schedule newly needed KV blocks for the request. + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + fixed_window_tokens=getattr(self.vllm_config.model_config.hf_config, "window_size", 0), + ) + + if new_blocks is not None: + # The request can be scheduled. + break + + if mlu_envs.VLLM_V1_BENCHMARK: + raise RuntimeError( + "V1 benchmark does not support recompute. Please increase " + "gpu-memory-utilization or make input smaller.") + + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + scheduled_spec_decode_tokens.pop( + preempted_req.request_id, None + ) + preempted_encoder_inputs = scheduled_encoder_inputs.pop( + preempted_req.request_id, None + ) + if preempted_encoder_inputs: + # Restore encoder compute budget if the preempted + # request had encoder inputs scheduled in this step. + num_tokens_to_restore = sum( + preempted_req.get_num_encoder_tokens(i) + for i in preempted_encoder_inputs + ) + encoder_compute_budget += num_tokens_to_restore + req_index -= 1 + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break + + if new_blocks is None: + # Cannot schedule this request. + break + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders + ) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids + ) + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule + ) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) + + if (mlu_envs.VLLM_V1_UNCHUNK_SCHED_LOG + and time.time() - self.last_log_time >= 5): + logger.info( + f"MLUUnchunkScheduler: waiting={len(self.waiting)}, running={len(self.running)}, " + f"scheduled_new_reqs={len(scheduled_new_reqs)}, " + f"scheduled_resumed_reqs={len(scheduled_resumed_reqs)}, " + f"scheduled_running_reqs={len(scheduled_running_reqs)}, " + f"num_scheduled_tokens={num_scheduled_tokens}") + self.last_log_time = time.time() + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) + ) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids() + ) + for req in scheduled_new_reqs + ] + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + + # Record the request ids that were scheduled in this step. + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta: KVConnectorMetadata = self.connector.build_connector_meta( + scheduler_output + ) + scheduler_output.kv_connector_metadata = meta + + # Build the connector meta for ECConnector + if self.ec_connector is not None: + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( + scheduler_output + ) + scheduler_output.ec_connector_metadata = ec_meta + + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) + return scheduler_output \ No newline at end of file diff --git a/vllm_mlu/v1/core/single_type_kv_cache_manager.py b/vllm_mlu/v1/core/single_type_kv_cache_manager.py new file mode 100644 index 0000000..01dde67 --- /dev/null +++ b/vllm_mlu/v1/core/single_type_kv_cache_manager.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.v1.core.single_type_kv_cache_manager import ( + FullAttentionManager, + SlidingWindowManager, + spec_manager_map, +) + +from vllm_mlu.v1.kv_cache_interface import ( + MLUFullAttentionSpec, + MLUMLAAttentionSpec, + MLUSlidingWindowSpec, +) + + +spec_manager_map.update({ + MLUFullAttentionSpec: FullAttentionManager, + MLUSlidingWindowSpec: SlidingWindowManager, + MLUMLAAttentionSpec: FullAttentionManager, +}) \ No newline at end of file diff --git a/vllm_mlu/v1/engine/__init__.py b/vllm_mlu/v1/engine/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/engine/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/engine/async_llm.py b/vllm_mlu/v1/engine/async_llm.py new file mode 100644 index 0000000..da4b568 --- /dev/null +++ b/vllm_mlu/v1/engine/async_llm.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.v1.engine.async_llm import AsyncLLM + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +class AsyncLLM_MluHijack(AsyncLLM): + + async def start_scheduler_profile(self) -> None: + await self.engine_core.start_scheduler_profile() + + async def stop_scheduler_profile(self) -> None: + await self.engine_core.stop_scheduler_profile() + + +MluHijackObject.apply_hijack(AsyncLLM, + "start_scheduler_profile", + AsyncLLM_MluHijack.start_scheduler_profile) +MluHijackObject.apply_hijack(AsyncLLM, + "stop_scheduler_profile", + AsyncLLM_MluHijack.stop_scheduler_profile) \ No newline at end of file diff --git a/vllm_mlu/v1/engine/core.py b/vllm_mlu/v1/engine/core.py new file mode 100644 index 0000000..02082bc --- /dev/null +++ b/vllm_mlu/v1/engine/core.py @@ -0,0 +1,566 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +from collections import deque +import signal +from typing import Any, Callable, cast +from concurrent.futures import Future + +from vllm.config import ParallelConfig, VllmConfig +from vllm.logger import logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import engine_receiver_cache_from_config +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.utils.gc_utils import freeze_gc_heap +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.utils.system_utils import decorate_logs, set_process_title +from vllm.v1.core.kv_cache_utils import BlockHash, get_request_block_hasher, init_none_hash +from vllm.v1.engine import EngineCoreOutputs +from vllm.v1.engine.core import ( + EngineCore, + EngineCoreProc, + DPEngineCoreProc, +) +from vllm.v1.executor.abstract import Executor +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager +from vllm.version import __version__ as VLLM_VERSION +from logging import DEBUG + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu.mlu_metric import LLMMetric + + +class EngineCore_MluHijack(EngineCore): + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Callable | None = None, + ): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: load_general_plugins in run_engine_core + ''' + # # plugins need to be loaded at the engine/scheduler level too + # from vllm.plugins import load_general_plugins + # load_general_plugins() + ''' + ================== + End of MLU Hijack + ================== + ''' + + self.vllm_config = vllm_config + if vllm_config.parallel_config.data_parallel_rank == 0: + logger.info( + "Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, + vllm_config, + ) + + self.log_stats = log_stats + + # Setup Model. + self.model_executor = executor_class(vllm_config) + if executor_fail_callback is not None: + self.model_executor.register_failure_callback(executor_fail_callback) + + self.available_gpu_memory_for_kv_cache = -1 + + # Setup KV Caches and update CacheConfig after profiling. + num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( + vllm_config + ) + + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks + vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) + + self.structured_output_manager = StructuredOutputManager(vllm_config) + + # Setup scheduler. + Scheduler = vllm_config.scheduler_config.get_scheduler_cls() + + if len(kv_cache_config.kv_cache_groups) == 0: + # Encoder models without KV cache don't support + # chunked prefill. But do SSM models? + logger.info("Disabling chunked prefill for model without KVCache") + vllm_config.scheduler_config.enable_chunked_prefill = False + + scheduler_block_size = ( + vllm_config.cache_config.block_size + * vllm_config.parallel_config.decode_context_parallel_size + ) + + self.scheduler: SchedulerInterface = Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + structured_output_manager=self.structured_output_manager, + include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, + log_stats=self.log_stats, + block_size=scheduler_block_size, + ) + self.use_spec_decode = vllm_config.speculative_config is not None + if self.scheduler.connector is not None: # type: ignore + self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore + + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = engine_receiver_cache_from_config( + vllm_config, mm_registry + ) + + # If a KV connector is initialized for scheduler, we want to collect + # handshake metadata from all workers so the connector in the scheduler + # will have the full context + kv_connector = self.scheduler.get_kv_connector() + if kv_connector is not None: + # Collect and store KV connector xfer metadata from workers + # (after KV cache registration) + xfer_handshake_metadata = ( + self.model_executor.get_kv_connector_handshake_metadata() + ) + + if xfer_handshake_metadata: + # xfer_handshake_metadata is list of dicts from workers + # Each dict already has structure {tp_rank: metadata} + # Merge all worker dicts into a single dict + content: dict[int, Any] = {} + for worker_dict in xfer_handshake_metadata: + if worker_dict is not None: + content.update(worker_dict) + kv_connector.set_xfer_handshake_metadata(content) + + # Setup batch queue for pipeline parallelism. + # Batch queue for scheduled batches. This enables us to asynchronously + # schedule and execute batches, and is required by pipeline parallelism + # to eliminate pipeline bubbles. + self.batch_queue_size = self.model_executor.max_concurrent_batches + self.batch_queue: ( + deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None + ) = None + if self.batch_queue_size > 1: + logger.info("Batch queue is enabled with size %d", self.batch_queue_size) + self.batch_queue = deque(maxlen=self.batch_queue_size) + + self.ec_producer = ( + vllm_config.ec_transfer_config is not None + and vllm_config.ec_transfer_config.is_ec_producer + ) + self.is_pooling_model = vllm_config.model_config.runner_type == "pooling" + + self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None + if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None: + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo + ) + init_none_hash(caching_hash_fn) + + self.request_block_hasher = get_request_block_hasher( + scheduler_block_size, caching_hash_fn + ) + + self.step_fn = ( + self.step if self.batch_queue is None else self.step_with_batch_queue + ) + self.async_scheduling = vllm_config.scheduler_config.async_scheduling + + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + freeze_gc_heap() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: v1 support offline benchmark + ''' + self.step_latency = [] + self.model_exec_latency = [] + self.mm_encoder_latency = [] + self.num_gpu_blocks = num_gpu_blocks + self.num_cpu_blocks = num_cpu_blocks + ''' + ================== + End of MLU Hijack + ================== + ''' + + def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: + """Schedule, execute, and make output. + + Returns tuple of outputs and a flag indicating whether the model + was executed. + """ + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: v1 support offline benchmark + ''' + if mlu_envs.VLLM_LATENCY_DEBUG_EN: + step_start = LLMMetric.get_mlu_cost_time() + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Check for any requests remaining in the scheduler - unfinished, + # or finished and not yet removed from the batch. + if not self.scheduler.has_requests(): + return {}, False + scheduler_output = self.scheduler.schedule() + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) + + if self.use_spec_decode and \ + self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.kv_role == "kv_producer": + draft_token_ids = self.model_executor.take_draft_token_ids() + self.scheduler.draft_token_ids = draft_token_ids + + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: v1 support offline benchmark + ''' + has_sched_reqs = (scheduler_output.total_num_scheduled_tokens > 0) + if mlu_envs.VLLM_LATENCY_DEBUG_EN and has_sched_reqs: + self.step_latency.append(LLMMetric.get_mlu_cost_time() - step_start) + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and has_sched_reqs: + self.model_exec_latency.append(self.get_model_exec_latency()) + mm_encoder_latency = self.get_mm_encoder_latency() + if mm_encoder_latency: + self.mm_encoder_latency.append(mm_encoder_latency) + ''' + ================== + End of MLU Hijack + ================== + ''' + return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 + + def step_with_batch_queue( + self, + ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]: + """Schedule and execute batches with the batch queue. + Note that if nothing to output in this step, None is returned. + + The execution flow is as follows: + 1. Try to schedule a new batch if the batch queue is not full. + If a new batch is scheduled, directly return an empty engine core + output. In other words, fulfilling the batch queue has a higher priority + than getting model outputs. + 2. If there is no new scheduled batch, meaning that the batch queue + is full or no other requests can be scheduled, we block until the first + batch in the job queue is finished. + 3. Update the scheduler from the output. + """ + batch_queue = self.batch_queue + assert batch_queue is not None + + # Try to schedule a new batch if the batch queue is not full, but + # the scheduler may return an empty batch if all requests are scheduled. + # Note that this is not blocking. + assert len(batch_queue) < self.batch_queue_size + + model_executed = False + deferred_scheduler_output = None + if self.scheduler.has_requests(): + scheduler_output = self.scheduler.schedule() + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) + if not self.ec_producer: + model_executed = scheduler_output.total_num_scheduled_tokens > 0 + + if self.is_pooling_model or not model_executed: + # No sampling required (no requests scheduled). + future = cast(Future[ModelRunnerOutput], exec_future) + else: + exec_future.add_done_callback(self._log_err_callback(scheduler_output)) + + if not scheduler_output.pending_structured_output_tokens: + # We aren't waiting for any tokens, get any grammar output + # and sample immediately. + grammar_output = self.scheduler.get_grammar_bitmask( + scheduler_output + ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) + else: + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + + if not deferred_scheduler_output: + # Add this step's future to the queue. + batch_queue.appendleft((future, scheduler_output)) + if ( + model_executed + and len(batch_queue) < self.batch_queue_size + and not batch_queue[-1][0].done() + ): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True + + elif not batch_queue: + # Queue is empty. We should not reach here since this method should + # only be called when the scheduler contains requests or the queue + # is non-empty. + return None, False + + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if self.use_spec_decode and \ + self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.kv_role == "kv_producer": + draft_token_ids = self.model_executor.take_draft_token_ids() + self.scheduler.draft_token_ids = draft_token_ids + ''' + ================== + End of MLU Hijack + ================== + ''' + + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) + + # NOTE(nick): We can either handle the deferred tasks here or save + # in a field and do it immediately once step_with_batch_queue is + # re-called. The latter slightly favors TTFT over TPOT/throughput. + if deferred_scheduler_output: + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens(grammar_output, non_block=True) + batch_queue.appendleft((future, deferred_scheduler_output)) + + return engine_core_outputs, model_executed + + def get_model_exec_latency(self): + latency = self.model_executor.get_latency() + return latency + + def get_mm_encoder_latency(self): + return self.model_executor.get_mm_encoder_latency() + + def get_hfu_info(self, batch, input_len, output_len): + return self.model_executor.get_hfu_info(batch, input_len, output_len) + + def get_latency(self): + return (self.step_latency, self.model_exec_latency, self.mm_encoder_latency) + + def get_memory_usage(self): + peak_memory, block_memory = self.model_executor.get_memory_usage() + return (peak_memory, block_memory, + self.num_gpu_blocks, self.num_cpu_blocks) + + def recapture_model(self, + prefill_enable_mlugraph: bool, + batch_size: int, + input_len: int): + self.model_executor.recapture_model( + prefill_enable_mlugraph, batch_size, input_len) + + def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int): + self.step_latency = [] + self.model_exec_latency = [] + self.mm_encoder_latency = [] + mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED = use_unchunk_sched + mlu_envs.VLLM_V1_MIN_PREFILL_BATCH = min_prefill_batch + + def start_scheduler_profile(self): + self.scheduler.start_scheduler_profile() + + def stop_scheduler_profile(self): + self.scheduler.stop_scheduler_profile() + + def response_remote_alloc_once(self): + self.model_executor.response_remote_alloc_once() + + +class EngineCoreProc_MluHijack(EngineCoreProc): + + @staticmethod + def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): + """Launch EngineCore busy loop in background process.""" + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: load_general_plugins for mp backend engine + ''' + # plugins need to be loaded at the engine/scheduler level too + from vllm.plugins import load_general_plugins + load_general_plugins() + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + # Ensure we can serialize transformer config after spawning + maybe_register_config_serialize_by_value() + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + engine_core: EngineCoreProc | None = None + try: + parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config + if parallel_config.data_parallel_size > 1 or dp_rank > 0: + set_process_title("EngineCore", f"DP{dp_rank}") + decorate_logs() + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, **kwargs) + else: + set_process_title("EngineCore") + decorate_logs() + engine_core = EngineCoreProc(*args, **kwargs) + + engine_core.run_busy_loop() + + except SystemExit: + logger.debug("EngineCore exiting.") + raise + except Exception as e: + if engine_core is None: + logger.exception("EngineCore failed to start.") + else: + logger.exception("EngineCore encountered a fatal error.") + engine_core._send_engine_dead() + raise e + finally: + if engine_core is not None: + engine_core.shutdown() + + def _process_input_queue(self): + """Exits when an engine step needs to be performed.""" + + waited = False + while ( + not self.engines_running + and not self.scheduler.has_requests() + and not self.batch_queue + ): + if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): + logger.debug("EngineCore waiting for work.") + waited = True + + if self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.kv_role == "kv_consumer": + self.response_remote_alloc_once() + if self.input_queue.empty(): + continue + req = self.input_queue.get_nowait() + self._handle_client_request(*req) + else: + req = self.input_queue.get() + self._handle_client_request(*req) + + if waited: + logger.debug("EngineCore loop active.") + + if self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.kv_role == "kv_consumer": + self.response_remote_alloc_once() + + # Handle any more client requests. + while not self.input_queue.empty(): + req = self.input_queue.get_nowait() + self._handle_client_request(*req) + + +MluHijackObject.apply_hijack(EngineCore, + "get_mm_encoder_latency", + EngineCore_MluHijack.get_mm_encoder_latency) +MluHijackObject.apply_hijack(EngineCore, + "get_model_exec_latency", + EngineCore_MluHijack.get_model_exec_latency) +MluHijackObject.apply_hijack(EngineCore, + "get_hfu_info", + EngineCore_MluHijack.get_hfu_info) +MluHijackObject.apply_hijack(EngineCore, + "get_latency", + EngineCore_MluHijack.get_latency) +MluHijackObject.apply_hijack(EngineCore, + "get_memory_usage", + EngineCore_MluHijack.get_memory_usage) +MluHijackObject.apply_hijack(EngineCore, + "recapture_model", + EngineCore_MluHijack.recapture_model) +MluHijackObject.apply_hijack(EngineCore, + "init_metric", + EngineCore_MluHijack.init_metric) +MluHijackObject.apply_hijack(EngineCore, + "start_scheduler_profile", + EngineCore_MluHijack.start_scheduler_profile) +MluHijackObject.apply_hijack(EngineCore, + "stop_scheduler_profile", + EngineCore_MluHijack.stop_scheduler_profile) +MluHijackObject.apply_hijack(EngineCore, + EngineCore.__init__, + EngineCore_MluHijack.__init__) +MluHijackObject.apply_hijack(EngineCore, + EngineCore.step, + EngineCore_MluHijack.step) +MluHijackObject.apply_hijack(EngineCore, + "response_remote_alloc_once", + EngineCore_MluHijack.response_remote_alloc_once) +MluHijackObject.apply_hijack(EngineCore, + EngineCore.step_with_batch_queue, + EngineCore_MluHijack.step_with_batch_queue) +MluHijackObject.apply_hijack(EngineCoreProc, + EngineCoreProc.run_engine_core, + EngineCoreProc_MluHijack.run_engine_core) +MluHijackObject.apply_hijack(EngineCoreProc, + EngineCoreProc._process_input_queue, + EngineCoreProc_MluHijack._process_input_queue) diff --git a/vllm_mlu/v1/engine/core_client.py b/vllm_mlu/v1/engine/core_client.py new file mode 100644 index 0000000..22b5be8 --- /dev/null +++ b/vllm_mlu/v1/engine/core_client.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.engine.core_client import ( + EngineCoreClient, + InprocClient, + SyncMPClient, + AsyncMPClient, + DPAsyncMPClient, + DPLBAsyncMPClient, +) +from vllm.v1.engine import EngineCoreRequest +from vllm.config import VllmConfig +from vllm.v1.executor import Executor + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +class EngineCoreClient_MluHiack(EngineCoreClient): + + @staticmethod + def make_async_mp_client( + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: dict[str, str] | None = None, + client_count: int = 1, + client_index: int = 0, + ) -> "MPClient": + parallel_config = vllm_config.parallel_config + client_args = ( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: disagg use DPAsyncMPClient instead of DPLBAsyncMPClient. + ''' + if parallel_config.data_parallel_size > 1: + if parallel_config.data_parallel_external_lb or vllm_config.kv_transfer_config is not None: + # External load balancer - client per DP rank. + return DPAsyncMPClient(*client_args) + # Internal load balancer - client balances to all DP ranks. + return DPLBAsyncMPClient(*client_args) + ''' + ================== + End of MLU Hijack + ================== + ''' + return AsyncMPClient(*client_args) + + +class InprocClient_MluHiack(InprocClient): + + def get_hfu_info(self, batch, input_len, output_len): + return self.engine_core.get_hfu_info(batch, input_len, output_len) + + def get_latency(self): + return self.engine_core.get_latency() + + def get_memory_usage(self): + return self.engine_core.get_memory_usage() + + def recapture_model( + self, + prefill_enable_mlugraph: bool, + batch_size: int, + input_len: int, + ): + return self.engine_core.recapture_model( + prefill_enable_mlugraph, batch_size, input_len + ) + + def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int): + return self.engine_core.init_metric( + use_unchunk_sched, min_prefill_batch, + ) + + def start_scheduler_profile(self): + self.engine_core.start_scheduler_profile() + + def stop_scheduler_profile(self): + self.engine_core.stop_scheduler_profile() + + def response_remote_alloc_once(self) -> None: + self.engine_core.response_remote_alloc_once() + + +class SyncMPClient_MluHiack(SyncMPClient): + + def get_hfu_info(self, batch, input_len, output_len): + try: + return self.call_utility("get_hfu_info", batch, input_len, output_len) + except Exception as e: + raise RuntimeError(f"Failed to get HFU info: {str(e)}") + + def get_latency(self): + return self.call_utility("get_latency") + + def get_memory_usage(self): + return self.call_utility("get_memory_usage") + + def recapture_model(self, + prefill_enable_mlugraph: bool, + batch_size: int, + input_len: int): + return self.call_utility("recapture_model", + prefill_enable_mlugraph, batch_size, input_len) + + def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int): + return self.call_utility("init_metric", + use_unchunk_sched, + min_prefill_batch) + + def start_scheduler_profile(self): + self.call_utility("start_scheduler_profile") + + def stop_scheduler_profile(self): + self.call_utility("stop_scheduler_profile") + + def response_remote_alloc_once(self) -> None: + self.call_utility("response_remote_alloc_once") + + +class AsyncMPClient_MluHijack(AsyncMPClient): + + async def start_scheduler_profile(self) -> None: + await self.call_utility_async("start_scheduler_profile") + + async def stop_scheduler_profile(self) -> None: + await self.call_utility_async("stop_scheduler_profile") + + async def response_remote_alloc_once(self) -> None: + await self.call_utility_async("response_remote_alloc_once") + + +class DPAsyncMPClient_MluHijack(DPAsyncMPClient): + + def get_core_engine_for_request(self, request: EngineCoreRequest): + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: disagg need proxy to assign dp_rank + ''' + if request.data_parallel_rank is not None: + # engines are already in rank order + return self.core_engines[request.data_parallel_rank] + ''' + ================== + End of MLU Hijack + ================== + ''' + + return self.core_engine + + +MluHijackObject.apply_hijack(EngineCoreClient, + EngineCoreClient.make_async_mp_client, + EngineCoreClient_MluHiack.make_async_mp_client) +MluHijackObject.apply_hijack(InprocClient, + "get_hfu_info", + InprocClient_MluHiack.get_hfu_info) +MluHijackObject.apply_hijack(InprocClient, + "get_latency", + InprocClient_MluHiack.get_latency) +MluHijackObject.apply_hijack(InprocClient, + "get_memory_usage", + InprocClient_MluHiack.get_memory_usage) +MluHijackObject.apply_hijack(InprocClient, + "recapture_model", + InprocClient_MluHiack.recapture_model) +MluHijackObject.apply_hijack(InprocClient, + "init_metric", + InprocClient_MluHiack.init_metric) +MluHijackObject.apply_hijack(InprocClient, + "start_scheduler_profile", + InprocClient_MluHiack.start_scheduler_profile) +MluHijackObject.apply_hijack(InprocClient, + "stop_scheduler_profile", + InprocClient_MluHiack.stop_scheduler_profile) +MluHijackObject.apply_hijack(InprocClient, + "response_remote_alloc_once", + InprocClient_MluHiack.response_remote_alloc_once) +MluHijackObject.apply_hijack(SyncMPClient, + "get_hfu_info", + SyncMPClient_MluHiack.get_hfu_info) +MluHijackObject.apply_hijack(SyncMPClient, + "get_latency", + SyncMPClient_MluHiack.get_latency) +MluHijackObject.apply_hijack(SyncMPClient, + "get_memory_usage", + SyncMPClient_MluHiack.get_memory_usage) +MluHijackObject.apply_hijack(SyncMPClient, + "recapture_model", + SyncMPClient_MluHiack.recapture_model) +MluHijackObject.apply_hijack(SyncMPClient, + "init_metric", + SyncMPClient_MluHiack.init_metric) +MluHijackObject.apply_hijack(SyncMPClient, + "start_scheduler_profile", + SyncMPClient_MluHiack.start_scheduler_profile) +MluHijackObject.apply_hijack(SyncMPClient, + "stop_scheduler_profile", + SyncMPClient_MluHiack.stop_scheduler_profile) +MluHijackObject.apply_hijack(SyncMPClient, + "response_remote_alloc_once", + SyncMPClient_MluHiack.response_remote_alloc_once) +MluHijackObject.apply_hijack(AsyncMPClient, + "start_scheduler_profile", + AsyncMPClient_MluHijack.start_scheduler_profile) +MluHijackObject.apply_hijack(AsyncMPClient, + "stop_scheduler_profile", + AsyncMPClient_MluHijack.stop_scheduler_profile) +MluHijackObject.apply_hijack(AsyncMPClient, + "response_remote_alloc_once", + AsyncMPClient_MluHijack.response_remote_alloc_once) +MluHijackObject.apply_hijack(DPAsyncMPClient, + DPAsyncMPClient.get_core_engine_for_request, + DPAsyncMPClient_MluHijack.get_core_engine_for_request) diff --git a/vllm_mlu/v1/engine/llm_engine.py b/vllm_mlu/v1/engine/llm_engine.py new file mode 100644 index 0000000..c45d48e --- /dev/null +++ b/vllm_mlu/v1/engine/llm_engine.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.engine.llm_engine import LLMEngine +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__engine__llm_engine__LLMEngine__get_hfu_info(self, batch, input_len, output_len): + return self.engine_core.get_hfu_info(batch, input_len, output_len) + + +def vllm__engine__llm_engine__LLMEngine__get_latency(self): + return self.engine_core.get_latency() + + +def vllm__engine__llm_engine__LLMEngine__get_memory_usage(self): + return self.engine_core.get_memory_usage() + + +def vllm__engine__llm_engine__LLMEngine__start_scheduler_profile(self): + self.engine_core.start_scheduler_profile() + + +def vllm__engine__llm_engine__LLMEngine__stop_scheduler_profile(self): + self.engine_core.stop_scheduler_profile() + + +MluHijackObject.apply_hijack(LLMEngine, + "get_hfu_info", + vllm__engine__llm_engine__LLMEngine__get_hfu_info) +MluHijackObject.apply_hijack(LLMEngine, + "get_latency", + vllm__engine__llm_engine__LLMEngine__get_latency) +MluHijackObject.apply_hijack(LLMEngine, + "get_memory_usage", + vllm__engine__llm_engine__LLMEngine__get_memory_usage) +MluHijackObject.apply_hijack(LLMEngine, + "start_scheduler_profile", + vllm__engine__llm_engine__LLMEngine__start_scheduler_profile) +MluHijackObject.apply_hijack(LLMEngine, + "stop_scheduler_profile", + vllm__engine__llm_engine__LLMEngine__stop_scheduler_profile) diff --git a/vllm_mlu/v1/executor/__init__.py b/vllm_mlu/v1/executor/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/executor/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/executor/abstract.py b/vllm_mlu/v1/executor/abstract.py new file mode 100644 index 0000000..f00f29f --- /dev/null +++ b/vllm_mlu/v1/executor/abstract.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.v1.executor.abstract import Executor + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm__v1__executor__abstract__Executor__get_hfu_info(self, batch, input_len, output_len): + output = self.collective_rpc("get_hfu_info", args=([batch, input_len, output_len])) + return max(output) + +def vllm__v1__executor__abstract__Executor__get_mm_encoder_latency(self): + output = self.collective_rpc("get_mm_encoder_latency") + return None if any(item is None for item in output) else max(output) + +def vllm__v1__executor__abstract__Executor__get_latency(self): + output = self.collective_rpc("get_latency") + return max(output) + + +def vllm__v1__executor__abstract__Executor__get_memory_usage(self): + output = self.collective_rpc("get_memory_usage") + return output[0] + + +def vllm__v1__executor__abstract__Executor__recapture_model( + self, prefill_enable_mlugraph: bool, batch_size: int, input_len: int): + self.collective_rpc("recapture_model", + args=(prefill_enable_mlugraph, batch_size, input_len)) + + +MluHijackObject.apply_hijack( + Executor, + "get_hfu_info", + vllm__v1__executor__abstract__Executor__get_hfu_info +) +MluHijackObject.apply_hijack( + Executor, + "get_latency", + vllm__v1__executor__abstract__Executor__get_latency +) +MluHijackObject.apply_hijack( + Executor, + "get_mm_encoder_latency", + vllm__v1__executor__abstract__Executor__get_mm_encoder_latency +) +MluHijackObject.apply_hijack( + Executor, + "get_memory_usage", + vllm__v1__executor__abstract__Executor__get_memory_usage +) +MluHijackObject.apply_hijack( + Executor, + "recapture_model", + vllm__v1__executor__abstract__Executor__recapture_model +) diff --git a/vllm_mlu/v1/executor/multiproc_executor.py b/vllm_mlu/v1/executor/multiproc_executor.py new file mode 100644 index 0000000..6ff0f4a --- /dev/null +++ b/vllm_mlu/v1/executor/multiproc_executor.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +class MultiprocExecutor_MluHijack(MultiprocExecutor): + + def response_remote_alloc_once(self) -> None: + self.collective_rpc("response_remote_alloc_once", unique_reply_rank=self.output_rank) + + +MluHijackObject.apply_hijack(MultiprocExecutor, + "response_remote_alloc_once", + MultiprocExecutor_MluHijack.response_remote_alloc_once) \ No newline at end of file diff --git a/vllm_mlu/v1/executor/ray_executor.py b/vllm_mlu/v1/executor/ray_executor.py new file mode 100644 index 0000000..802aab5 --- /dev/null +++ b/vllm_mlu/v1/executor/ray_executor.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +import os +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.ray.ray_env import get_env_vars_to_copy +from vllm.v1.executor.ray_executor import RayDistributedExecutor, RayWorkerMetaData +from vllm.v1.executor.ray_utils import ( + RayWorkerWrapper, + initialize_ray_cluster, + ray, +) +from vllm.utils.network_utils import ( + get_distributed_init_method, + get_ip, + get_open_port, +) + +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +class RayDistributedExecutor_MluHijack(RayDistributedExecutor): + + def _init_executor(self) -> None: + self.forward_dag: ray.dag.CompiledDAG | None = None + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: For MLU, avoid compiling NVIDIA's NCCL + ''' + # For TPU or XPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu() or current_platform.is_xpu() or \ + current_platform.is_out_of_tree(): + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" + ''' + ================== + End of MLU Hijack + ================== + ''' + + assert self.uses_ray + initialize_ray_cluster(self.parallel_config) + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + # KV connector setup + self.has_connector = self.vllm_config.kv_transfer_config is not None + + self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and ( + self.vllm_config.ec_transfer_config is None + or not self.vllm_config.ec_transfer_config.is_ec_producer + ) + + self.scheduler_output: SchedulerOutput | None = None + + def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use default cnperf config. + ''' + runtime_env.update({ + # use default cnperf config + "nsight": "default" + }) + ''' + ================== + End of MLU Hijack + ================== + ''' + + return ray_remote_kwargs + + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): + num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: RayWorkerWrapper | None = None + # The remaining workers are the actual ray actors. + self.workers: list[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: list[list[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs + ) + + # Create the workers. + bundle_indices: list[int] + if envs.VLLM_RAY_BUNDLE_INDICES: + # Use the bundle indices specified by the user. + bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) + assert len(bundle_indices) == self.parallel_config.world_size, ( + "VLLM_RAY_BUNDLE_INDICES must have the same size" + f" as the world size, but got {bundle_indices=} " + f"and {self.parallel_config.world_size=}" + ) + assert len(set(bundle_indices)) == len(bundle_indices), ( + "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}" + ) + else: + # use the first N bundles that have GPU resources. + bundle_indices = [] + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if bundle.get(current_platform.ray_device_key, 0): + bundle_indices.append(bundle_id) + bundle_indices = bundle_indices[: self.parallel_config.world_size] + + worker_metadata: list[RayWorkerMetaData] = [] + driver_ip = get_ip() + for rank, bundle_id in enumerate(bundle_indices): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support ray + cnperf-cli + ''' + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs['runtime_env'].update({ + "nsight": { + "o": f"cnperf_rank_{rank}", + "force_overwrite": "true" + } + }) + if rank == 0: + ray_remote_kwargs['runtime_env'].update({ + "nsight": {} + }) + ''' + ================== + End of MLU Hijack + ================== + ''' + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + if current_platform.ray_device_key == "GPU": + # NV+AMD GPUs, and Intel XPUs + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote( # type: ignore[attr-defined] + vllm_config=self.vllm_config, rpc_rank=rank + ) + else: + worker = ray.remote( + num_cpus=0, + num_gpus=0, + resources={current_platform.ray_device_key: num_gpus}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote( # type: ignore[attr-defined] + vllm_config=self.vllm_config, rpc_rank=rank + ) + worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) + + worker_ips = ray.get( + [ + each.worker.get_node_ip.remote() # type: ignore[attr-defined] + for each in worker_metadata + ] + ) + + for each, ip in zip(worker_metadata, worker_ips): + each.ip = ip + + logger.debug("workers: %s", worker_metadata) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + + ip_counts: dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = item.ip + return 0 if ip == driver_ip else 1, ip_counts[ip], ip + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + sorted_worker_metadata = sorted( + worker_metadata, key=sort_by_driver_then_worker_ip + ) + for i, item in enumerate(sorted_worker_metadata): + item.adjusted_rank = i + self.workers = [item.worker for item in sorted_worker_metadata] + rerank_mapping = { + item.created_rank: item.adjusted_rank for item in sorted_worker_metadata + } + self.collective_rpc("adjust_rank", args=(rerank_mapping,)) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) + ) # type: ignore[attr-defined] + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" + " each node." + ) + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [ + { + current_platform.device_control_env_var: ",".join( + map(str, node_gpus[node_id]) + ), + } + for (node_id, _) in worker_node_and_gpu_ids + ] + + # Environment variables to copy from driver to workers + env_vars_to_copy = get_env_vars_to_copy( + exclude_vars=self.WORKER_SPECIFIC_ENV_VARS, + additional_vars=set(current_platform.additional_env_vars).union( + self.ADDITIONAL_ENV_VARS + ), + destination="workers", + ) + + # Copy existing env vars to each worker's args + for args in all_args_to_update_environment_variables: + # TODO: refactor platform-specific env vars + for name in env_vars_to_copy: + if name in os.environ: + args[name] = os.environ[name] + + self._env_vars_for_all_workers = all_args_to_update_environment_variables + + self.collective_rpc( + "update_environment_variables", args=(self._get_env_vars_to_be_updated(),) + ) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port() + ) + + # Initialize the actual workers inside worker wrapper. + all_kwargs = [] + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): + local_rank = node_workers[node_id].index(rank) + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self.collective_rpc("init_worker", args=(all_kwargs,)) + + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range(self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + +MluHijackObject.apply_hijack( + RayDistributedExecutor, + RayDistributedExecutor._configure_ray_workers_use_nsight, + RayDistributedExecutor_MluHijack._configure_ray_workers_use_nsight +) +MluHijackObject.apply_hijack( + RayDistributedExecutor, + RayDistributedExecutor._init_workers_ray, + RayDistributedExecutor_MluHijack._init_workers_ray +) +MluHijackObject.apply_hijack( + RayDistributedExecutor, + RayDistributedExecutor._init_executor, + RayDistributedExecutor_MluHijack._init_executor +) \ No newline at end of file diff --git a/vllm_mlu/v1/kv_cache_interface.py b/vllm_mlu/v1/kv_cache_interface.py new file mode 100644 index 0000000..13f277e --- /dev/null +++ b/vllm_mlu/v1/kv_cache_interface.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from dataclasses import dataclass +from typing_extensions import Self + +import torch + +from math import prod +from vllm.logger import init_logger +from vllm.utils.torch_utils import get_dtype_size +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MLAAttentionSpec, + SlidingWindowSpec, + MambaSpec, +) + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class MLUFullAttentionSpec(FullAttentionSpec): + + @property + def type_id(self) -> str: + return f"mlu_full_attention_{self.block_size}_{self.page_size_bytes}" + + @property + def cache_size_bytes(self) -> int: + return ( + 2 + * self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) + + @property + def scale_size_bytes(self) -> int: + scale_size_bytes = 0 + if self.dtype in [torch.int8, torch.uint8]: + scale_size_bytes = ( + 2 + * self.block_size + * self.num_kv_heads + * get_dtype_size(torch.float32) + ) + return scale_size_bytes + + @property + def page_size_bytes(self) -> int: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: caculate kv_cache_scale size when kv_cache_dtype=int8 + ''' + return self.cache_size_bytes + self.scale_size_bytes + ''' + ================== + End of MLU Hijack + ================== + ''' + + +@dataclass(frozen=True) +class MLUMLAAttentionSpec(MLAAttentionSpec): + # Use to record k_cache info for DSA indexer + index_head_dim: int = 0 + index_n_heads: int = 0 + + @property + def type_id(self) -> str: + return f"mlu_mla_attention_{self.block_size}_{self.page_size_bytes}" + + @property + def cache_size_bytes(self) -> int: + return ( + self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) + + @property + def scale_size_bytes(self) -> int: + scale_size_bytes = 0 + if self.dtype in [torch.int8, torch.uint8]: + scale_size_bytes = ( + self.block_size + * self.num_kv_heads + * get_dtype_size(torch.float32) + ) + return scale_size_bytes + + @property + def index_cache_size_bytes(self) -> int: + return ( + self.block_size + * self.index_n_heads + * self.index_head_dim + * get_dtype_size(self.dtype) + ) + + @property + def page_size_bytes(self) -> int: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: caculate kv_cache_scale size when kv_cache_dtype=int8 + @brief: caculate indexer cache size for deepseek v3.2 + ''' + return self.cache_size_bytes + self.scale_size_bytes + self.index_cache_size_bytes + ''' + ================== + End of MLU Hijack + ================== + ''' + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be MLAAttentionSpec." + ) + cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) + assert len(cache_dtype_str_set) == 1, ( + "All attention layers in the same KV cache group must use the same " + "quantization method." + ) + return cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + cache_dtype_str=cache_dtype_str_set.pop(), + index_head_dim=specs[0].index_head_dim, + index_n_heads=specs[0].index_n_heads, + ) + + +@dataclass(frozen=True) +class MLUSlidingWindowSpec(SlidingWindowSpec): + + @property + def type_id(self) -> str: + return f"mlu_sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa + + @property + def cache_size_bytes(self) -> int: + return ( + 2 + * self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) + + @property + def scale_size_bytes(self) -> int: + scale_size_bytes = 0 + if self.dtype in [torch.int8, torch.uint8]: + scale_size_bytes = ( + 2 + * self.block_size + * self.num_kv_heads + * get_dtype_size(torch.float32) + ) + return scale_size_bytes + + @property + def page_size_bytes(self) -> int: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: caculate kv_cache_scale size when kv_cache_dtype=int8 + ''' + return self.cache_size_bytes + self.scale_size_bytes + ''' + ================== + End of MLU Hijack + ================== + ''' + +@property +def vllm__v1__kv_cache_interface__MambaSpec__page_size_bytes(self) -> int: + page_size = sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) + if self.page_size_padded is not None: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + # assert self.page_size_padded >= page_size + ''' + ================== + End of MLU Hijack + ================== + ''' + return self.page_size_padded + return page_size + +MluHijackObject.apply_hijack(MambaSpec, + MambaSpec.page_size_bytes, + vllm__v1__kv_cache_interface__MambaSpec__page_size_bytes) diff --git a/vllm_mlu/v1/sample/rejection_sampler.py b/vllm_mlu/v1/sample/rejection_sampler.py new file mode 100644 index 0000000..e22d88f --- /dev/null +++ b/vllm_mlu/v1/sample/rejection_sampler.py @@ -0,0 +1,946 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional +import math +import torch +import triton +import triton.language as tl +import vllm +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample import rejection_sampler +from vllm.v1.sample.rejection_sampler import sample_recovered_tokens + +from vllm_mlu.mlu_hijack_utils import MluHijackObject +from vllm_mlu._mlu_utils import * +from vllm_mlu import _mlu_ops as mlu_ops + +PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 +GREEDY_TEMPERATURE: tl.constexpr = 0 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 128 +''' +============================= +Modify by vllm_mlu +============================= +@brief: + - Limit maximum batch size due to NRAM memory constraints + - Add generate_recovered_uniform_probs function for tmo rejection sampler +''' +MAX_BATCH_SIZE = 65536 + +def generate_recovered_uniform_probs( + num_tokens: int, + vocab_size: int, + num_draft_tokens: list[int], + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + q = torch.empty( + (num_tokens, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + return q + +''' +============================= +End of MLU Hijack +============================= +''' + +def vllm__v1__sample__rejection_sampler__expand_batch_to_tokens( + x: torch.Tensor, # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. + Returns: + expanded_x: [num_tokens] tensor. + """ + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + ''' + ============================= + Modify by vllm_mlu + ============================= + ''' + if batch_size > MAX_BATCH_SIZE: + raise ValueError(f"Rejection Sampler Not Supported: " + f"Batch size exceeds the maximum allowed value of {MAX_BATCH_SIZE}") + ''' + ================== + End of MLU Hijack + ================== + ''' + expanded_x = x.new_empty(num_tokens) + vllm__v1__sample__rejection_sampler__expand_kernel[(batch_size, )]( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) + return expanded_x + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def vllm__v1__sample__rejection_sampler__expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: # noqa: SIM108 + ''' + ============================= + Modify by vllm_mlu + ============================= + ''' + # Ensure data types are consistent + start_idx = tl.full((), 0, tl.int64) + ''' + ================== + End of MLU Hijack + ================== + ''' + else: + ''' + ============================= + Modify by vllm_mlu + ============================= + ''' + start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1).to(tl.int64) + ''' + ================== + End of MLU Hijack + ================== + ''' + end_idx = tl.load(cu_num_tokens_ptr + req_idx) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + req_idx) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + offset = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens) + + +@triton.jit +def vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, + BLOCK_VOCAB: tl.constexpr = 2048, +): + req_idx = tl.program_id(0) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + ''' + ============================= + Modify by vllm_mlu + ============================= + ''' + max_score = -float("inf") + max_index = 0 + ''' + ================== + End of MLU Hijack + ================== + ''' + + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + draft_token_id) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + 0) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Replace with block loop due to ngram limitations + ''' + num_blocks = tl.cdiv(PADDED_VOCAB_SIZE, BLOCK_VOCAB) + + for i in tl.range(0, num_blocks): + offset = i * BLOCK_VOCAB + tl.arange(0, BLOCK_VOCAB) + mask = offset < vocab_size + + if NO_DRAFT_PROBS: + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + offset, + mask=mask, + other=0 + ) + else: + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + offset, + mask=mask, + other=0 + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + offset, + mask=mask, + other=0 + ) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + q = tl.load(q_ptr + req_idx * vocab_size + offset, + mask=mask, + other=float("-inf")) + score = prob / q # Broadcasting elementwise + cur_max = tl.argmax(score, axis=0) + + cur_score = score[cur_max] + cur_index = offset[cur_max] + + # Manually maintain argmax. + if cur_score > max_score: + max_score = cur_score + max_index = cur_index + + tl.store(output_token_ids_ptr + start_idx + pos, max_index) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if NO_DRAFT_PROBS: + # Restore the original probability. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + orig_prob) + +""" +============================= +Modify by vllm_mlu +============================= +""" +def filter_with_acceptance_rate(output_token_ids, # [batch_size, max_spec_len + 1] + fixed_acceptance_rate): + """ + Filter speculative tokens based on a fixed acceptance rate using batch-level accept/reject decisions. + + This function implements an adaptive acceptance rate control mechanism that maintains a target + acceptance rate over time through error compensation and PID-style adjustments. + + Args: + output_token_ids (torch.Tensor): Input tensor of shape [batch_size, max_spec_len + 1] + where the first column contains base tokens and remaining columns contain speculative tokens + fixed_acceptance_rate (float or None): Target acceptance rate between 0.0 and 1.0 + If None, returns input tensor unchanged + + Returns: + torch.Tensor: Modified tensor where rejected batches have all speculative tokens + (columns 1 to max_spec_len) set to PLACEHOLDER_TOKEN_ID + + Algorithm Flow: + 1. **Initialization Phase**: + - Extract batch dimensions and device information + - Initialize static variables for tracking acceptance statistics: + * cumulative_error: Long-term error accumulation + * total_batches/accepted_batches: Global acceptance tracking + * acceptance_history: Sliding window for recent performance + * precision_adjustment: PID controller adjustment factor + * recent_adjustments: Error history for PID calculation + + 2. **Statistics Calculation**: + - Calculate global acceptance rate from all historical data + - Calculate sliding window acceptance rate from recent batches + - Compute combined error using weighted average of global and window errors + - Weight transitions from global-focused (early) to window-focused (later) + + 3. **PID Controller Adjustment** (after 50+ batches): + - Proportional term: Current error magnitude + - Integral term: Accumulated error over recent history + - Derivative term: Rate of error change + - Combines P, I, D terms to compute precision adjustment factor + - Limits adjustment range to prevent over-correction + + 4. **Error Correction**: + - Applies smooth nonlinear correction based on combined error magnitude + - Uses exponential decay mapping for gradual adjustment strength + - Handles boundary cases (0.0, 1.0, very low rates) specially + + 5. **Gap-based Adjustment**: + - Calculates difference between target and actual accepted batches + - Applies adaptive threshold-based corrections + - Uses exponential smoothing for adjustment strength + - Adjustment strength decreases as total batch count increases + + 6. **Random Perturbation** (after 100+ batches): + - Adds small random noise to prevent local minima + - Noise amplitude decreases over time for stability + + 7. **Batch Decision**: + - Generates random value and compares with adjusted acceptance rate + - Makes binary accept/reject decision for entire batch + + 8. **Token Modification**: + - If accepted: Leave all tokens unchanged + - If rejected: Set all speculative tokens (columns 1:) to PLACEHOLDER_TOKEN_ID + - This ensures token-level acceptance rate matches batch-level rate + + 9. **State Updates**: + - Update acceptance counters and history + - Update cumulative error using exponential moving average + - Prepare state for next function call + + Key Features: + - **Batch-level consistency**: All samples in a batch share the same accept/reject fate + - **Adaptive control**: Uses multiple feedback mechanisms (global, windowed, PID) + - **Error compensation**: Corrects for deviations from target rate over time + - **Stability mechanisms**: Includes smoothing, limits, and perturbation for robustness + - **Token-level alignment**: Ensures token acceptance rate matches batch acceptance rate + + Note: This function maintains internal state across calls through static variables, + so it will converge to the target acceptance rate over multiple invocations. + """ + if fixed_acceptance_rate is None: + return output_token_ids + else: + # Apply accept/reject decisions for the entire batch based on fixed_acceptance_rate + batch_size = output_token_ids.shape[0] + max_spec_len = output_token_ids.shape[1] - 1 # Get max_spec_len + device = output_token_ids.device + + assert fixed_acceptance_rate >= 0 and fixed_acceptance_rate <= 1 + + # Use error compensation method to track global acceptance rate + # These are static variables that persist between calls + if not hasattr(filter_with_acceptance_rate, "cumulative_error"): + filter_with_acceptance_rate.cumulative_error = 0.0 + if not hasattr(filter_with_acceptance_rate, "total_batches"): + filter_with_acceptance_rate.total_batches = 0 + if not hasattr(filter_with_acceptance_rate, "accepted_batches"): + filter_with_acceptance_rate.accepted_batches = 0 + if not hasattr(filter_with_acceptance_rate, "window_size"): + filter_with_acceptance_rate.window_size = 1000 # Sliding window size + if not hasattr(filter_with_acceptance_rate, "acceptance_history"): + filter_with_acceptance_rate.acceptance_history = [] # Track recent accept/reject history + if not hasattr(filter_with_acceptance_rate, "precision_adjustment"): + filter_with_acceptance_rate.precision_adjustment = 0.0 # Precision adjustment factor + if not hasattr(filter_with_acceptance_rate, "recent_adjustments"): + filter_with_acceptance_rate.recent_adjustments = [] # Recent adjustment history + if not hasattr(filter_with_acceptance_rate, "target_rate"): + filter_with_acceptance_rate.target_rate = fixed_acceptance_rate # Record target acceptance rate + else: + # If target acceptance rate changes, reset adjustment state + if filter_with_acceptance_rate.target_rate != fixed_acceptance_rate: + filter_with_acceptance_rate.precision_adjustment = 0.0 + filter_with_acceptance_rate.recent_adjustments = [] + filter_with_acceptance_rate.target_rate = fixed_acceptance_rate + + # Update batch count + filter_with_acceptance_rate.total_batches += 1 + + # Calculate current global acceptance rate + global_rate = (filter_with_acceptance_rate.accepted_batches / + filter_with_acceptance_rate.total_batches if + filter_with_acceptance_rate.total_batches > 0 else 0.0) + + # Calculate sliding window acceptance rate (focusing on recent performance) + filter_with_acceptance_rate.acceptance_history.append(0) # Default to reject + if len(filter_with_acceptance_rate.acceptance_history) > filter_with_acceptance_rate.window_size: + filter_with_acceptance_rate.acceptance_history.pop(0) # Remove oldest record + + window_rate = sum(filter_with_acceptance_rate.acceptance_history) / len(filter_with_acceptance_rate.acceptance_history) + + # Enhance precision for small batches - use smoother weight function + batch_weight_factor = 1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 30.0) # Exponential smooth transition + + # Dynamically adjust error weights: rely more on global error for fewer batches, + # more on sliding window error as batch count increases + window_size = len(filter_with_acceptance_rate.acceptance_history) + window_significance = min(window_size / 100.0, 0.9) # Window significance depends on historical data volume + window_weight = window_significance * batch_weight_factor + global_weight = 1.0 - window_weight + + # Consider both global error and window error + combined_error = (global_weight * (global_rate - fixed_acceptance_rate) + + window_weight * (window_rate - fixed_acceptance_rate)) + + # Update precision adjustment factor - use PID controller style adjustment + if filter_with_acceptance_rate.total_batches > 50: + # Only perform precision adjustment when there's enough data + current_error = global_rate - fixed_acceptance_rate + + # Save recent adjustment history + filter_with_acceptance_rate.recent_adjustments.append(current_error) + if len(filter_with_acceptance_rate.recent_adjustments) > 20: # Keep recent 20 errors + filter_with_acceptance_rate.recent_adjustments.pop(0) + + # PID controller parameters + kp = 0.05 # Proportional coefficient + ki = 0.001 # Integral coefficient + kd = 0.01 # Derivative coefficient + + # Proportional term - current error + p_term = current_error + + # Integral term - accumulated error + i_term = sum(filter_with_acceptance_rate.recent_adjustments) + + # Derivative term - error change rate + d_term = 0.0 + if len(filter_with_acceptance_rate.recent_adjustments) >= 2: + d_term = filter_with_acceptance_rate.recent_adjustments[-1] - filter_with_acceptance_rate.recent_adjustments[-2] + + # Calculate PID adjustment + pid_adjustment = kp * p_term + ki * i_term + kd * d_term + + # Update precision adjustment factor + filter_with_acceptance_rate.precision_adjustment = pid_adjustment + + # Limit adjustment factor range to prevent over-adjustment + max_adjustment = 0.02 + 0.03 * (1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 500.0)) + filter_with_acceptance_rate.precision_adjustment = max(-max_adjustment, min(max_adjustment, filter_with_acceptance_rate.precision_adjustment)) + + # Calculate acceptance rate correction factor + error_magnitude = abs(combined_error) + correction_factor = 1.0 + + # Use more refined error correction logic - use smooth nonlinear correction function + if error_magnitude > 0.0005: # Correct even smaller errors + # Use smooth correction function instead of piecewise function + base_strength = 2.0 + error_scale = 1.0 - math.exp(-error_magnitude * 50.0) # Exponential decay mapping to [0,1] + correction_strength = base_strength + error_scale * 1.5 # Range from 2.0 to 3.5 + + # Smooth correction + sign = 1 if combined_error > 0 else -1 + correction_factor = 1.0 + (correction_strength * error_magnitude * sign) + + # Handle boundary cases to avoid division by zero + if correction_factor == 0.0: + correction_factor = 1.0 + + # Apply correction factor + adjusted_rate = max(0.0, min(1.0, fixed_acceptance_rate * (1.0 / correction_factor))) + + # Apply precision adjustment factor + adjusted_rate = max(0.0, min(1.0, adjusted_rate - filter_with_acceptance_rate.precision_adjustment)) + + # More precise boundary case handling + if fixed_acceptance_rate > 0 and fixed_acceptance_rate < 0.05: + if filter_with_acceptance_rate.total_batches % int(1/fixed_acceptance_rate) == 0: + adjusted_rate = 1.0 # Periodically force accept to ensure accuracy in low acceptance rate scenarios + # If fixed_acceptance_rate is 0, directly reject + elif fixed_acceptance_rate == 0.0: + adjusted_rate = 0.0 + # If fixed_acceptance_rate is 1, directly accept + elif fixed_acceptance_rate == 1.0: + adjusted_rate = 1.0 + + # Make precise adjustments for cases with large remaining errors + target_accepted = int(filter_with_acceptance_rate.total_batches * fixed_acceptance_rate + 0.5) # Round to nearest + actual_accepted = filter_with_acceptance_rate.accepted_batches + acceptance_gap = target_accepted - actual_accepted + + # More aggressive gap adjustment strategy - use adaptive threshold and smooth adjustment + gap_relative = abs(acceptance_gap) / max(1, filter_with_acceptance_rate.total_batches) + gap_threshold = max(1, int(filter_with_acceptance_rate.total_batches * 0.005)) # Smaller dynamic threshold, at least 1 + + # Dynamically adjust acceptance rate based on the gap + if abs(acceptance_gap) >= gap_threshold: # Use dynamic threshold + # Use smooth adjustment strategy + if acceptance_gap > 0: # Need to accept more + # Use exponential function for smooth adjustment + gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1] + # Adjustment strength decreases as total batch count increases + strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0) + boost_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count + adjusted_rate = min(1.0, adjusted_rate + (1.0 - adjusted_rate) * boost_factor) + else: # Accepted too many, need to reject + # Use exponential function for smooth adjustment + gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1] + # Adjustment strength decreases as total batch count increases + strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0) + reduction_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count + adjusted_rate = max(0.0, adjusted_rate * (1.0 - reduction_factor)) + + # Add small random perturbation in fixed intervals to enhance convergence + if 0.01 < adjusted_rate < 0.99 and filter_with_acceptance_rate.total_batches > 100: + # Random perturbation amplitude decreases as batch count increases + noise_amplitude = 0.01 * math.exp(-filter_with_acceptance_rate.total_batches / 500.0) + noise = (torch.rand(1, device=device).item() * 2 - 1) * noise_amplitude + adjusted_rate = max(0.0, min(1.0, adjusted_rate + noise)) + + # Generate a random number to decide whether to accept the current batch + random_value = torch.rand(1, device=device).item() + accept_batch = random_value < adjusted_rate + + # Set some tokens to PLACEHOLDER_TOKEN_ID to achieve specified acceptance rate + # Support max_spec_len > 1 cases + if accept_batch: + # Accept batch - don't modify token_ids + filter_with_acceptance_rate.accepted_batches += 1 + filter_with_acceptance_rate.acceptance_history[-1] = 1 # Update the most recent acceptance status + else: + # Reject batch - set all speculative tokens (except first column) to PLACEHOLDER_TOKEN_ID + # This ensures token-level acceptance rate matches batch-level acceptance rate + output_token_ids[:, 1:] = PLACEHOLDER_TOKEN_ID + + # Note: acceptance rate calculation is still based on entire batch accept/reject, no modification needed + # But we can add a comment explaining how actual token-level acceptance rate is calculated + # Actual token-level acceptance rate = 1 - (number of PLACEHOLDER_TOKEN_ID in output_token_ids / max_spec_len) + + # Update cumulative error - use exponential moving average for smoother error adjustment + actual_rate = filter_with_acceptance_rate.accepted_batches / filter_with_acceptance_rate.total_batches + # Use EMA to smooth error updates - use adaptive EMA coefficient + alpha = 0.05 * math.exp(-filter_with_acceptance_rate.total_batches / 200.0) + 0.01 # EMA coefficient gradually decreases over time + filter_with_acceptance_rate.cumulative_error = (alpha * (actual_rate - fixed_acceptance_rate) + + (1 - alpha) * filter_with_acceptance_rate.cumulative_error) + + return output_token_ids +""" +============================= +End of MLU Hijack +============================= +""" + +def vllm__v1__sample__rejection_sampler__rejection_sample( + # [num_tokens] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + max_spec_len: int, + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: torch.Tensor | None, + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 + + batch_size = len(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use tmo rejection_sample for all random sampling requests + ''' + fixed_acceptance_rate = VLLM_MTP_FIXED_ACCEPTANCE_RATE + use_fusion_kernel = (sampling_metadata.all_random + and max_spec_len == 1 + and (num_draft_tokens is not None + and 0 not in num_draft_tokens)) + if use_fusion_kernel: + # All data is random, use tmo rejection_sample + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_rand = vllm__v1__sample__rejection_sampler__generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + # generate random probs for recovered tokens + uniform_probs = generate_recovered_uniform_probs( + num_tokens, + vocab_size, + num_draft_tokens, + sampling_metadata, + device, + ) + # num_draft_tokens need to be a tensor + num_draft_tokens_tensor = torch.tensor(num_draft_tokens, dtype=torch.int32, device=device) + # tmo rejection_sample dtype need to be int32 + bonus_token_ids = bonus_token_ids.to(torch.int32) + draft_token_ids = draft_token_ids.to(torch.int32) + # use tmo rejection_sample + output_token_ids = mlu_ops.rejection_sample( + draft_token_ids, + num_draft_tokens_tensor, + cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + uniform_rand, + uniform_probs, + max_spec_len, + high_acc=True # for now, only support high_acc + ).view(batch_size, max_spec_len + 1) + if fixed_acceptance_rate is not None: + # set all speculative tokens to placeholder token + output_token_ids[:, 1:] = 0 + output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate) + return output_token_ids + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + # Create output buffer. + output_token_ids = torch.full( + (batch_size, max_spec_len + 1), + PLACEHOLDER_TOKEN_ID, + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. + device=device, + ) + + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + has_acceptance_rate=fixed_acceptance_rate is not None, + ) + if sampling_metadata.all_greedy: + output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate) + return output_token_ids + + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_probs = vllm__v1__sample__rejection_sampler__generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add fixed acceptance rate check + ''' + # Rejection sampling for random sampling requests. + vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + has_acceptance_rate=fixed_acceptance_rate is not None, + ) + output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate) + ''' + ================== + End of MLU Hijack + ================== + ''' + return output_token_ids + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["max_spec_len"]) +def vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + NO_DRAFT_PROBS: tl.constexpr, + has_acceptance_rate: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add accept rate check, always accept if has_acceptance_rate is True + ''' + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob or has_acceptance_rate: + # Accept. + token_id = draft_token_id + else: + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Check whether to accept bonus token through acceptance_rate_ptr + ''' + # If has acceptance rate, all tokens are accepted + if has_acceptance_rate: + rejected = False + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. +@triton.jit(do_not_specialize=["max_spec_len"]) +def vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None + max_spec_len, + has_acceptance_rate: tl.constexpr, +): + req_idx = tl.program_id(0) + # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, + # re-compilation may happen during runtime when is_greedy_ptr is None. + is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests. + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) + if draft_token_id != target_argmax_id: + # Reject. + rejected = True + if has_acceptance_rate: + rejected = False + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) + + +def vllm__v1__sample__rejection_sampler__generate_uniform_probs( + num_tokens: int, + num_draft_tokens: list[int], + generators: dict[int, torch.Generator], + device: torch.device, +) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + if available. + + This method creates a tensor of shape `(num_tokens, )` filled + with uniform random values in the range [0, 1). If `generators` is provided, + the requests with their own seeds will use the provided `torch.Generator` + for reproducibility. The samples for the other requests will be generated + without a seed. + + Args: + num_tokens: int + Total number of tokens. + num_draft_tokens: List[List[int]] + Number of draft tokens per request. + generators: Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + device: torch.device + The device on which to allocate the tensor. + Returns: + uniform_rand: torch.Tensor + A tensor of shape `(num_tokens, )` containing uniform + random values in the range [0, 1). + """ + # NOTE(woosuk): We deliberately use float64 instead of float32 here + # because when using float32, there's a non-negligible chance that + # uniform_prob is sampled to be exact 0.0 as reported in + # https://github.com/pytorch/pytorch/issues/16706. Using float64 + # mitigates the issue. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Changed torch.float64 to torch.float32 + ''' + uniform_probs = torch.rand( + (num_tokens,), + dtype=torch.float32, + device=device, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + start_idx = 0 + for req_idx, n in enumerate(num_draft_tokens): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if n == 0: + continue + end_idx = start_idx + n + generator = generators.get(req_idx) + if generator is not None: + uniform_probs[start_idx:end_idx].uniform_(generator=generator) + start_idx = end_idx + return uniform_probs + + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.generate_uniform_probs, + vllm__v1__sample__rejection_sampler__generate_uniform_probs) + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.expand_batch_to_tokens, + vllm__v1__sample__rejection_sampler__expand_batch_to_tokens) + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.expand_kernel, + vllm__v1__sample__rejection_sampler__expand_kernel) + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.sample_recovered_tokens_kernel, + vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel) + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.rejection_sample, + vllm__v1__sample__rejection_sampler__rejection_sample) + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.rejection_random_sample_kernel, + vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel) + +MluHijackObject.apply_hijack(rejection_sampler, + rejection_sampler.rejection_greedy_sample_kernel, + vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel) diff --git a/vllm_mlu/v1/sample/sampler.py b/vllm_mlu/v1/sample/sampler.py new file mode 100644 index 0000000..bd68f69 --- /dev/null +++ b/vllm_mlu/v1/sample/sampler.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 + +import torch +from vllm.config.model import LogprobsMode +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS + +from vllm_mlu._mlu_utils import * +from vllm_mlu import _mlu_ops as mlu_ops + +""" +@brief: use tmo random_sample +""" +def mlu_random_sample( + probs: torch.Tensor, + generators: dict[int, torch.Generator], +) -> torch.Tensor: + is_gumbel_max = True + return mlu_ops.random_sample(probs, is_gumbel_max, generators).view(-1) + +class MluSampler(Sampler): + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + logprobs_mode_override: LogprobsMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Sample logits based on sampling metadata. + + The various logits processing functions called in this method + may update the logits tensor in-place. + """ + + logprobs_mode = logprobs_mode_override or self.logprobs_mode + assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) + if sampling_metadata.all_random: + greedy_sampled = None + else: + greedy_sampled = self.greedy_sample(logits) + if sampling_metadata.all_greedy: + processed_logprobs = None + if sampling_metadata.max_num_logprobs is not None: + if logprobs_mode == "processed_logits": + processed_logprobs = logits + elif logprobs_mode == "processed_logprobs": + processed_logprobs = self.compute_logprobs(logits) + return greedy_sampled, processed_logprobs + + assert sampling_metadata.temperature is not None + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: use tmo topk_topp_sampler to sample. + """ + use_tmo = (sampling_metadata.top_k is not None) or (sampling_metadata.top_p is not None) + if use_tmo: + batch_size, vocab_size = logits.shape + index_in = torch.arange(vocab_size, dtype=torch.int32, device=logits.device) + ( + logits_out, + sorted_logits_out, + index_out, + true_select_len, + ) = mlu_ops.apply_topkp_v2( + logits, + index_in, + sampling_metadata.temperature, + None, + sampling_metadata.top_k.to(torch.int32) if sampling_metadata.top_k is not None else None, + sampling_metadata.top_p, + ) + + processed_logprobs = None + if logprobs_mode == "processed_logits": + processed_logprobs = logits + elif logprobs_mode == "processed_logprobs": + processed_logprobs = self.compute_logprobs(logits) + + probs = logits_out.softmax(dim=-1, dtype=torch.float32) + random_sampled = mlu_random_sample(probs, sampling_metadata.generators) + else: + # Apply temperature. + logits = self.apply_temperature( + logits, sampling_metadata.temperature, sampling_metadata.all_random + ) + + # Apply logits processors that only apply to random sampling + # (argmax invariant) + for processor in sampling_metadata.logitsprocs.argmax_invariant: + logits = processor.apply(logits) + + # Apply top_k and/or top_p. + random_sampled, processed_logprobs = self.topk_topp_sampler( + logits, + sampling_metadata.generators, + sampling_metadata.top_k, + sampling_metadata.top_p, + ) + """ + ================= + End of MLU Hijack + ================= + """ + + if greedy_sampled is None: + return random_sampled, processed_logprobs + + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + out=greedy_sampled, # Reuse tensor + ) + return sampled, processed_logprobs diff --git a/vllm_mlu/v1/spec_decode/dp_eagle.py b/vllm_mlu/v1/spec_decode/dp_eagle.py new file mode 100644 index 0000000..8f9f53d --- /dev/null +++ b/vllm_mlu/v1/spec_decode/dp_eagle.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +from typing import List, Optional, Any +import copy + +import torch +import torch.nn.functional as F +from vllm.config.vllm import CUDAGraphMode +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, + FlashAttentionMetadata) +from vllm.v1.sample.metadata import SamplingMetadata + +from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, logger + +from vllm.distributed.communication_op import tensor_model_parallel_all_gather_into_list +from vllm.distributed import ( + get_logits_tp_world_size, + get_logits_tp_group, + get_tensor_model_parallel_world_size, +) +from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata +from vllm_mlu.v1.attention.backends.mla.flashmla import FlashMLAMetadataBuilder +from vllm_mlu.v1.attention.backends.utils import ( + MLUCommonAttentionMetadata, COMMON_METADATA_STR) +from vllm_mlu._mlu_utils import * +from vllm_mlu.v1.attention.backends.utils import MLUInferMode + +from vllm_mlu.mlu_forward_context import MLUDPMetadata +from vllm_mlu.v1.spec_decode.eagle import MluEagleProposer +from vllm_mlu.model_executor.models.dp_utils import ( + enable_data_parallel, + DataParallelRuntimeParams +) + +class DPMluEagleProposer(MluEagleProposer): + + def get_logits_batch_sizes(self, batch_size: int) -> Optional[List[int]]: + tp_world_size, logits_batch_sizes = get_logits_tp_world_size(), None + if tp_world_size != get_tensor_model_parallel_world_size(): + tp_tensor = torch.tensor([batch_size]).to(self.runner.device) + outputs = tensor_model_parallel_all_gather_into_list(tp_tensor, get_logits_tp_group()) + # Convert device tensor to host list + outputs = torch.cat(outputs).tolist() + logits_batch_sizes = [outputs[i] for i in range(tp_world_size)] + return logits_batch_sizes + + def propose_ds_execute_dummy_batch( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + dp_params: DataParallelRuntimeParams, + ) -> tuple[torch.Tensor, torch.Tensor]: + # num_scheduled_tokens + num_tokens = target_token_ids.shape[0] + input_ids = self.input_ids[:num_tokens] + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + input_ids[:-1] = target_token_ids[1:] + + # always skip attn compute + attn_metadata: Optional[dict[str, Any]] = None + + # Get graph capture related infomation for deepseek model. + + with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): + hidden_states = self.model( + input_ids=input_ids, + positions=target_positions, + hidden_states=target_hidden_states, + intermediate_tensors=None, + inputs_embeds=None, + dp_params=dp_params, + ) + if dp_params is not None: + dp_params.logits_batch_split_list = self.get_logits_batch_sizes(num_tokens) + _ = self.model.compute_logits(hidden_states, dp_params=dp_params) + + if self.num_speculative_tokens == 1: + return + ''' + ============================= + Modify by vllm_mlu + @brief: support k > 1, need run draft model k-1 times + ============================= + ''' + # support k > 1 + for _ in range(self.num_speculative_tokens - 1): + new_dp_params = self.runner._get_data_parallel_metadata( + num_tokens, num_tokens, True, [1] * num_tokens) + with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): + hidden_states = self.model( + input_ids=input_ids, + positions=target_positions, + hidden_states=target_hidden_states, + intermediate_tensors=None, + inputs_embeds=None, + dp_params=new_dp_params, + ) + _ = self.model.compute_logits(hidden_states, dp_params=new_dp_params) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, + common_attn_metadata: MLUCommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + # [batch_size] + num_rejected_tokens: torch.Tensor, + # [num_tokens] + token_indices: torch.Tensor, + whole_block_table: torch.Tensor, + main_model_dp_params: Optional[DataParallelRuntimeParams] = None, + time_markers: List =[], + ) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + hidden_states_indices = last_token_indices + + assert self.runner is not None + + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Use full graph with draft model and pad batch_size for dp + ''' + dp_group_max_token_num = max(main_model_dp_params.token_split_list) + if dp_group_max_token_num <= self.vllm_config.compilation_config.max_cudagraph_capture_size: + batch_descriptor_num_tokens = self.vllm_config.pad_for_cudagraph(dp_group_max_token_num) + captured_already = True + else: + batch_descriptor_num_tokens = num_tokens + captured_already = False + + # Determine if we can use full graph + decode_only = all(not prefill for prefill in main_model_dp_params.dp_is_prefill) + # FIXME(wangchao2): disable mtp graph for ds3.2 with dp fow now(core dump) + is_dsv32 = self.vllm_config.model_config.hf_config.model_type == "deepseek_v32" + use_full_graph = (self.use_cuda_graph + and decode_only and captured_already and not is_dsv32) + if (self.use_cuda_graph and decode_only and not use_full_graph and not is_dsv32): + logger.warning_once( + f"Select MLU-V1 Full-MLUGraph mode with drafter, however running in " + + f"eager mode: decode_only={decode_only}, captured_already={captured_already}, " + + f"num_tokens={num_tokens}." + ) + + cudagraph_runtime_mode = CUDAGraphMode.FULL if use_full_graph else CUDAGraphMode.NONE + batch_descriptor = BatchDescriptor( + num_tokens=batch_descriptor_num_tokens, + uniform_decode=True, + ) + + # dp pad batch_size + if use_full_graph: + K = self.num_speculative_tokens + num_input_tokens = batch_descriptor_num_tokens + padded_batch_size = num_input_tokens // (K + 1) + else: + padded_batch_size = batch_size + num_input_tokens = num_tokens + + # change attn metadata num_actual_tokens + attn_metadata.num_actual_tokens = num_input_tokens + + common_attn_metadata_copy = None + # copy common_attn_metadata when k>1 for draft model, + # because dp pad batch_size will change common_attn_metadata + if self.num_speculative_tokens > 1: + common_attn_metadata_copy = copy.deepcopy(common_attn_metadata) + # pad attn metadata + if use_full_graph and enable_data_parallel() and num_input_tokens != num_tokens: + assert self.runner is not None + # Update attention metadata. + pad_attn_metadata( + attn_metadata, + common_attn_metadata, + whole_block_table, + self.runner, + num_tokens, + num_input_tokens, + batch_size, + padded_batch_size, + ) + + # Update input ids, pad with 0 if necessary. + token_pad_size = num_input_tokens - num_tokens + assert token_pad_size >= 0 + # Update target hidden states, pad with zeros if necessary. + if token_pad_size > 0: + target_hidden_states = F.pad( + target_hidden_states, + (0, 0, 0, token_pad_size), + value=0.0 + ) + + # Update positions, pad with zeros if necessary. + if token_pad_size > 0: + target_positions = F.pad( + target_positions, + (0, token_pad_size), + value=0 + ) + + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata + + # copy inputs to buffer for cudagraph + self.positions[:num_input_tokens] = target_positions + self.hidden_states[:num_input_tokens] = target_hidden_states + + kwargs = {} if main_model_dp_params is None else {"dp_params": main_model_dp_params} + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + start = torch.mlu.Event(enable_timing=True) + start.record() + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + batch_descriptor=batch_descriptor if use_full_graph else None, + cudagraph_runtime_mode=cudagraph_runtime_mode): + if use_full_graph: + ret_hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + intermediate_tensors=None, + inputs_embeds=None, + is_running_drafter=True, + **kwargs, + ) + else: + ret_hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + **kwargs, + ) + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + time_markers.append([start, end]) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + ''' + ============================= + End of MLU Hijack + ============================= + ''' + if main_model_dp_params is not None: + # Ensure main_model_dp_params has required attribute before assignment + if hasattr(main_model_dp_params, 'logits_batch_split_list'): + main_model_dp_params.logits_batch_split_list = self.get_logits_batch_sizes(batch_size) + else: + raise AttributeError("dp_params must have 'logits_batch_split_list' attribute") + + sample_hidden_states = last_hidden_states[hidden_states_indices] + logits = self.model.compute_logits(sample_hidden_states, dp_params=main_model_dp_params) + + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + if self.uses_mrope: + positions = target_positions[:, last_token_indices] + else: + positions = target_positions[last_token_indices] + + ''' + ============================= + Modify by vllm_mlu + ============================= + ''' + hidden_states = last_hidden_states[hidden_states_indices] + ''' + ============================= + End of MLU Hijack + ============================= + ''' + # Generate the remaining draft tokens. + draft_token_ids_list = [draft_token_ids] + + input_batch_size = batch_size + + if common_attn_metadata.infer_mode != MLUInferMode.DECODE_ONLY: + seq_lens_cpu = torch.ones(input_batch_size, dtype=torch.int32,) + cu_num_tokens = torch.cumsum(seq_lens_cpu, dim=0) + query_start_loc_cpu = torch.empty(input_batch_size + 1, dtype=torch.int32) + query_start_loc_cpu[0] = 0 + query_start_loc_cpu[1:] = cu_num_tokens + seq_start_loc_cpu = self.arange[:input_batch_size + 1] + common_attn_metadata_k = MLUCommonAttentionMetadata.build( + query_start_loc=query_start_loc_cpu.to(self.device, non_blocking=True), + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens_cpu.to(self.device, non_blocking=True), + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + seq_start_loc=seq_start_loc_cpu.to(self.device, non_blocking=True), + is_start_loc_match=False, # not prefill + max_query_len=1, + num_actual_tokens=input_batch_size, + num_input_tokens=input_batch_size, + num_speculative_tokens=self.num_speculative_tokens, + has_prefill_reqs=common_attn_metadata.infer_mode == MLUInferMode.CHUNKED, + ) + else: + common_attn_metadata_k = common_attn_metadata_copy + common_attn_metadata_k.num_actual_tokens = batch_size + common_attn_metadata_k.num_input_tokens = batch_size + common_attn_metadata_k.max_query_len = 1 + common_attn_metadata_k.query_start_loc = self.arange[: batch_size + 1] + common_attn_metadata_k.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[: batch_size + 1] + ).clone() + # In padded drafter batch, we need to adjust the sequence lengths + # to remove the "padding" (i.e. rejected tokens). + # Only apply this adjustment when we have rejected tokens + # (i.e., not the first proposal). + for token_index in range(self.num_speculative_tokens - 1): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: get dp_params for draft model + ''' + # dp_params for draft model + if main_model_dp_params is not None: + dp_params = self.runner._get_data_parallel_metadata( + input_batch_size, + input_batch_size, + common_attn_metadata.is_decode_only, + [1] * input_batch_size + ) + kwargs = {} if main_model_dp_params is None else {"dp_params": dp_params} + ''' + ============================= + End of MLU Hijack + ============================= + ''' + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_list[-1].int() + if self.uses_mrope: + positions += 1 + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. + # Since it is complex to remove such requests from the batch, + # we keep them in the batch but adjust the position ids + # and slot mappings to avoid the + # out-of-range access during the model execution. + # The draft tokens generated with this adjustment + # should be ignored. + exceeds_max_model_len = positions[0] >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), + positions, + ) + else: + positions += 1 + exceeds_max_model_len = positions >= self.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) + + # For data integrity when async scheduling, we shouldn't use in place + # operations in case they are modified in next step's `prepare_input` + # of main model. + # Increment the sequence lengths. + common_attn_metadata_k.seq_lens += 1 + # This is an out-of-place operation to avoid modifying the original tensor. + common_attn_metadata_k.seq_lens_cpu = common_attn_metadata_k.seq_lens_cpu + 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + + common_attn_metadata_k.seq_lens.masked_fill_(exceeds_max_model_len, 1) + + common_attn_metadata_k.num_computed_tokens_cpu = ( + common_attn_metadata_k.seq_lens_cpu - 1 + ) + + # Compute the slot mapping. + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size + block_ids = common_attn_metadata_k.block_table_tensor.gather( + dim=1, index=block_numbers.view(-1, 1) + ) + block_ids = block_ids.view(-1) + if self.uses_mrope: + common_attn_metadata_k.slot_mapping = ( + block_ids * self.block_size + clamped_positions[0] % self.block_size + ) + else: + common_attn_metadata_k.slot_mapping = ( + block_ids * self.block_size + clamped_positions % self.block_size + ) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + common_attn_metadata_k.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID + ) + + # Rebuild attention metadata + attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore + common_attn_metadata=common_attn_metadata_k, draft_index=token_index + 1 + ) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata_k + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + if self.supports_mm_inputs: + self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) + + input_ids = None + inputs_embeds = self.inputs_embeds[:input_batch_size] + else: + input_ids = self.input_ids[:input_batch_size] + inputs_embeds = None + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: record latency + ''' + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + start = torch.mlu.Event(enable_timing=True) + start.record() + ''' + ============================= + End of MLU Hijack + ============================= + ''' + # Run the model. + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): + ret_hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + **kwargs, + ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + time_markers.append([start, end]) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + hidden_states = last_hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size], + dp_params=dp_params) + + # TODO(wenlong): get more than one token for tree attention + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_list.append(draft_token_ids) + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids diff --git a/vllm_mlu/v1/spec_decode/eagle.py b/vllm_mlu/v1/spec_decode/eagle.py new file mode 100644 index 0000000..885009e --- /dev/null +++ b/vllm_mlu/v1/spec_decode/eagle.py @@ -0,0 +1,1067 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +import ast +from dataclasses import replace +from importlib.util import find_spec +import numpy as np +import torch +import torch.nn as nn +from typing import Any, List, Optional +from vllm.config.vllm import ( + CompilationMode, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, +) +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, +) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp + +from vllm.v1.spec_decode.eagle import EagleProposer, PADDING_SLOT_ID, logger +from vllm.v1.utils import CpuGpuBuffer + +from vllm_mlu.compilation.mlu_graph import MLUGraphWrapper +from vllm_mlu.v1.attention.backends.mla.flashmla import FlashMLAMetadataBuilder +from vllm_mlu.v1.attention.backends.utils import ( + MLUCommonAttentionMetadata, get_common_metadata_from_attn_metadata, + get_common_metadata, COMMON_METADATA_STR) +from vllm_mlu.model_executor.models.sp_utils import set_sp_forward_context +from vllm_mlu._mlu_utils import * +from vllm_mlu.v1.attention.backends.utils import MLUInferMode + + +class MluEagleProposer(EagleProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + self.vllm_config = vllm_config + self.speculative_config = vllm_config.speculative_config + assert self.speculative_config is not None + self.draft_model_config = self.speculative_config.draft_model_config + self.method = self.speculative_config.method + + self.runner = runner + self.device = device + self.dtype = vllm_config.model_config.dtype + self.max_model_len = vllm_config.model_config.max_model_len + self.block_size = vllm_config.cache_config.block_size + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.token_arange_np = np.arange(self.max_num_tokens) + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + vllm_config.model_config + ) + + self.attn_metadata_builder: AttentionMetadataBuilder | None = None + self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None + self.attn_layer_names: list[str] = [] + self.indexer_layer_names: list[str] = [] + + self.use_cuda_graph = True + + compilation_config = self.vllm_config.compilation_config + if compilation_config.mode == CompilationMode.VLLM_COMPILE: + cudagraph_mode = compilation_config.cudagraph_mode + if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( + CUDAGraphMode.PIECEWISE + ): + logger.warning( + "Currently the eagle proposer only supports cudagraph_mode " + "PIECEWISE, if you want the drafter to use cuda graphs, " + "please set compilation_config.cudagraph_mode to PIECEWISE " + "or FULL_AND_PIECEWISE" + ) + self.use_cuda_graph = ( + not self.speculative_config.enforce_eager + ) + + self.cudagraph_batch_sizes = ( + (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + if self.use_cuda_graph + else [] + ) + + self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes) + # persistent buffers for cuda graph + self.input_ids = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=device + ) + self.uses_mrope = self.vllm_config.model_config.uses_mrope + if self.uses_mrope: + # M-RoPE need (3, max_num_tokens) + self.mrope_positions = torch.zeros( + (3, self.max_num_tokens), dtype=torch.int64, device=device + ) + else: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: tmo positions need to be int32 + ''' + # RoPE need (max_num_tokens,) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=device) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) + + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + max_batch_size = vllm_config.scheduler_config.max_num_seqs + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange( + max_num_slots_for_arange, device=device, dtype=torch.int32 + ) + + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) + + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + + # Determine allowed attention backends once during initialization. + from vllm.attention.backends.registry import AttentionBackendEnum + + self.allowed_attn_types: tuple | None = None + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # ROCM_AITER_FA is an optional backend + if find_spec( + AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) + ): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata, + ) + + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + + # Parse the speculative token tree. + spec_token_tree = self.speculative_config.speculative_token_tree + self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) + tree_depth = len(self.tree_choices[-1]) + # Precompute per-level properties of the tree. + num_drafts_per_level = [0] * tree_depth + for node in self.tree_choices: + num_drafts_per_level[len(node) - 1] += 1 + self.cu_drafts_per_level = [num_drafts_per_level[0]] + self.child_drafts_per_level = [num_drafts_per_level[0]] + for level in range(1, tree_depth): + self.cu_drafts_per_level.append( + self.cu_drafts_per_level[-1] + num_drafts_per_level[level] + ) + self.child_drafts_per_level.append( + num_drafts_per_level[level] // num_drafts_per_level[level - 1] + ) + # Precompute draft position offsets in flattened tree. + self.tree_draft_pos_offsets = torch.arange( + 1, + len(self.tree_choices) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, + dtype=torch.int32) + ''' + ============================= + Modify by vllm_mlu + @brief: Now kv_cache is stored in groups, need to get the corresponding group_id + FIXME: need to be removed after update https://github.com/vllm-project/vllm/pull/20022 + ============================= + ''' + self.kv_cache_group_id = None + ''' + ============================= + End of MLU Hijack + ============================= + ''' + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) + + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + + # Determine allowed attention backends once during initialization. + from vllm.attention.backends.registry import AttentionBackendEnum + + self.allowed_attn_types: tuple | None = None + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # ROCM_AITER_FA is an optional backend + if find_spec( + AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) + ): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata, + ) + + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + + # Parse the speculative token tree. + spec_token_tree = self.speculative_config.speculative_token_tree + self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) + tree_depth = len(self.tree_choices[-1]) + # Precompute per-level properties of the tree. + num_drafts_per_level = [0] * tree_depth + for node in self.tree_choices: + num_drafts_per_level[len(node) - 1] += 1 + self.cu_drafts_per_level = [num_drafts_per_level[0]] + self.child_drafts_per_level = [num_drafts_per_level[0]] + for level in range(1, tree_depth): + self.cu_drafts_per_level.append( + self.cu_drafts_per_level[-1] + num_drafts_per_level[level] + ) + self.child_drafts_per_level.append( + num_drafts_per_level[level] // num_drafts_per_level[level - 1] + ) + # Precompute draft position offsets in flattened tree. + self.tree_draft_pos_offsets = torch.arange( + 1, + len(self.tree_choices) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) + + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, + common_attn_metadata: MLUCommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + # [batch_size] + num_rejected_tokens: torch.Tensor, + # [num_tokens] + token_indices: torch.Tensor, + time_markers: List = [], + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + ) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + hidden_states_indices = last_token_indices + + assert self.runner is not None + + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + + # FIXME: support hybrid kv for draft model (remove separate indexer) + if self.draft_indexer_metadata_builder: + draft_indexer_metadata = ( + self.draft_indexer_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + ) + else: + draft_indexer_metadata = None + + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata + + for layer_name in self.indexer_layer_names: + assert draft_indexer_metadata is not None + per_layer_attn_metadata[layer_name] = draft_indexer_metadata + + cudagraph_runtime_mode = CUDAGraphMode.NONE + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + else: + num_input_tokens = num_tokens + + # copy inputs to buffer for cudagraph + self._set_positions(num_tokens, target_positions) + self.hidden_states[:num_tokens] = target_hidden_states + + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + start = torch.mlu.Event(enable_timing=True) + start.record() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Use full graph with draft model + @brief: Add set_sp_forward_context for sequence parallel. + ''' + use_full_graph = False + batch_descriptor = BatchDescriptor( + num_tokens=num_tokens, + uniform_decode=True, + ) + + if batch_descriptor in self.model.concrete_cudagraph_entries: + cudagraph_runtime_mode = CUDAGraphMode.FULL + use_full_graph = True + + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + + with set_forward_context(per_layer_attn_metadata, self.vllm_config, + num_tokens=num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor if use_full_graph else None), \ + set_sp_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens, + ): + ret_hidden_states = self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + is_running_drafter=use_full_graph + ) + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + time_markers.append([start, end]) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + ''' + ============================= + End of MLU Hijack + ============================= + ''' + sample_hidden_states = last_hidden_states[hidden_states_indices] + logits = self.model.compute_logits(sample_hidden_states) + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + if self.uses_mrope: + positions = target_positions[:, last_token_indices] + else: + positions = target_positions[last_token_indices] + + ''' + ============================= + Modify by vllm_mlu + ============================= + ''' + hidden_states = sample_hidden_states + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + # Generate the remaining draft tokens. + draft_token_ids_list = [draft_token_ids] + + input_batch_size = batch_size + + if common_attn_metadata.infer_mode != MLUInferMode.DECODE_ONLY: + seq_lens_cpu = torch.ones(input_batch_size, dtype=torch.int32,) + cu_num_tokens = torch.cumsum(seq_lens_cpu, dim=0) + query_start_loc_cpu = torch.empty(input_batch_size + 1, dtype=torch.int32) + query_start_loc_cpu[0] = 0 + query_start_loc_cpu[1:] = cu_num_tokens + seq_start_loc_cpu = self.arange[:input_batch_size + 1] + common_attn_metadata_k = MLUCommonAttentionMetadata.build( + query_start_loc=query_start_loc_cpu.to(self.device, non_blocking=True), + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens_cpu.to(self.device, non_blocking=True), + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + seq_start_loc=seq_start_loc_cpu.to(self.device, non_blocking=True), + is_start_loc_match=False, # not prefill + max_query_len=1, + num_actual_tokens=input_batch_size, + num_input_tokens=input_batch_size, + num_speculative_tokens=self.num_speculative_tokens, + has_prefill_reqs=common_attn_metadata.infer_mode == MLUInferMode.CHUNKED, + ) + else: + common_attn_metadata_k = common_attn_metadata + common_attn_metadata_k.num_actual_tokens = batch_size + common_attn_metadata_k.num_input_tokens = batch_size + common_attn_metadata_k.max_query_len = 1 + common_attn_metadata_k.query_start_loc = self.arange[: batch_size + 1] + common_attn_metadata_k.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[: batch_size + 1] + ).clone() + for token_index in range(self.num_speculative_tokens - 1): + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_list[-1].int() + if self.uses_mrope: + positions += 1 + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. + # Since it is complex to remove such requests from the batch, + # we keep them in the batch but adjust the position ids + # and slot mappings to avoid the + # out-of-range access during the model execution. + # The draft tokens generated with this adjustment + # should be ignored. + exceeds_max_model_len = positions[0] >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), + positions, + ) + else: + positions += 1 + exceeds_max_model_len = positions >= self.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) + + # For data integrity when async scheduling, we shouldn't use in place + # operations in case they are modified in next step's `prepare_input` + # of main model. + # Increment the sequence lengths. + common_attn_metadata_k.seq_lens += 1 + # This is an out-of-place operation to avoid modifying the original tensor. + common_attn_metadata_k.seq_lens_cpu = common_attn_metadata_k.seq_lens_cpu + 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + + common_attn_metadata_k.seq_lens.masked_fill_(exceeds_max_model_len, 1) + + common_attn_metadata_k.num_computed_tokens_cpu = ( + common_attn_metadata_k.seq_lens_cpu - 1 + ) + + # Compute the slot mapping. + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size + block_ids = common_attn_metadata_k.block_table_tensor.gather( + dim=1, index=block_numbers.view(-1, 1) + ) + block_ids = block_ids.view(-1) + if self.uses_mrope: + common_attn_metadata_k.slot_mapping = ( + block_ids * self.block_size + clamped_positions[0] % self.block_size + ) + else: + common_attn_metadata_k.slot_mapping = ( + block_ids * self.block_size + clamped_positions % self.block_size + ) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + common_attn_metadata_k.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID + ) + + # Rebuild attention metadata + attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore + common_attn_metadata=common_attn_metadata_k, draft_index=token_index + 1 + ) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata_k + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + if self.supports_mm_inputs: + self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) + + input_ids = None + inputs_embeds = self.inputs_embeds[:input_batch_size] + else: + input_ids = self.input_ids[:input_batch_size] + inputs_embeds = None + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: record latency + @brief: add set_sp_forward_context for sequence parallel. + ''' + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + start = torch.mlu.Event(enable_timing=True) + start.record() + ''' + ============================= + End of MLU Hijack + ============================= + ''' + # Run the model. + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size + ), set_sp_forward_context( + per_layer_attn_metadata, + self.vllm_config, + input_batch_size, + ): + ret_hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: adapt to different methods + ''' + if self.method == "mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + ''' + ============================= + End of MLU Hijack + ============================= + ''' + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: record latency + ''' + if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + time_markers.append([start, end]) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size]) + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_list.append(draft_token_ids) + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + + def prepare_inputs( + self, + common_attn_metadata: MLUCommonAttentionMetadata, + # [batch_size] + num_rejected_tokens: torch.Tensor + ) -> tuple[MLUCommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for the spec decode. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add seq_start_loc compute, use MLUCommonAttentionMetadata + ''' + new_seq_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_seq_start_loc_np = new_seq_start_loc_cpu.numpy() + np.cumsum(new_seq_lens_cpu.numpy(), out=new_seq_start_loc_np[1:]) + + spec_common_attn_metadata = MLUCommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + seq_start_loc=new_seq_start_loc_cpu.to(device, non_blocking=True), + num_input_tokens=total_num_tokens, + num_prefill_query_tokens=total_num_tokens, + num_prefill_kv_tokens=total_num_tokens, + infer_mode=common_attn_metadata.infer_mode, + ) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + return spec_common_attn_metadata, token_indices + + def load_model( + self, target_model: nn.Module) -> None: + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) + + from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: use graph wrapper for draft model + ''' + self.model = MLUGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() - + target_attn_layer_names) + + self.attn_layer_names = list(draft_attn_layer_names) + + if supports_multimodal(target_model): + # handle multimodality + self.model.config.image_token_index = ( + target_model.config.image_token_index) + target_language_model = target_model.get_language_model() + else: + target_language_model = target_model + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: only eagle and eagle3 need to share embed_tokens with the target model + ''' + if self.method in ["eagle", "eagle3"] or self.vllm_config.model_config.hf_config.model_type == "glm4_moe": + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1 \ + and self.model.model.embed_tokens.weight.shape \ + == target_language_model.model.embed_tokens.weight.shape: + logger.info( + "Assuming the EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_language_model.model.embed_tokens + else: + logger.info( + "The EAGLE head's vocab embedding will be loaded separately" \ + " from the target model." + ) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.vllm_config.speculative_config.method not in ["eagle3", "longcat_flash_mtp"] and \ + hasattr(target_language_model, "lm_head"): + logger.info("Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_language_model.lm_head + + target_lm_head = target_model.lm_head + if target_lm_head is None: + logger.warning("Target model lm_head is None") + return + if self.vllm_config.model_config.hf_config.model_type == "glm4_moe": + self._process_moe_mtp_layers(target_lm_head) + + def _process_moe_mtp_layers(self, target_lm_head): + # For GLM4 MoE MTP models, share weights with all MTP layer shared_head.head + # instead of replacing the module (to preserve DPParallelLMHead functionality) + if not (hasattr(self.model, "model") and hasattr(self.model.model, "layers")): + return + for layer_name, layer in self.model.model.layers.items(): + if not (hasattr(layer, "shared_head") and hasattr(layer.shared_head, "head")): + continue + if not (hasattr(target_lm_head, "weight") and hasattr(layer.shared_head.head, "weight")): + continue + if layer.shared_head.head.weight.shape != target_lm_head.weight.shape: + logger.debug( + f"Skipping weight sharing for layer {layer_name}: " + f"shape mismatch (mtp: {layer.shared_head.head.weight.shape}, " + f"target: {target_lm_head.weight.shape})" + ) + continue + # Safe replacement + del layer.shared_head.head + layer.shared_head.head = target_lm_head + logger.info(f"Replaced MTP layer {layer_name} shared_head.head with target lm_head") + + @torch.inference_mode() + def dummy_run( + self, + attn_metadata: Any, + num_tokens: int, + use_cudagraphs=True, + ) -> None: + # Determine if CUDA graphs should be used for this run. + cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph + if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @breif: add set_sp_forward_context for sequence parallel. + @brief: capture drafter model + ''' + cudagraph_runtime_mode = (CUDAGraphMode.FULL if cudagraphs_enabled + else CUDAGraphMode.NONE) + + batch_descriptor = BatchDescriptor( + num_tokens=num_tokens, + uniform_decode=True, + ) + + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=cudagraph_runtime_mode, + ), set_sp_forward_context(None, self.vllm_config, num_tokens): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + self.model( + input_ids=input_ids, + positions=self._get_positions(num_tokens), + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, + is_running_drafter=True + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + def validate_same_kv_cache_group( + self, + kv_cache_config: KVCacheConfig) -> None: + """ + Validate that all eagle layers belong to the same KVCacheGroup. + Need this assumption to ensure all eagle layers can use the + same AttentionMetadata. + May extend to multiple AttentionMetadata in the future. + """ + kv_cache_groups: dict[str, int] = {} + for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + kv_cache_groups[layer_name] = id + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: get kv_cache_group_id and filter kv_cache_groups + ''' + eagle_cache_groups = set(kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + if layer_name in kv_cache_groups) + assert len(eagle_cache_groups) == 1, ( + "All eagle layers should belong to the same kv cache group") + self.kv_cache_group_id = next(iter(eagle_cache_groups)) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + def prepare_inputs_padded( + self, + common_attn_metadata: MLUCommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[MLUCommonAttentionMetadata, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] + - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] + ) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu), + ) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add seq_start_loc compute, use MLUCommonAttentionMetadata + ''' + new_seq_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_seq_start_loc_np = new_seq_start_loc_cpu.numpy() + np.cumsum(common_attn_metadata.seq_lens.cpu().numpy(), out=new_seq_start_loc_np[1:]) + + spec_common_attn_metadata = MLUCommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + seq_start_loc=new_seq_start_loc_cpu.to(self.device, non_blocking=True), + num_input_tokens=total_num_tokens, + num_prefill_query_tokens=total_num_tokens, + num_prefill_kv_tokens=total_num_tokens, + infer_mode=common_attn_metadata.infer_mode, + ) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu + ) + return spec_common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu + + + def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: + """Find and return the attention metadata builders for EAGLE layers. + + Returns: + The metadata builders for EAGLE layers. + + Raises: + AssertionError: If no metadata builders are found for EAGLE layers. + """ + builder = None + chosen_layer = self.attn_layer_names[0] + + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: replace attn metadata name to prefill_attn name + """ + if self.draft_model_config.is_deepseek_mla and chosen_layer.endswith("self_attn.attn"): + chosen_layer = chosen_layer.replace( + "self_attn.attn", "self_attn.mla_attn") + """ + ================= + End of MLU Hijack + ================= + """ + for kv_cache_group in self.runner.attn_groups: + for attn_group in kv_cache_group: + if chosen_layer in attn_group.layer_names: + builder = attn_group.get_metadata_builder() + break + if builder is not None: + break + + assert builder is not None, ( + "Failed to find attention metadata builder for EAGLE layers." + ) + return builder \ No newline at end of file diff --git a/vllm_mlu/v1/worker/__init__.py b/vllm_mlu/v1/worker/__init__.py new file mode 100644 index 0000000..e6152d2 --- /dev/null +++ b/vllm_mlu/v1/worker/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + diff --git a/vllm_mlu/v1/worker/block_table.py b/vllm_mlu/v1/worker/block_table.py new file mode 100644 index 0000000..4e0622d --- /dev/null +++ b/vllm_mlu/v1/worker/block_table.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + +from vllm.distributed import get_dcp_group +from vllm.logger import init_logger +from vllm.v1.worker.block_table import BlockTable + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + +logger = init_logger(__name__) + + +class BlockTable_MluHijack(BlockTable): + + def __init__( + self, + block_size: int, + max_num_reqs: int, + max_num_blocks_per_req: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + kernel_block_size: int, + dcp_kv_cache_interleave_size: int, + ): + """ + Args: + block_size: Block size used for KV cache memory allocation + max_num_reqs: Maximum number of concurrent requests supported. + max_num_blocks_per_req: Maximum number of blocks per request. + max_num_batched_tokens: Maximum number of tokens in a batch. + pin_memory: Whether to pin memory for faster GPU transfers. + device: Target device for the block table. + kernel_block_size: The block_size of underlying attention kernel. + Will be the same as `block_size` if `block_size` is supported + by the attention kernel. + """ + self.max_num_reqs = max_num_reqs + self.max_num_batched_tokens = max_num_batched_tokens + self.pin_memory = pin_memory + self.device = device + + if kernel_block_size == block_size: + # Standard case: allocation and computation use same block size + # No block splitting needed, direct mapping + self.block_size = block_size + self.blocks_per_kv_block = 1 + self.use_hybrid_blocks = False + else: + # Hybrid case: allocation block size differs from kernel block size + # Memory blocks are subdivided to match kernel requirements + # Example: 32-token memory blocks with 16-token kernel blocks + # → Each memory block corresponds to 2 kernel blocks + if block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"kv_manager_block_size size {block_size} evenly" + ) + + self.block_size = kernel_block_size + self.blocks_per_kv_block = block_size // kernel_block_size + self.use_hybrid_blocks = True + + self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block + + self.block_table = self._make_buffer( + self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 + ) + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: change slot_mapping dtype for int64 to int32 + ''' + self.slot_mapping = self._make_buffer( + self.max_num_batched_tokens, dtype=torch.int32 + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if self.use_hybrid_blocks: + self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape( + 1, -1 + ) + else: + self._kernel_block_arange = None + + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size + + +MluHijackObject.apply_hijack( + BlockTable, + BlockTable.__init__, + BlockTable_MluHijack.__init__ +) \ No newline at end of file diff --git a/vllm_mlu/v1/worker/dp_gpu_model_runner.py b/vllm_mlu/v1/worker/dp_gpu_model_runner.py new file mode 100644 index 0000000..8ebc35c --- /dev/null +++ b/vllm_mlu/v1/worker/dp_gpu_model_runner.py @@ -0,0 +1,1007 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from copy import copy +from typing import TYPE_CHECKING, Dict, Optional, List, Tuple, Any + +import torch +import numpy as np +import cnpx + +from vllm.distributed.parallel_state import ( + get_tp_group, get_pp_group) +from vllm.distributed.kv_transfer import has_kv_transfer_group, get_kv_transfer_group +from vllm.distributed import ( + divide, get_moe_expert_parallel_world_size +) +from vllm.config import VllmConfig, CUDAGraphMode +from vllm.forward_context import set_forward_context, BatchDescriptor +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput) +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.medusa import MedusaProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.gpu_model_runner import ExecuteModelState +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput, GrammarOutput + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata +from vllm_mlu.v1.attention.backends.utils import ( + MLUCommonAttentionMetadata, unpad_common_attn_metadata, + get_common_metadata_from_attn_metadata, MLUInferMode) +from vllm_mlu.v1.worker.gpu_model_runner import ( + MLUModelRunner, AsyncMLUModelRunnerOutput, apply_grammar_bitmask) +from vllm_mlu.mlu_forward_context import MLUDPMetadata +from vllm_mlu.model_executor.models.dp_utils import ( + enable_emb_logits_custom_parallel, + get_runtime_infos_per_dp_group, + get_deepseek_layer_split_list, +) + +from vllm_mlu.model_executor.models.dp_utils import ( + DataParallelRuntimeParams +) +from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp +from vllm_mlu.distributed.parallel_state import ( + init_cnclep, get_cnclep +) + +from vllm_mlu._mlu_utils import * +import vllm_mlu._mlu_utils as mlu_envs + +logger = init_logger(__name__) + + +class DPMLUModelRunner(MLUModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + vllm_config.mlu_config.enable_custom_data_parallel_opt = True + super().__init__(vllm_config, device) + self.use_cuda_graph = ( + self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and not self.model_config.enforce_eager) + if not self.use_cuda_graph and not self.model_config.enforce_eager: + logger.warning("Can not use cudagraph for dp mlu model runner. Dp mlu model runner can " + "only support cudagraph_mode with CUDAGraphMode.FULL_DECODE_ONLY.") + self.use_all2all = self.mlu_config.decode_dispatch_combine_use_all2all + if self.use_all2all: + assert get_moe_expert_parallel_world_size() > 1, ( + "all2all requires that expert parallel is enabled") + kwargs = self.make_cnclep_kwargs() + init_cnclep(**kwargs) + if self.model_config.is_longcat_flash: + kwargs_bf16 = self.make_cnclep_kwargs(use_quant_dispatch=False) + init_cnclep(**kwargs_bf16) + self.dp_metadata = None + + def _get_data_parallel_metadata( + self, + num_tokens: int, + num_reqs: int, + is_decode_only: bool, + query_len_per_batch: Optional[List[int]], + ) -> "MLUDPMetadata": + (dp_query_lens, dp_group_bs, dp_is_prefill, + seq_len_per_batch) = get_runtime_infos_per_dp_group( + num_tokens, + num_reqs, + not is_decode_only, + query_len_per_batch, + self.device, + self.vllm_config, + ) + (emb_query_lens, logits_batch_sizes, + dense_attn_token_split_list) = get_deepseek_layer_split_list( + dp_query_lens, + dp_group_bs, + ) + return MLUDPMetadata.make_oot( + self.parallel_config.data_parallel_rank, + self.parallel_config.data_parallel_size, + self.parallel_config.tensor_parallel_size, + dp_query_lens, + dp_is_prefill, + self.vllm_config.mlu_config.prefill_dispatch_use_RS_AG, + seq_lens=(seq_len_per_batch if all(dp_is_prefill) else None), + batch_sizes=dp_group_bs, + emb_query_lens=emb_query_lens, + logits_batch_sizes=logits_batch_sizes, + dense_attn_token_split_list=dense_attn_token_split_list, + ) + + def _get_dp_graph_info(self, + K: int, + num_scheduled_tokens: int, + dp_metadata: "MLUDPMetadata"): + """ + Check if the DeepSeek model can enter graph mode and retrieve input + tokens and batch. + + This function also applies to other eligible MoE models with DP enabled, + reusing the same graph mode compatibility logic. + + Returns: + tuple: Contains three elements: + num_input_tokens: Retrieved input token + num_input_batchs: Retrieved input batch + use_graph: Whether the model can use graph mode + """ + if (self.use_cuda_graph + and all(not prefill for prefill in dp_metadata.dp_is_prefill) + and all(token_num <= self.cudagraph_batch_sizes[-1] + for token_num in dp_metadata.token_split_list)): + num_input_tokens = self.vllm_config.pad_for_cudagraph( + max(dp_metadata.token_split_list)) + assert num_input_tokens % (K + 1) == 0, \ + f"num_input_tokens ({num_input_tokens}) must be divisible by (K + 1) = {K + 1}" + num_input_batchs = num_input_tokens // (1 + K) + use_graph = True + else: + num_input_batchs = self.input_batch.num_reqs + num_input_tokens = num_scheduled_tokens + use_graph = False + return num_input_tokens, num_input_batchs, use_graph + + @torch.inference_mode() + def moe_dp_execute_dummy_batch( + self, num_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + # MUST do comm across dp group first when enable data parallel. + # Here we set dummy run state as prefill only to prevent other dp + # group use graph. + dp_metadata = self._get_data_parallel_metadata( + num_tokens, num_reqs, False, [num_tokens // num_reqs] * num_reqs + ) + + # always skip attn compute + attn_metadata: Optional[Dict[str, Any]] = None + + input_ids = self.input_ids.gpu[:num_tokens] + positions = self.positions.gpu[:num_tokens] + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + hidden_states = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=None, + dp_params=dp_metadata, + ) + + kwargs = ({"dp_params": dp_metadata} + if enable_emb_logits_custom_parallel() else {}) + self.model.compute_logits( + hidden_states[:num_tokens], **kwargs) + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + target_token_ids = input_ids + target_positions = positions + # hidden_states no need to be sliced + target_hidden_states = hidden_states + self.drafter.propose_ds_execute_dummy_batch( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + dp_params=dp_metadata) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + logit_indices_device = torch.from_numpy(logit_indices).to( + self.device, non_blocking=True + ) + return hidden_states, hidden_states[logit_indices_device] + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | IntermediateTensors | None: + if self.execute_model_state is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with record_function_or_nullcontext("dp_gpu_model_runner: preprocess"): + with self.synchronize_input_prep(): + # Update persistent batch states. + self._update_states(scheduler_output) + + if not num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward( + scheduler_output, self.vllm_config + ) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add mlu_infer_mode. + @brief: prepare mlu dp metadata in _prepare_inputs instead of ubatch_slices + and num_tokens_across_dp. + ''' + max_computed_tokens = np.max(self.input_batch.num_computed_tokens_cpu[:num_reqs]) + self.mlu_infer_mode = MLUInferMode.build( + max_query_len=max_num_scheduled_tokens, + max_computed_tokens=max_computed_tokens, + uniform_decode_query_len=self.uniform_decode_query_len, + ) + + num_tokens_across_dp = None + ( + logits_indices, + spec_decode_metadata, + ubatch_slices, + dp_metadata, + ) = self._prepare_inputs( + scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens + ) + self.dp_metadata = dp_metadata + ''' + ================== + End of MLU Hijack + ================== + ''' + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and ubatch_slices is None: + # Pre-compute cascade attention prefix lengths + # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + scheduler_output.num_common_prefix_blocks, + ) + + # TODO(lucas): move cudagraph dispatching here: + # https://github.com/vllm-project/vllm/issues/23789 + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( + total_num_scheduled_tokens=total_num_scheduled_tokens, + max_num_scheduled_tokens=max_num_scheduled_tokens, + num_reqs=num_reqs, + ubatch_slices=ubatch_slices, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + mlu_infer_mode=self.mlu_infer_mode, + ) + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: pad attn metadata for mlu grpah. + @brief: pad num_input_tokens based on all dp groups and spec decode. + @brief: add dp_params to model_kwargs. + ''' + dp_can_use_graph = False + if self.use_cuda_graph: + num_input_tokens_dp, num_padded_reqs, dp_can_use_graph = self._get_dp_graph_info( + self.num_spec_tokens, num_scheduled_tokens, dp_metadata) + if dp_can_use_graph: + # all layers share the same attn_metadata + assert len(self.kv_cache_config.kv_cache_groups) == 1 + attn_metadata_val = next(iter(attn_metadata.values())) + common_metadata = get_common_metadata_from_attn_metadata(attn_metadata) + block_table = self.input_batch.block_table[0] + pad_attn_metadata( + attn_metadata_val, common_metadata, block_table, self, + num_scheduled_tokens, num_input_tokens_dp, num_reqs, num_padded_reqs) + + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + num_input_tokens = ( + num_input_tokens_dp if dp_can_use_graph else num_scheduled_tokens) + + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_input_tokens, intermediate_tensors + ) + + model_kwargs["dp_params"] = dp_metadata + ''' + ================== + End of MLU Hijack + ================== + ''' + + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + batch_descriptor, + use_cascade_attn=cascade_attn_prefix_lens is not None, + ) + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: check if we can use cudagraph using dp_can_use_graph. + ''' + if not dp_can_use_graph: + cudagraph_runtime_mode = CUDAGraphMode.NONE + batch_descriptor = None + ''' + ================== + End of MLU Hijack + ================== + ''' + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: debug disagg cnpx. + ''' + if mlu_envs.VLLM_DISAGG_CNPX_EXECUTE: + self.execute_cnpx_mark = cnpx.rangeStart("DP_" + str(self.parallel_config.data_parallel_rank) + "_TP_" \ + + str(get_tensor_model_parallel_rank()) + "_execute_model" + \ + ("_no_graph" if cudagraph_runtime_mode == CUDAGraphMode.NONE else "")) + if mlu_envs.VLLM_DISAGG_CNPX_REQUEST: + self.request_cnpx_mark.clear() + for req in scheduler_output.scheduled_new_reqs: + self.request_cnpx_mark[req.req_id] = cnpx.rangeStart(req.req_id) + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + self.request_cnpx_mark[req_id] = cnpx.rangeStart(req_id) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + start = torch.mlu.Event(enable_timing=True) + start.record() + + # Run the model. + # Use persistent buffers for CUDA graphs. + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), + record_function_or_nullcontext("dp_gpu_model_runner: forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + model_output = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + with record_function_or_nullcontext("dp_gpu_model_runner: postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. + hidden_states, aux_hidden_states = model_output + else: + # Common case. + hidden_states = model_output + aux_hidden_states = None + + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + # Return the pooling output. + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) + output.kv_connector_output = kv_connector_output + return output + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support embed logits custom parallel. + ''' + sample_hidden_states = hidden_states[logits_indices] + logits_kwargs = ({"dp_params": dp_metadata} + if enable_emb_logits_custom_parallel() else {}) + logits = self.model.compute_logits(sample_hidden_states, **logits_kwargs) + ''' + ================== + End of MLU Hijack + ================== + ''' + else: + # Rare case. + assert not self.is_pooling_model + + sample_hidden_states = hidden_states[logits_indices] + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + else: + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + self.time_markers = [] + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + self.time_markers.append([start, end]) + + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + ) + return None + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncMLUModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # noqa + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) + + with record_function_or_nullcontext("gpu_model_runner: sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + self.input_batch.prev_sampled_token_ids = None + + def propose_draft_token_ids( + sampled_token_ids: torch.Tensor | list[np.ndarray], + ) -> None: + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("gpu_model_runner: draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + whole_block_table=self.input_batch.block_table[0], + main_model_dp_params=self.dp_metadata, + ) + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + effective_drafter_max_model_len = self.max_model_len + if effective_drafter_max_model_len is None: + effective_drafter_max_model_len = self.model_config.max_model_len + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): + effective_drafter_max_model_len = ( + self.speculative_config.draft_model_config.max_model_len + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Force `input_fits_in_drafter` to be True to ensure that `self.uniform_decode_query_len` tokens are scheduled per batch during model execution. + This is required for graph validation and to keep the batch token count consistent with `self.uniform_decode_query_len` immediately after the prefill stage. + ''' + # input_fits_in_drafter = spec_decode_common_attn_metadata and ( + # spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + # <= effective_drafter_max_model_len + # ) + input_fits_in_drafter = True + ''' + ================== + End of MLU Hijack + ================== + ''' + if use_padded_batch_for_eagle: + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("gpu_model_runner: eplb"): + self.eplb_step() + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, + num_nans_in_logits=num_nans_in_logits, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + get_kv_transfer_group().clear_connector_metadata() + + if mlu_envs.VLLM_DISAGG_CNPX_EXECUTE: + current_stream = torch.mlu.current_stream() + current_stream.synchronize() + cnpx.rangeEnd(self.execute_cnpx_mark) + if mlu_envs.VLLM_DISAGG_CNPX_REQUEST: + current_stream = torch.mlu.current_stream() + current_stream.synchronize() + for req in scheduler_output.scheduled_new_reqs: + cnpx.rangeEnd(self.request_cnpx_mark[req.req_id]) + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + cnpx.rangeEnd(self.request_cnpx_mark[req_id]) + ''' + ================== + End of MLU Hijack + ================== + ''' + if not self.use_async_scheduling: + return output + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncMLUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + + def propose_draft_token_ids( + self, + scheduler_output: "SchedulerOutput", + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + common_attn_metadata: MLUCommonAttentionMetadata, + whole_block_table: torch.Tensor, + main_model_dp_params: Optional[DataParallelRuntimeParams] = None, + ) -> list[list[int]]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: draft model will build new FlashMLAMetadata, + so just unpad common_attn_metadata here. + ''' + unpad_common_attn_metadata( + common_metadata=common_attn_metadata, + num_reqs=self.input_batch.num_reqs, + num_scheduled_tokens=num_scheduled_tokens, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.speculative_config.method == "ngram": + assert isinstance(self.drafter, NgramProposer) + spec_token_ids = self.propose_ngram_draft_token_ids( + sampled_token_ids) + elif self.speculative_config.method == "medusa": + assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. + hidden_states = sample_hidden_states + else: + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + + spec_token_ids = self.drafter.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + elif self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # TODO(woosuk): Refactor the loop. + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" + "padded-batch is disabled." + ) + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" + "padded-batch is enabled." + ) + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + if spec_decode_metadata is None: + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + # TODO(woosuk): Support M-RoPE. + target_positions = self._get_positions(num_scheduled_tokens) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + num_rejected_tokens_gpu = None + token_indices = None + else: + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) + else: + common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = ( + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count, + ) + ) + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[token_indices] + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add debug info for draft accepted rate + ''' + if mlu_envs.VLLM_MTP_DEBUG: + batch_total_draft = sum(spec_decode_metadata.num_draft_tokens) + batch_total_rejected = sum(num_rejected_tokens_gpu) + self.total_draft_tokens += batch_total_draft + self.total_accepted_tokens += ( + batch_total_draft - batch_total_rejected) + if batch_total_draft > 0: + batch_accept_rate = ( + batch_total_draft - batch_total_rejected + ) / batch_total_draft + print(f"Batch Accept Rate: {batch_accept_rate:.4f}, " + f"Total Accept Rate: {self.get_accept_rate():.4f}") + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.supports_mm_inputs: + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: keep full scheduled tokens for draft model compute + ''' + target_token_ids = target_token_ids[:num_scheduled_tokens] + target_positions = target_positions[:num_scheduled_tokens] + target_hidden_states = target_hidden_states[:num_scheduled_tokens] + ''' + ================== + End of MLU Hijack + ================== + ''' + + spec_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + num_rejected_tokens=num_rejected_tokens_gpu, + token_indices=token_indices, + whole_block_table=whole_block_table, + main_model_dp_params=main_model_dp_params, + time_markers=self.time_markers, + ) + return spec_token_ids + + def make_cnclep_kwargs(self, use_quant_dispatch: bool = True) -> dict[Any, Any]: + + K = (self.drafter.num_speculative_tokens + if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer) + else 0) + seq_len = K + 1 + config = self.model_config.hf_config + num_experts = (config.n_routed_experts if hasattr(config, "n_routed_experts") + else config.num_experts) + topk = getattr(config, "num_experts_per_tok", None) or getattr(config, "moe_topk", None) + assert topk is not None, "failed to get topk from config" + hidden_size = config.hidden_size + dispatch_token_size = hidden_size * get_dtype_size(self.dtype) + if use_quant_dispatch: + dispatch_token_size = hidden_size * get_dtype_size(torch.int8) + get_dtype_size(torch.float32) + combine_token_size = hidden_size * get_dtype_size(self.dtype) + + max_num_seqs_per_dp = self.scheduler_config.max_num_seqs + # max number of tokens that an ep rank could send + max_num_tokens_per_rank = divide(max_num_seqs_per_dp * seq_len * topk, + self.parallel_config.tensor_parallel_size) + + return dict(dispatch_token_size=dispatch_token_size, + combine_token_size=combine_token_size, + max_num_tokens_per_rank=max_num_tokens_per_rank, + num_global_experts=num_experts, + use_quant_dispatch=use_quant_dispatch) + + def prepare_all2all_buffer_for_model( + self, model: torch.nn.Module) -> None: + """ + Prepare all2all buffer for the model. + """ + if not self.use_all2all: + return + + moe_modules = [ + module for module in self.model.modules() + if isinstance(module, SparseMoeMlp) + ] + if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer): + draft_moes = [ + module for module in self.drafter.model.modules() + if isinstance(module, SparseMoeMlp) and not mlu_envs.VLLM_MTP_NO_QUANT + ] + moe_modules.extend(draft_moes) + for module in moe_modules: + if self.load_config.load_format == "dummy": + module.pack_params() + module.pack_params_after_loading() + use_quant_dispatch = module.quant_config is not None + module.prepare_for_cnclep(get_cnclep(use_quant_dispatch=use_quant_dispatch)) + + def load_model(self, eep_scale_up: bool = False) -> None: + super().load_model() + if self.use_all2all: + self.prepare_all2all_buffer_for_model(self.model) diff --git a/vllm_mlu/v1/worker/gpu_input_batch.py b/vllm_mlu/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000..e8e98f8 --- /dev/null +++ b/vllm_mlu/v1/worker/gpu_input_batch.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def split_decodes_and_prefills(self): + decodes = 0 + prefills = 0 + for i, req_id in enumerate(self.req_ids): + req_index = self.req_id_to_index.get(req_id) + num_prompt_tokens = self.num_prompt_tokens[req_index] + num_computed_tokens = self.num_computed_tokens_cpu[req_index] + if num_computed_tokens < num_prompt_tokens: + prefills += 1 + else: + decodes += 1 + return decodes, prefills + + +MluHijackObject.apply_hijack(InputBatch, + "split_decodes_and_prefills", + split_decodes_and_prefills) \ No newline at end of file diff --git a/vllm_mlu/v1/worker/gpu_model_runner.py b/vllm_mlu/v1/worker/gpu_model_runner.py new file mode 100644 index 0000000..c165188 --- /dev/null +++ b/vllm_mlu/v1/worker/gpu_model_runner.py @@ -0,0 +1,4166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 + +from copy import copy +import gc +import time +from contextlib import contextmanager +from itertools import product +from typing import TYPE_CHECKING, Dict, List, Tuple, cast +import re + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm +import cnpx + +import vllm.envs as envs +from vllm.attention.layer import Attention +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import ( + CompilationMode, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, +) +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.eplb.eplb_state import EplbState +from vllm.distributed.kv_transfer import has_kv_transfer_group, get_kv_transfer_group +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.parallel_state import ( + get_dcp_group, + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, + get_tensor_model_parallel_rank, +) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + is_mixture_of_experts, + supports_eagle3, + supports_multimodal_pruning, +) +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors +from vllm.utils.import_utils import LazyLoader +from vllm.utils.math_utils import cdiv +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import ( + get_dtype_size, + kv_cache_dtype_str_to_dtype, + supports_dynamo, + weak_ref_tensor, +) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + get_dcp_local_seq_lens, + split_attn_metadata, +) +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + EncoderOnlyAttentionSpec, + KVCacheConfig, + KVCacheSpec, + FullAttentionSpec, + MambaSpec, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + LogprobsTensors, + ModelRunnerOutput, + make_empty_encoder_model_runner_output, + LogprobsLists, + SamplerOutput, +) +from vllm.v1.sample.logits_processor import build_logitsprocs +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.medusa import MedusaProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_model_runner import ( + AsyncGPUModelRunnerOutput, + ExecuteModelState, + GPUModelRunner, + PerLayerAttnMetadata, +) +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, + UBatchSlices, + check_ubatch_thresholds, +) +from vllm.v1.worker.utils import ( + AttentionGroup, + MultiModalBudget, + bind_kv_cache, + is_residual_scattered_for_sp, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders, +) +if TYPE_CHECKING: + import xgrammar as xgr + import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + xgr_torch_compile = LazyLoader( + "xgr_torch_compile", globals(), + "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") + +import vllm_mlu._mlu_utils as mlu_envs +from vllm_mlu.compilation.mlu_graph import MLUGraphWrapper +from vllm_mlu.distributed.parallel_state import mlu_graph_capture +from vllm_mlu.model_executor.layers.feed_forward import FeedForward +from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding +from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp +from vllm_mlu.v1.kv_cache_interface import MLUMLAAttentionSpec +from vllm_mlu.v1.attention.backends.flash_attn import FlashAttentionMetadata, pad_attn_metadata +from vllm_mlu.v1.attention.backends.utils import ( + COMMON_METADATA_STR, + MLUCommonAttentionMetadata, + MLUInferMode, + get_common_metadata, + unpad_common_attn_metadata) +from vllm_mlu.model_executor.models.sp_utils import set_sp_forward_context +from vllm_mlu.v1.sample.sampler import MluSampler +from vllm_mlu.v1.spec_decode.dp_eagle import DPMluEagleProposer +from vllm_mlu.v1.spec_decode.eagle import MluEagleProposer +import vllm_mlu._mlu_utils as mlu_envs + +logger = init_logger(__name__) + + +_NUM_WARMUP_ITERS = 2 + + +def _model_forward_pre_hook(self, args, kwargs): + ''' + This hook function will be called before model.forward + ''' + assert len(args) == 0 and len(kwargs) > 0, \ + f"The pre-forward's expected inputs are not passed by kwargs. " + \ + f"Expected len(args)=0, len(kwargs)>0, " + \ + f"now, len(args)={len(args)}, len(kwargs)={len(kwargs)}." + + common_metadata: MLUCommonAttentionMetadata = get_common_metadata() + + if common_metadata: + # Prepare attributes for all rope in model + MLURotaryEmbedding.set_mlu_var_v1(common_metadata=common_metadata) + + if self.config.model_type == "deepseek_v4": + args, kwargs = self.update_forward_args(args, kwargs) + + return (args, kwargs) + + +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncMLUModelRunnerOutput(AsyncGPUModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + logprobs_tensors: torch.Tensor | None, + invalid_req_indices: list[int], + async_output_copy_stream: torch.mlu.Stream, + vocab_size: int, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self.async_copy_ready_event = torch.mlu.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + self.vocab_size = vocab_size + self._logprobs_tensors = logprobs_tensors + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.mlu.current_stream() + with torch.mlu.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self.sampled_token_ids_cpu = self._sampled_token_ids.to( + "cpu", non_blocking=True + ) + self._logprobs_tensors_cpu = ( + self._logprobs_tensors.to_cpu_nonblocking() + if self._logprobs_tensors + else None + ) + self.async_copy_ready_event.record() + +def apply_grammar_bitmask( + scheduler_output: SchedulerOutput, + grammar_output: GrammarOutput, + input_batch: InputBatch, + logits: torch.Tensor, +) -> None: + """ + Apply grammar bitmask to output logits of the model with xgrammar function. + + Args: + scheduler_output (SchedulerOutput): The result of engine scheduling. + input_batch (InputBatch): The input of model runner. + logits (torch.Tensor): The output logits of model forward. + """ + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = grammar_output.grammar_bitmask + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) + if req_id in grammar_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.full( + shape=(logits.shape[0], grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype, + ) + cumulative_index = 0 + for req_id in grammar_output.structured_output_request_ids: + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) + if req_id in struct_out_req_batch_indices: + logit_index = struct_out_req_batch_indices[req_id] + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + + # Copy async to device as tensor. + grammar_bitmask = torch.from_numpy(sorted_bitmask).to( + logits.device, non_blocking=True + ) + + # If the length of out indices and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = len(out_indices) == logits.shape[0] + + index_tensor = None + if not skip_out_indices: + # xgrammar expects a python list of indices but it will actually work with + # a tensor. If we copy the tensor ourselves here we can do it in a non_blocking + # manner and there should be no cpu sync within xgrammar. + index_tensor = torch.tensor( + out_indices, dtype=torch.int32, device="cpu", pin_memory=True + ) + index_tensor = index_tensor.to(logits.device, non_blocking=True) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: remove index_put_ from inductor lowering denylist to + avoid torch.compile error when using xgrammar + ''' + from torch_mlu._inductor import remove_from_lowering_denylist + remove_from_lowering_denylist([torch.ops.aten.index_put_]) + ''' + ================== + End of MLU Hijack + ================== + ''' + + xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor) + +class MLUModelRunner(GPUModelRunner): + + def _init_kv_state( + self, + ): + hf_config = self.model_config.hf_config + if hf_config.model_type != "deepseek_v4": + return + + CACHED_STATE_NUM = self.scheduler_config.max_num_seqs + hf_config.cached_state_num = CACHED_STATE_NUM + self.kv_state_free_slots = set(range(CACHED_STATE_NUM)) + self.req_id_to_kv_state = dict() + + def _insert_req_id( + self, + scheduler_output: "SchedulerOutput", + ): + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + assert req_id not in self.req_id_to_kv_state, \ + f"try to insert req_id: {req_id}, which has been stored int kv_state." + assert self.kv_state_free_slots, "fail to allocate kv states" + slot = self.kv_state_free_slots.pop() + self.req_id_to_kv_state[req_id] = slot + + def _remove_req_id( + self, + scheduler_output: "SchedulerOutput", + ): + for req_id in scheduler_output.finished_req_ids: + slot = self.req_id_to_kv_state.pop(req_id) + self.kv_state_free_slots.add(slot) + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + self.mlu_config = vllm_config.mlu_config + + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( + cache_config.cache_dtype, self.model_config + ) + + self.is_pooling_model = model_config.runner_type == "pooling" + self.enable_prompt_embeds = model_config.enable_prompt_embeds + self.is_multimodal_raw_input_only_model = ( + model_config.is_multimodal_raw_input_only_model + ) + # This will be overridden in load_model() + self.is_multimodal_pruning_enabled = False + self.max_model_len = model_config.max_model_len + + # Always set to false after the first forward pass + self.calculate_kv_scales = self.cache_config.calculate_kv_scales + self.dcp_world_size = self.parallel_config.decode_context_parallel_size + self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + self.broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + + # Model-related. + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) + self.hidden_size = model_config.get_hidden_size() + self.attention_chunk_size = model_config.attention_chunk_size + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = model_config.uses_alibi + + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + model_config + ) + + if self.model_config.is_encoder_decoder: + # Maximum length of the encoder input, only for encoder-decoder + # models. + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens + else: + self.max_encoder_len = 0 + + # Sampler + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: use tmo topk_topp_sampler to sample. + """ + sampler_cls = (MluSampler + if self.model_config.is_deepseek_mla or self.model_config.is_longcat_flash + else Sampler) + self.sampler = sampler_cls(logprobs_mode=self.model_config.logprobs_mode) + """ + ================= + End of MLU Hijack + ================= + """ + + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: Add an extra field to indicate infer mode. + """ + self.mlu_infer_mode = MLUInferMode.PREFILL_ONLY + """ + ================= + End of MLU Hijack + ================= + """ + + self.eplb_state: EplbState | None = None + """ + State of the expert parallelism load balancer. + + Will be lazily initialized when the model is loaded. + """ + + # Lazy initializations + # self.model: nn.Module # Set after load_model + # Initialize in initialize_kv_cache + self.kv_caches: list[torch.Tensor] = [] + # indexes: [kv_cache_group_id][attn_group] + self.attn_groups: list[list[AttentionGroup]] = [] + # self.kv_cache_config: KVCacheConfig + + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} + + self.use_aux_hidden_state_outputs = False + # Set up speculative decoding. + # NOTE(Jiayi): currently we put the entire draft model on + # the last PP rank. This is not ideal if there are many + # layers in the draft model. + if self.speculative_config and get_pp_group().is_last_rank: + self.drafter: ( + NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + ) + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "suffix": + self.drafter = SuffixDecodingProposer(self.vllm_config) + elif self.speculative_config.use_eagle(): + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Use MluEagleProposer instead of EagleProposer + ''' + if vllm_config.mlu_config.enable_custom_data_parallel_opt: + proposer_cls = DPMluEagleProposer + else: + proposer_cls = MluEagleProposer + self.drafter = proposer_cls(self.vllm_config, self.device, self) + self.previous_hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + ''' + ============================= + End of MLU Hijack + ============================= + ''' + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "medusa": + self.drafter = MedusaProposer( + vllm_config=self.vllm_config, device=self.device + ) + else: + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) + self.rejection_sampler = RejectionSampler(self.sampler) + + self.num_spec_tokens = 0 + if self.speculative_config: + self.num_spec_tokens = self.speculative_config.num_speculative_tokens + + # Request states. + self.requests: dict[str, CachedRequestState] = {} + self.comm_stream = torch.mlu.Stream() + + # Input Batch + # NOTE(Chen): Ideally, we should initialize the input batch inside + # `initialize_kv_cache` based on the kv cache config. However, as in + # https://github.com/vllm-project/vllm/pull/18298, due to some unknown + # reasons, we have to initialize the input batch before `load_model`, + # quantization + weight offloading will fail otherwise. As a temporary + # solution, we initialize the input batch here, and re-initialize it + # in `initialize_kv_cache` if the block_sizes here is different from + # the block_sizes in the kv cache config. + custom_logitsprocs = model_config.logits_processors + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Adjust `max_model_len` to expand input_batch.token_ids_cpu_tensor, prevent + overflow when the total length (including speculative tokens) exceeds max_model_len. + ''' + max_model_len_revise=max(self.max_model_len, self.max_encoder_len) + if self.num_spec_tokens > 1: + max_model_len_revise = max_model_len_revise + self.num_spec_tokens - 1 + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max_model_len_revise, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, + self.device, + self.pin_memory, + self.is_pooling_model, + custom_logitsprocs, + ), + # We currently don't know whether a particular custom logits processor + # uses output token ids so we set this conservatively. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), + is_pooling_model=self.is_pooling_model, + dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + self.use_async_scheduling = self.scheduler_config.async_scheduling + # Separate cuda stream for overlapping transfer of sampled token ids from + # GPU to CPU when async scheduling is enabled. + self.async_output_copy_stream: torch.mlu.Stream | None = None + # cuda event to synchronize use of reused CPU tensors between steps + # when async scheduling is enabled. + self.prepare_inputs_event: torch.mlu.Event | None = None + if self.use_async_scheduling: + self.async_output_copy_stream = torch.mlu.Stream() + self.prepare_inputs_event = torch.mlu.Event() + + # self.cudagraph_batch_sizes sorts in ascending order. + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + self.cudagraph_batch_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes + ) + + # Cache the device properties. + self._init_device_properties() + + # Persistent buffers for CUDA graphs. + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: change postions dtype from int64 to int32 + @brief: add seq_start_loc buffer for chunk fa + ''' + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.seq_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) + + self.prefill_enable_mlugraph = self.mlu_config.prefill_enable_mlugraph + self.prefill_mlugraph_batch_size = self.mlu_config.prefill_mlugraph_batch_size + self.prefill_mlugraph_seq_len = self.mlu_config.prefill_mlugraph_seq_len + ''' + ================== + End of MLU Hijack + ================== + ''' + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Add kv_state buffer for deepseekv4 + ''' + self._init_kv_state() + ''' + ================== + End of MLU Hijack + ================== + ''' + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + if self.dcp_world_size > 1: + self.dcp_local_seq_lens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + # Because inputs_embeds may be bfloat16 and we don't need a numpy + # version of this tensor, avoid a RuntimeError by not creating a + # numpy buffer. + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_indices = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) + self.num_discarded_requests = 0 + + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) + + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: change postions dtype from int64 to int32 + ''' + self.mrope_positions = self._make_buffer( + (3, self.max_num_tokens + 1), dtype=torch.int32 + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: IntermediateTensors | None = None + + # OPTIMIZATION: Cache the tensors rather than creating them every step. + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) + + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + self.shared_kv_cache_layers: dict[str, str] = {} + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() + + self.kv_sharing_fast_prefill_logits_indices = None + if self.cache_config.kv_sharing_fast_prefill: + self.kv_sharing_fast_prefill_logits_indices = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=self.device + ) + + self.uniform_decode_query_len = 1 + self.num_spec_tokens + + # Cudagraph dispatcher for runtime cudagraph dispatching. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: enable reorder batch to ensure correct splitting + between prefill chunks and decode tokens in chunked prefill mode. + ''' + self.reorder_batch_threshold: int | None = self.uniform_decode_query_len + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + + # Cached outputs. + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self.transfer_event = torch.mlu.Event() + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: change sampled_token_ids dtype from int64 to int32 + @brief: add draft accepted counter + ''' + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_num_reqs, 1), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) + + # Pre-allocated tensor for copying valid sampled token counts to CPU, + # with dedicated stream for overlapping and event for coordination. + self.valid_sampled_token_count_event: torch.mlu.Event | None = None + self.valid_sampled_token_count_copy_stream: torch.mlu.Stream | None = None + if self.use_async_scheduling and self.num_spec_tokens: + self.valid_sampled_token_count_event = torch.mlu.Event() + self.valid_sampled_token_count_copy_stream = torch.mlu.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + self.total_draft_tokens = 0 + self.total_accepted_tokens = 0 + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Ephemeral state transferred between execute_model() and sample_tokens(). + self.execute_model_state: ExecuteModelState | None = None + self.kv_connector_output: KVConnectorOutput | None = None + + self.execute_cnpx_mark = None + self.request_cnpx_mark = {} + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + self.mamba_block_num = 1 + self.mamba_tensor_size = 0 + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from torch.cuda.get_device_properties""" + self.device_properties = torch.mlu.get_device_properties(self.device) + self.num_sms = self.device_properties.multi_processor_count + + # Note: used for model runner override. + def _sync_device(self) -> None: + torch.mlu.synchronize() + + + def get_accept_rate(self) -> float: + if self.total_draft_tokens == 0: + return 0.0 + return self.total_accepted_tokens / self.total_draft_tokens + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: only support pad tokens in decode mode. + ''' + if ( + self.mlu_infer_mode == MLUInferMode.DECODE_ONLY + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): + # Use CUDA graphs. + # Add padding to the batch size. + return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + num_scheduled_tokens: np.ndarray, + max_num_scheduled_tokens: int, + ) -> tuple[ + torch.Tensor, + SpecDecodeMetadata | None, + UBatchSlices | None, + torch.Tensor | None, + ]: + """ + :return: tuple[ + logits_indices, spec_decode_metadata, + ubatch_slices, num_tokens_across_dp, + ] + """ + self._insert_req_id(scheduler_output) + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + # Get positions. + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) + token_indices_tensor = torch.from_numpy(token_indices) + + # # Get the number of scheduled tokens for each request. + # req_ids = self.input_batch.req_ids + # tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + # num_scheduled_tokens = np.array(tokens, dtype=np.int32) + # max_num_scheduled_tokens = max(tokens) + + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) + if self.enable_prompt_embeds: + is_token_ids = self.input_batch.is_token_ids_tensor.flatten() + torch.index_select( + is_token_ids, + 0, + token_indices_tensor, + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) + + # Because we did not pre-allocate a massive prompt_embeds CPU tensor on + # the InputBatch, we need to fill in the prompt embeds into the expected + # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. + if self.input_batch.req_prompt_embeds: + output_idx = 0 + for req_idx in range(num_reqs): + num_sched = num_scheduled_tokens[req_idx] + + # Skip if this request doesn't have embeddings + if req_idx not in self.input_batch.req_prompt_embeds: + output_idx += num_sched + continue + + # Skip if no tokens scheduled + if num_sched <= 0: + output_idx += num_sched + continue + + req_embeds = self.input_batch.req_prompt_embeds[req_idx] + start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] + + # Skip if trying to read beyond available embeddings + if start_pos >= req_embeds.shape[0]: + output_idx += num_sched + continue + + # Copy available embeddings + end_pos = start_pos + num_sched + actual_end = min(end_pos, req_embeds.shape[0]) + actual_num_sched = actual_end - start_pos + + if actual_num_sched > 0: + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) + + output_idx += num_sched + + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + # Prepare the attention metadata. + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) + self.query_start_loc.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] + + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set enforce_eager on the prefiller in + # a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: build mlu dp metadata for dp opt. + ''' + ubatch_slices, num_tokens_across_dp = None, None + if self.vllm_config.mlu_config.enable_custom_data_parallel_opt: + cur_num_reqs = (num_reqs if self.mlu_infer_mode.is_prefill_only + else num_reqs * (1 + self.num_spec_tokens)) + query_len_per_batch = (self.query_start_loc.np[1:] - + self.query_start_loc.np[:-1]).tolist() + dp_metadata = self._get_data_parallel_metadata( + total_num_scheduled_tokens, cur_num_reqs, + self.mlu_infer_mode.is_decode_only, query_len_per_batch[:cur_num_reqs] + ) + # replace num_tokens_across_dp with dp_metadata for return + num_tokens_across_dp = dp_metadata + else: + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.parallel_config, + allow_microbatching=True, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + self.seq_lens.np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add seq_start_loc for chunk fa. + ''' + self.seq_start_loc.np[0] = 0 + self.seq_start_loc.np[1:num_reqs + 1] = np.cumsum(self.seq_lens.np[:num_reqs]) + self.seq_start_loc.np[num_reqs + 1 :].fill(self.seq_start_loc.np[num_reqs]) + self.seq_start_loc.copy_to_gpu() + ''' + ================== + End of MLU Hijack + ================== + ''' + + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = ( + discard_request_indices + ) + + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + + # Copy the tensors to the GPU. + self._prepare_input_ids( + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens, + ) + + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) + else: + # Common case (1D positions) + self.positions.copy_to_gpu(total_num_scheduled_tokens) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + num_draft_tokens = None + spec_decode_metadata = None + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + # For chunked prefills, use -1 as mask rather than 0, as guided + # decoding may rollback speculative tokens. + num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens + ) + logits_indices = spec_decode_metadata.logits_indices + num_sampled_tokens = num_draft_tokens + 1 + # For DECODE only cuda graph of some attention backends (e.g., GDN). + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[num_reqs:].fill(-1) + self.num_decode_draft_tokens.copy_to_gpu() + + # Hot-Swap lora model + if self.lora_config: + assert ( + np.sum(num_sampled_tokens) + <= self.vllm_config.scheduler_config.max_num_batched_tokens + ) + self.set_active_loras( + self.input_batch, num_scheduled_tokens, num_sampled_tokens + ) + + return ( + logits_indices, + spec_decode_metadata, + ubatch_slices, + num_tokens_across_dp, + ) + + def get_model(self) -> nn.Module: + # get raw model out of the cudagraph wrapper. + if isinstance(self.model, (MLUGraphWrapper, UBatchWrapper)): + return self.model.unwrap() + return self.model + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + # Batch the multi-modal inputs using the helper method. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output + ) + + if not mm_kwargs: + return + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: v1 offline benchmark + ''' + self.mm_time_markers = [] + + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + mm_start = torch.mlu.Event(enable_timing=True) + mm_start.record() + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + model = cast(SupportsMultiModal, self.model) + encoder_outputs = [] + for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, + ): + curr_group_outputs = [] + + # EVS-related change. + # (ekhvedchenia): Temporary hack to limit peak memory usage when + # processing multimodal data. This solves the issue with scheduler + # putting too many video samples into a single batch. Scheduler + # uses pruned vision tokens count to compare it versus compute + # budget which is incorrect (Either input media size or non-pruned + # output vision tokens count should be considered) + # TODO(ywang96): Fix memory profiling to take EVS into account and + # remove this hack. + if ( + self.is_multimodal_pruning_enabled + and modality == "video" + and num_items > 1 + ): + for video_mm_kwargs_item in filter( + lambda item: item.modality == "video", mm_kwargs + ): + _, _, micro_batch_mm_inputs = next( + group_mm_kwargs_by_modality( + [video_mm_kwargs_item], + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, + ) + ) + + micro_batch_outputs = model.embed_multimodal( + **micro_batch_mm_inputs + ) + + curr_group_outputs.extend(micro_batch_outputs) + else: + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, + # each of shape (feature_size, hidden_size) in case the feature + # size is dynamic depending on the input multimodal items. + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=num_items, + ) + encoder_outputs.extend(curr_group_outputs) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: v1 offline benchmark + ''' + if encoder_outputs and mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + mm_end = torch.mlu.Event(enable_timing=True) + mm_end.record() + self.mm_time_markers.append([mm_start, mm_end]) + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + self._remove_req_id(scheduler_output) + + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" + + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, + mm_features=new_req_data.mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=getattr(new_req_data, 'new_token_ids', None) or [], + lora_request=new_req_data.lora_request, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + self.requests[req_id] = req_state + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(req_state) + + reqs_to_add.append(req_state) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt lenth), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_state.prev_num_draft_len: + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + assert req_index is None + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + + reqs_to_add.append(req_state) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens + if new_block_ids is not None: + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, [] + ) + num_spec_tokens = len(spec_token_ids) + # For async scheduling, token_ids_cpu assigned from + # spec_token_ids are placeholders and will be overwritten in + # _prepare_input_ids. + if num_spec_tokens: + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index + ] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # When speculative decoding is used with structured output, + # the scheduler can drop draft tokens that do not + # conform to the schema. This can result in + # scheduler_output.scheduled_spec_decode_tokens being empty, + # even when speculative decoding is enabled. + self.input_batch.spec_token_ids[req_index] = spec_token_ids + + # there are no draft tokens with async scheduling, + # we clear the spec_decoding info in scheduler_output and + # use normal sampling but rejection_sampling. + if self.use_async_scheduling: + req_state.prev_num_draft_len = num_spec_tokens + if num_spec_tokens and self._draft_token_ids is None: + scheduler_output.total_num_scheduled_tokens -= num_spec_tokens + scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens + scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + self.input_batch.add_request(request) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: cache the unverified spec_decode token + ''' + req_id = request.req_id + req_index = self.input_batch.req_id_to_index.get(req_id) + assert req_index is not None + num_tokens = self.input_batch.num_tokens[req_index] + end_token_index = num_tokens + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, [] + ) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | IntermediateTensors | None: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: clear time markers before execute model. + ''' + self.time_markers = [] + self.mm_time_markers = [] + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.execute_model_state is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with record_function_or_nullcontext("gpu_model_runner: preprocess"): + with self.synchronize_input_prep(): + # Update persistent batch states. + self._update_states(scheduler_output) + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward( + scheduler_output, self.vllm_config + ) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add mlu_infer_mode. + ''' + max_computed_tokens = np.max(self.input_batch.num_computed_tokens_cpu[:num_reqs]) + self.mlu_infer_mode = MLUInferMode.build( + max_query_len=max_num_scheduled_tokens, + max_computed_tokens=max_computed_tokens, + uniform_decode_query_len=self.uniform_decode_query_len, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + ( + logits_indices, + spec_decode_metadata, + ubatch_slices, + num_tokens_across_dp, + ) = self._prepare_inputs( + scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and ubatch_slices is None: + # Pre-compute cascade attention prefix lengths + # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + scheduler_output.num_common_prefix_blocks, + ) + + # TODO(lucas): move cudagraph dispatching here: + # https://github.com/vllm-project/vllm/issues/23789 + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( + total_num_scheduled_tokens=total_num_scheduled_tokens, + max_num_scheduled_tokens=max_num_scheduled_tokens, + num_reqs=num_reqs, + ubatch_slices=ubatch_slices, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + mlu_infer_mode=self.mlu_infer_mode, + ) + ) + + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + ''' + ============================= + Modify by vllm_mlu + ============================= + pad num_input_tokens after supporting pad decode graph. + ''' + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) + capture_already = False + K = 0 + if hasattr(self.speculative_config, "num_speculative_tokens"): + K = self.speculative_config.num_speculative_tokens + if (hasattr(self, 'cudagraph_batch_sizes') and + self.cudagraph_batch_sizes is not None): + decode_cudagraph_batch_sizes = [ + x + for x in self.cudagraph_batch_sizes + if max_num_tokens >= x >= self.uniform_decode_query_len + ] + capture_already = len(decode_cudagraph_batch_sizes) > 0 and \ + num_reqs*(1+K) <= max(decode_cudagraph_batch_sizes) + if self.mlu_infer_mode == MLUInferMode.DECODE_ONLY and not \ + all(x == K + 1 for x in scheduler_output.num_scheduled_tokens.values()): + capture_already = False + if capture_already: + num_input_tokens = self._get_num_input_tokens( + scheduler_output.total_num_scheduled_tokens + ) + else: + num_input_tokens = scheduler_output.total_num_scheduled_tokens + ''' + ================== + End of MLU Hijack + ================== + ''' + + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_input_tokens, intermediate_tensors + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @breif: add padding for attention metadata in decode graph. + ''' + # padding + self._padding_attn_metadata(attn_metadata, + input_ids, inputs_embeds, capture_already, + num_input_tokens, num_scheduled_tokens) + ''' + ================== + End of MLU Hijack + ================== + ''' + + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + batch_descriptor, + use_cascade_attn=cascade_attn_prefix_lens is not None, + ) + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @breif: add prefill graph & capture already check + ''' + if (self.prefill_enable_mlugraph and + attn_metadata.get(COMMON_METADATA_STR) is not None and + attn_metadata[COMMON_METADATA_STR].infer_mode == MLUInferMode.PREFILL_ONLY): + cudagraph_runtime_mode = CUDAGraphMode.FULL + if not capture_already: + cudagraph_runtime_mode = CUDAGraphMode.NONE + ''' + ================== + End of MLU Hijack + ================== + ''' + + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: debug disagg cnpx. + ''' + if mlu_envs.VLLM_DISAGG_CNPX_EXECUTE: + self.execute_cnpx_mark = cnpx.rangeStart("DP_" + str(self.parallel_config.data_parallel_rank) + "_TP_" \ + + str(get_tensor_model_parallel_rank()) + "_execute_model" + \ + ("_no_graph" if cudagraph_runtime_mode == CUDAGraphMode.NONE else "")) + if mlu_envs.VLLM_DISAGG_CNPX_REQUEST: + self.request_cnpx_mark.clear() + for req in scheduler_output.scheduled_new_reqs: + self.request_cnpx_mark[req.req_id] = cnpx.rangeStart(req.req_id) + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + self.request_cnpx_mark[req_id] = cnpx.rangeStart(req_id) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + start = torch.mlu.Event(enable_timing=True) + start.record() + + ''' + ============================= + Modify by vllm_mlu + ============================= + @breif: add set_sp_forward_context for sequence parallel. + ''' + # Run the model. + # Use persistent buffers for CUDA graphs. + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), + set_sp_forward_context( + attn_metadata, + self.vllm_config, + num_input_tokens, + ), + record_function_or_nullcontext("gpu_model_runner: forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + if self.model_config.hf_config.model_type == "deepseek_v4": + model_kwargs["batch_to_kv_state"] = torch.tensor([ + self.req_id_to_kv_state[req_id] for req_id in self.input_batch._req_ids + ], dtype=torch.int32, device=input_ids.device) + model_output = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + prefill_enable_mlugraph=self.prefill_enable_mlugraph, + **model_kwargs, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + with record_function_or_nullcontext("gpu_model_runner: postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. + hidden_states, aux_hidden_states = model_output + else: + # Common case. + hidden_states = model_output + aux_hidden_states = None + + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + # Return the pooling output. + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) + output.kv_connector_output = kv_connector_output + ''' + ============================= + Modify by vllm_mlu + ============================= + @breif: add time markers for pooling model + ''' + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + self.time_markers.append([start, end]) + ''' + ================== + End of MLU Hijack + ================== + ''' + return output + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + else: + # Rare case. + assert not self.is_pooling_model + + sample_hidden_states = hidden_states[logits_indices] + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + else: + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN: + end = torch.mlu.Event(enable_timing=True) + end.record() + self.time_markers.append([start, end]) + + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + ) + return None + + def response_remote_alloc_once(self) -> None: + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase) + kv_connector.response_remote_alloc_once() + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncMLUModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # noqa + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) + + with record_function_or_nullcontext("gpu_model_runner: sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + self.input_batch.prev_sampled_token_ids = None + + def propose_draft_token_ids( + sampled_token_ids: torch.Tensor | list[np.ndarray], + ) -> None: + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("gpu_model_runner: draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + effective_drafter_max_model_len = self.max_model_len + if effective_drafter_max_model_len is None: + effective_drafter_max_model_len = self.model_config.max_model_len + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): + effective_drafter_max_model_len = ( + self.speculative_config.draft_model_config.max_model_len + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Force `input_fits_in_drafter` to be True to ensure that `self.uniform_decode_query_len` tokens are scheduled per batch during model execution. + This is required for graph validation and to keep the batch token count consistent with `self.uniform_decode_query_len` immediately after the prefill stage. + ''' + # input_fits_in_drafter = spec_decode_common_attn_metadata and ( + # spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + # <= effective_drafter_max_model_len + # ) + input_fits_in_drafter = True + ''' + ================== + End of MLU Hijack + ================== + ''' + if use_padded_batch_for_eagle: + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("gpu_model_runner: eplb"): + self.eplb_step() + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, + num_nans_in_logits=num_nans_in_logits, + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + get_kv_transfer_group().clear_connector_metadata() + + if mlu_envs.VLLM_DISAGG_CNPX_EXECUTE: + current_stream = torch.mlu.current_stream() + current_stream.synchronize() + cnpx.rangeEnd(self.execute_cnpx_mark) + if mlu_envs.VLLM_DISAGG_CNPX_REQUEST: + current_stream = torch.mlu.current_stream() + current_stream.synchronize() + for req in scheduler_output.scheduled_new_reqs: + cnpx.rangeEnd(self.request_cnpx_mark[req.req_id]) + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + cnpx.rangeEnd(self.request_cnpx_mark[req_id]) + ''' + ================== + End of MLU Hijack + ================== + ''' + if not self.use_async_scheduling: + return output + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncMLUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + + def propose_draft_token_ids( + self, + scheduler_output: "SchedulerOutput", + sampled_token_ids: torch.Tensor | list[np.ndarray], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + common_attn_metadata: MLUCommonAttentionMetadata, + ) -> torch.Tensor | list[list[int]]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: draft model will build new FlashMLAMetadata, + so just unpad common_attn_metadata here. + ''' + unpad_common_attn_metadata( + common_metadata=common_attn_metadata, + num_reqs=self.input_batch.num_reqs, + num_scheduled_tokens=num_scheduled_tokens + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, NgramProposer) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs, + ) + elif self.speculative_config.method == "suffix": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, SuffixDecodingProposer) + draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) + elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, MedusaProposer) + + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. + hidden_states = sample_hidden_states + else: + indices = [] + offset = 0 + assert spec_decode_metadata is not None, ( + "No spec decode metadata for medusa" + ) + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): + indices.append(offset + tokens.shape[0] - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + + draft_token_ids = self.drafter.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + elif self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" + "padded-batch is disabled." + ) + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" + "padded-batch is enabled." + ) + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + if spec_decode_metadata is None: + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + num_rejected_tokens_gpu = None + token_indices = None + else: + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) + else: + common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = ( + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count, + ) + ) + + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[token_indices] + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: add debug info for draft accepted rate + ''' + if mlu_envs.VLLM_MTP_DEBUG: + batch_total_draft = sum(spec_decode_metadata.num_draft_tokens) + batch_total_rejected = sum(num_rejected_tokens_gpu) + self.total_draft_tokens += batch_total_draft + self.total_accepted_tokens += ( + batch_total_draft - batch_total_rejected) + if batch_total_draft > 0: + batch_accept_rate = ( + batch_total_draft - batch_total_rejected + ) / batch_total_draft + print(f"Batch Accept Rate: {batch_accept_rate:.4f}, " + f"Total Accept Rate: {self.get_accept_rate():.4f}") + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.supports_mm_inputs: + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: keep full scheduled tokens for draft model compute + ''' + target_token_ids = target_token_ids[:num_scheduled_tokens] + target_positions = target_positions[:num_scheduled_tokens] + target_hidden_states = target_hidden_states[:num_scheduled_tokens] + ''' + ================== + End of MLU Hijack + ================== + ''' + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + num_rejected_tokens=num_rejected_tokens_gpu, + token_indices=token_indices, + time_markers=self.time_markers, + ) + + return draft_token_ids + + def load_model(self, eep_scale_up: bool = False) -> None: + """ + Args: + eep_scale_up: the model loading is for elastic EP scale up. + """ + logger.info_once( + "Starting to load model %s...", + self.model_config.model, + scope="global", + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: + 1. Set max_batched_token for SparseMoeMlp when enable avg moe. + 2. modify rope's max_position_embeddings to max_model_len. + Those MUST be set before init model. + ''' + if mlu_envs.VLLM_AVG_MOE_EN: + logger.warning("Inference with Moe avg dispatch, " + "it's only for deepseek-v3/r1 model's performance test," + " and will result in precision anomalies. Be careful!") + SparseMoeMlp.max_batched_token = max(self.model_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + MLURotaryEmbedding.max_seq_len = self.model_config.max_model_len + MLURotaryEmbedding.max_model_len = self.model_config.max_model_len + ''' + ================== + End of MLU Hijack + ================== + ''' + global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( + EplbState.get_eep_state(self.parallel_config) + if eep_scale_up + else (None, None, None) + ) + + if self.parallel_config.enable_eplb: + self.eplb_state = EplbState(self.parallel_config, self.device) + eplb_models = 0 + with DeviceMemoryProfiler() as m: + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: register model pre forward for rope optimization + ''' + self.model.register_forward_pre_hook(_model_forward_pre_hook, with_kwargs=True) + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.lora_config: + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) + if hasattr(self, "drafter"): + logger.info_once("Loading drafter model...") + self.drafter.load_model(self.model) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: Apply forward prehook to draft model. + ''' + self.drafter.model.register_forward_pre_hook(_model_forward_pre_hook, with_kwargs=True) + ''' + ================== + End of MLU Hijack + ================== + ''' + if ( + hasattr(self.drafter, "model") + and is_mixture_of_experts(self.drafter.model) + and self.parallel_config.enable_eplb + ): + logger.info_once( + "EPLB is enabled for drafter model %s.", + self.vllm_config.speculative_config.draft_model_config.model, + ) + + global_expert_load = ( + global_expert_loads[eplb_models] + if global_expert_loads + else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + if self.eplb_state is None: + self.eplb_state = EplbState(self.parallel_config, self.device) + self.eplb_state.add_model( + self.drafter.model, + self.vllm_config.speculative_config.draft_model_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + eplb_models += 1 + + if self.use_aux_hidden_state_outputs: + if not supports_eagle3(self.get_model()): + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested" + ) + + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers, + ) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + + self.model.set_aux_hidden_state_layers(aux_layers) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info_once( + "Model loading took %.4f GiB memory and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + scope="local", + ) + prepare_communication_buffer_for_model(self.model) + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.get_model()) + and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + ) + + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info_once("EPLB is enabled for model %s.", self.model_config.model) + global_expert_load = ( + global_expert_loads[eplb_models] if global_expert_loads else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + assert self.eplb_state is not None + self.eplb_state.add_model( + self.model, + self.model_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + + if ( + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + and supports_dynamo() + ): + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) + compilation_counter.stock_torch_compile_count += 1 + self.model.compile(fullgraph=True, backend=backend) + return + # for other compilation modes, cudagraph behavior is controlled by + # CudagraphWraper and CudagraphDispatcher of vllm. + + # wrap the model with full cudagraph wrapper if needed. + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.enable_dbo + ): + self.model = MLUGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + elif self.parallel_config.enable_dbo: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) + else: + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) + + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, LogprobsTensors | None]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + num_tokens = num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True + ) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: remove the prompt_logprobs for final chunk request + ''' + del prompt_logprobs_dict[req_id] + ''' + ================== + End of MLU Hijack + ================== + ''' + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids + ) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( + ranks, non_blocking=True + ) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + self._sync_device() + + return prompt_logprobs_dict + + def _build_attention_metadata( + self, + total_num_scheduled_tokens: int, + max_num_scheduled_tokens: int, + num_reqs: int, + ubatch_slices: UBatchSlices | None = None, + logits_indices: torch.Tensor | None = None, + use_spec_decode: bool = False, + for_cudagraph_capture: bool = False, + scheduled_encoder_inputs: dict[str, list[int]] | None = None, + cascade_attn_prefix_lens: list[list[int]] | None = None, + mlu_infer_mode: MLUInferMode | None = None, + ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: + """ + :return: tuple[attn_metadata, spec_decode_common_attn_metadata] + """ + logits_indices_padded = None + num_logits_indices = 0 + if logits_indices is not None: + num_logits_indices = logits_indices.size(0) + if self.cache_config.kv_sharing_fast_prefill: + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices + ) + + # update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + self.seq_lens.cpu[:num_reqs], + self.dcp_world_size, + self.dcp_rank, + self.parallel_config.dcp_kv_cache_interleave_size, + ) + self.dcp_local_seq_lens.copy_to_gpu(num_reqs) + + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + + # Used in the below loop + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] + seq_lens = self.seq_lens.gpu[:num_reqs] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ] + dcp_local_seq_lens = ( + self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None + ) + spec_decode_common_attn_metadata = None + + if for_cudagraph_capture: + # For some attention backends (e.g. FA) with sliding window models we need + # to make sure the backend see a max_seq_len that is larger to the sliding + # window size when capturing to make sure the correct kernel is selected. + max_seq_len = self.max_model_len + else: + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups + ): + encoder_seq_lens = self._get_encoder_seq_lens( + scheduled_encoder_inputs or {}, + kv_cache_group.kv_cache_spec, + num_reqs, + ) + + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_gid] + blk_table_tensor = blk_table.get_device_tensor(num_reqs) + slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: replace CommonAttentionMetadata with MLUCommonAttentionMetadata + """ + common_attn_metadata = MLUCommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + causal=True, + dcp_local_seq_lens=dcp_local_seq_lens, + seq_start_loc=self.seq_start_loc.gpu[: num_reqs + 1], + seq_start_loc_cpu=self.seq_start_loc.cpu[: num_reqs + 1], + infer_mode=mlu_infer_mode, + num_prefill_query_tokens=total_num_scheduled_tokens, + num_prefill_kv_tokens=total_num_scheduled_tokens, + ) + """ + ================= + End of MLU Hijack + ================= + """ + if self.speculative_config and spec_decode_common_attn_metadata is None: + if isinstance(self.drafter, EagleProposer): + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: replace attn metadata name to prefill_attn name + """ + attn_layer_name = self.drafter.attn_layer_names[0] + if self.model_config.is_deepseek_mla and attn_layer_name.endswith("self_attn.attn"): + attn_layer_name = attn_layer_name.replace( + "self_attn.attn", "self_attn.mla_attn") + if attn_layer_name in kv_cache_group.layer_names: + spec_decode_common_attn_metadata = common_attn_metadata + """ + ================= + End of MLU Hijack + ================= + """ + else: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]): + cascade_attn_prefix_len = ( + cascade_attn_prefix_lens[kv_cache_gid][attn_gid] + if cascade_attn_prefix_lens + else 0 + ) + builder = attn_group.get_metadata_builder() + + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs + ], + ) + + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata + ) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list + ): + builder = attn_group.get_metadata_builder(ubatch_id=ubid) + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in kv_cache_group.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert isinstance(attn_metadata, dict) + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: bind decode_attn metadata to prefill_attn + """ + for layer_name in attn_group.layer_names: + if ( + self.model_config.is_deepseek_mla + and layer_name.endswith("self_attn.mla_attn") + ): + prefill_attn_name = layer_name.replace( + "self_attn.mla_attn", "self_attn.attn" + ) + attn_metadata[prefill_attn_name] = attn_metadata[layer_name] + + # matches self_attn.0.attn or self_attn.1.attn for longcat-flash + if ( + self.model_config.is_longcat_flash + and (match := re.match(r".*self_attn\.(0|1)\.mla_attn$", layer_name)) + ): + # Extract the captured digit (0 or 1) + digit = match.group(1) + prefill_attn_name = layer_name.replace( + f"self_attn.{digit}.mla_attn", + f"self_attn.{digit}.attn" + ) + attn_metadata[prefill_attn_name] = attn_metadata_i + """ + ================= + End of MLU Hijack + ================= + """ + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: Add common_attn_metadata to attn_metadata + """ + attn_metadata[COMMON_METADATA_STR] = common_attn_metadata + """ + ================= + End of MLU Hijack + ================= + """ + + return attn_metadata, spec_decode_common_attn_metadata + + def _padding_attn_metadata( + self, + attn_metadata: MLACommonMetadata | FlashAttentionMetadata, + input_ids: torch.Tensor | None, + inputs_embeds: torch.Tensor | None, + captured_already: bool, + num_input_tokens: int, + num_scheduled_tokens: int + ) -> None: + common_metadata = attn_metadata[COMMON_METADATA_STR] + decode_only = common_metadata.is_decode_only + if decode_only and captured_already: + # If the model is decode only, we can use full graph. + # use_full_graph = use_full_graph # and captured_already + # Update attn_metadata for full graph. + K = 0 + if (self.speculative_config is not None + and self.speculative_config.num_speculative_tokens > 0 + ): + K = self.speculative_config.num_speculative_tokens + + if num_input_tokens != num_scheduled_tokens: + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_table = self.input_batch.block_table[kv_cache_group_id] + first_layer_name = kv_cache_group_spec.layer_names[0] + attn_metadata_i = attn_metadata[first_layer_name] + num_reqs = self.input_batch.num_reqs + num_padded_reqs = self.vllm_config.pad_for_cudagraph(num_reqs * (1 + K)) // (1 + K) + pad_attn_metadata( + attn_metadata_i, common_metadata, block_table, + self, num_scheduled_tokens, num_input_tokens, + num_reqs, num_padded_reqs, + ) + + + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: Add prefill input parameters + @parameters: is_capturing_prefill, prefill_batch_size, prefill_seq_len + """ + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + is_capturing_prefill: bool = False, + prefill_batch_size: int = None, + prefill_seq_len: int = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a dummy forward pass to warm up/profile run or capture the + CUDA graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: used to control the behavior. + - if not set will determine the cudagraph mode based on using + the self.cudagraph_dispatcher. + - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. + - CUDAGraphMode.FULL: Full cudagraph, attention metadata is + needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. + remove_lora: If False, dummy LoRAs are not destroyed after the run + activate_lora: If False, dummy_run is performed without LoRAs. + """ + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + # Disable DP padding when running eager + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + + # We currently only microbatch if the number of tokens is + # over a certain threshold. + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=total_num_scheduled_tokens, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=total_num_scheduled_tokens, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, + ) + num_tokens_after_padding = num_tokens + if num_tokens_across_dp is not None: + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) + + attn_metadata: PerLayerAttnMetadata | None = None + + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: use prefill_seq_len to build seq_lens + when prefill capture + """ + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + elif is_capturing_prefill: + seq_lens = prefill_seq_len + else: + seq_lens = max_query_len + """ + ================= + End of MLU Hijack + ================= + """ + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: compute seq_start_loc and mlu_infer_mode. + @brief: use prefill_batch_size to build seq_start_loc + """ + cu_seqlens_k = np.cumsum(self.seq_lens.np[:num_reqs]) + self.seq_start_loc.np[0] = 0 + self.seq_start_loc.np[1 : num_reqs + 1] = cu_seqlens_k + self.seq_start_loc.copy_to_gpu() + + max_computed_tokens = np.max(self.input_batch.num_computed_tokens_cpu[:num_reqs]) + mlu_infer_mode = MLUInferMode.build( + max_query_len=max_query_len, + max_computed_tokens=max_computed_tokens, + uniform_decode_query_len=self.uniform_decode_query_len) + if is_capturing_prefill: + attn_metadata, _ = self._build_attention_metadata( + total_num_scheduled_tokens=num_tokens, + max_num_scheduled_tokens=max_query_len, + num_reqs=prefill_batch_size, + ubatch_slices=ubatch_slices, + for_cudagraph_capture=True, + mlu_infer_mode=MLUInferMode.PREFILL_ONLY, + ) + else: + attn_metadata, _ = self._build_attention_metadata( + total_num_scheduled_tokens=num_tokens, + max_num_scheduled_tokens=max_query_len, + num_reqs=num_reqs, + ubatch_slices=ubatch_slices, + for_cudagraph_capture=True, + mlu_infer_mode=mlu_infer_mode, + ) + """ + ================= + End of MLU Hijack + ================= + """ + + with self.maybe_dummy_run_with_lora( + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, + ): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_after_padding <= self.max_num_tokens + model_kwargs = self._init_model_kwargs(num_tokens_after_padding) + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] + model_kwargs = { + **model_kwargs, + **self._dummy_mm_kwargs(num_reqs), + } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] + model_kwargs = self._init_model_kwargs(num_tokens_after_padding) + else: + input_ids = self.input_ids.gpu[:num_tokens_after_padding] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens_after_padding] + else: + positions = self.positions.gpu[:num_tokens_after_padding] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + ) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens_after_padding, None, False + ) + + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + has_lora=activate_lora and self.lora_config is not None, + ) + ) + if not is_profile + else (CUDAGraphMode.NONE, None) + ) + """ + ============================= + Modify by vllm_mlu + ============================= + @brief: adjust cudagraph mode for + prefill graph capture + """ + if is_capturing_prefill: + _cg_mode = cudagraph_runtime_mode + """ + ================= + End of MLU Hijack + ================= + """ + if cudagraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for cudagraph capture + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) + else: + cudagraph_runtime_mode = _cg_mode + + if ubatch_slices is not None: + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_after_padding = ubatch_slices[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_after_padding + + ''' + ============================= + Modify by vllm_mlu + ============================= + @breif: add set_sp_forward_context for sequence parallel. + ''' + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_after_padding, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), + set_sp_forward_context( + attn_metadata, + self.vllm_config, + num_tokens_after_padding, + ), + ): + if self.model_config.hf_config.model_type == "deepseek_v4": + assert self.kv_state_free_slots, \ + "At least one slot is needed to run dummy model" + model_kwargs["batch_to_kv_state"] = torch.tensor([ + list(self.kv_state_free_slots)[0] + ] * num_reqs, + dtype=torch.int32, + device=input_ids.device, + ) + outputs = self.model( + is_capturing_prefill=is_capturing_prefill, + prefill_batch_size=prefill_batch_size, + prefill_seq_len=prefill_seq_len, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + use_cudagraphs = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL + and not self.speculative_config.enforce_eager + ) + + # Note(gnovack) - We need to disable cudagraphs for one of the two + # lora cases when cudagraph_specialize_lora is enabled. This is a + # short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/28334 + if self.compilation_config.cudagraph_specialize_lora and activate_lora: + use_cudagraphs = False + + self.drafter.dummy_run( + attn_metadata, + num_tokens, + use_cudagraphs=use_cudagraphs, + ) + + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + logit_indices_device = torch.from_numpy(logit_indices).to( + self.device, non_blocking=True + ) + return hidden_states, hidden_states[logit_indices_device] + + def _capture_cudagraphs( + self, + is_capturing_prefill: bool = False, + prefill_batch_size: int = 0, + prefill_seq_len: int = 0, + compilation_cases: list[tuple[int, bool]] = [] , + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + uniform_decode: bool = False, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graphs ({}, {})".format( + "prefill" if is_capturing_prefill else "decode", + cudagraph_runtime_mode.name, + ), + ) + if (self.speculative_config is not None + and self.speculative_config.num_speculative_tokens > 0 + ): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA draft graphs ({}, {})".format( + "decode", + cudagraph_runtime_mode.name, + ), + ) + + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens, activate_lora in compilation_cases: + # We currently only capture ubatched graphs when its a FULL + # cudagraph, a uniform decode batch, and the number of tokens + # is above the threshold. Otherwise we just capture a non-ubatched + # version of the graph + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode + and check_ubatch_thresholds( + config=self.vllm_config.parallel_config, + num_tokens=num_tokens, + uniform_decode=uniform_decode, + ) + ) + + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens=num_tokens, + is_capturing_prefill=is_capturing_prefill, + prefill_batch_size=prefill_batch_size, + prefill_seq_len=prefill_seq_len, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + activate_lora=activate_lora, + ) + self._dummy_run( + num_tokens=num_tokens, + is_capturing_prefill=is_capturing_prefill, + prefill_batch_size=prefill_batch_size, + prefill_seq_len=prefill_seq_len, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + activate_lora=activate_lora, + ) + self.maybe_remove_all_loras(self.lora_config) + + def capture_model(self) -> int: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + logger.warning( + "Skipping CUDA graph capture. To turn on CUDA graph capture, " + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) + return 0 + + compilation_counter.num_gpu_runner_capture_triggers += 1 + + start_time = time.perf_counter() + + @contextmanager + def freeze_gc(): + # Optimize garbage collection during CUDA graph capture. + # Clean up, then freeze all remaining objects from being included + # in future collections. + gc.collect() + should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + gc.collect() + + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + set_cudagraph_capturing_enabled(True) + with freeze_gc(), mlu_graph_capture(device=self.device): + start_free_gpu_memory = torch.mlu.mem_get_info()[0] + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + + if self.lora_config: + if self.compilation_config.cudagraph_specialize_lora: + lora_cases = [True, False] + else: + lora_cases = [True] + else: + lora_cases = [False] + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: prefill graph capture + ''' + if self.prefill_enable_mlugraph: + # capture prefill mlugraph + batch_size = self.prefill_mlugraph_batch_size + seq_len = self.prefill_mlugraph_seq_len + num_tokens = batch_size * seq_len + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + assert batch_size <= self.scheduler_config.max_num_seqs + logger.info("Capture prefill mlugraph for batch size " + f"{batch_size} and seq len {seq_len}") + prefill_compilation_cases = list( + product([num_tokens], lora_cases) + ) + self._capture_cudagraphs( + is_capturing_prefill=True, + prefill_batch_size=batch_size, + prefill_seq_len=seq_len, + compilation_cases=prefill_compilation_cases, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=False, + ) + ''' + ================== + End of MLU Hijack + ================== + ''' + + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + cudagraph_runtime_mode = cudagraph_mode.mixed_mode() + # make sure we capture the largest batch size first + compilation_cases = list( + product(reversed(self.cudagraph_batch_sizes), lora_cases) + ) + self._capture_cudagraphs( + compilation_cases=compilation_cases, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=False, + ) + + # Capture full cudagraph for uniform decode batches if we + # don't already have full mixed prefill-decode cudagraphs. + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and hasattr(self, 'cudagraph_batch_sizes') + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) + decode_cudagraph_batch_sizes = [ + x + for x in self.cudagraph_batch_sizes + if max_num_tokens >= x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + product(reversed(decode_cudagraph_batch_sizes), lora_cases) + ) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True, + ) + + torch.mlu.synchronize() + end_free_gpu_memory = torch.mlu.mem_get_info()[0] + + # Disable cudagraph capturing globally, so any unexpected cudagraph + # capturing will be detected and raise an error after here. + # Note: We don't put it into graph_capture context manager because + # we may do lazy capturing in future that still allows capturing + # after here. + set_cudagraph_capturing_enabled(False) + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info_once( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + scope="local", + ) + return cuda_graph_size + + def _allocate_kv_cache_tensors( + self, + kv_cache_config: KVCacheConfig + ) -> dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next, deepseek v3.2 indexer cache and mlu kv8 + ''' + + kv_cache_group = kv_cache_config.kv_cache_groups[0] + + if self.mlu_config.enable_mamba_split_page_size: + # hybrid attention, try to find full attention + for group in kv_cache_config.kv_cache_groups: + if isinstance(group.kv_cache_spec, FullAttentionSpec): + kv_cache_group = group + break + self.mamba_block_num = self.mlu_config.mamba_support_max_batch_size + self.mamba_tensor_size = (kv_cache_group.kv_cache_spec.page_size_bytes \ + * self.mlu_config.mamba_to_attn_block_ratio * self.mamba_block_num) + logger.info(f"one linear attn layer cache tensor size {self.mamba_tensor_size}") + + kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + kv_cache_spec = kv_cache_group.kv_cache_spec + assert kv_cache_tensor.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = kv_cache_tensor.size // kv_cache_spec.page_size_bytes + if kv_cache_spec.dtype in [torch.int8, torch.uint8]: + # mlu kv8 + assert isinstance(kv_cache_spec, AttentionSpec) + cache_ = torch.zeros( + num_blocks * kv_cache_spec.cache_size_bytes, + dtype=torch.int8, device=self.device, + ) + scale_ = torch.zeros( + num_blocks * kv_cache_spec.scale_size_bytes, + dtype=torch.int8, device=self.device, + ) + else: + # not mlu kv8 + cache_ = torch.zeros( + num_blocks * kv_cache_spec.cache_size_bytes, + dtype=torch.int8, + device=self.device + ) + scale_ = torch.tensor([], dtype=torch.int8, device=self.device) + + if (isinstance(kv_cache_spec, MLUMLAAttentionSpec) + and kv_cache_spec.index_n_heads > 0): + index_cache_ = torch.zeros((num_blocks * + kv_cache_spec.index_cache_size_bytes), + dtype=torch.int8, + device=self.device) + else: + index_cache_ = torch.tensor([], dtype=torch.int8, device=self.device) + + for layer_name in kv_cache_tensor.shared_by: + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + if self.mlu_config.enable_mamba_split_page_size: + if 'linear_attn' in layer_name: + mamba_tensor = torch.zeros( + self.mamba_tensor_size, dtype=torch.int8, device=self.device + ) + kv_cache_raw_tensors[layer_name] = [mamba_tensor, scale_, index_cache_] + else: + kv_cache_raw_tensors[layer_name] = [cache_, scale_, index_cache_] + else: + kv_cache_raw_tensors[layer_name] = [cache_, scale_, index_cache_] + ''' + ================== + End of MLU Hijack + ================== + ''' + ''' + ================== + End of MLU Hijack + ================== + ''' + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) + return kv_cache_raw_tensors + + def _reshape_kv_cache_tensors( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + kernel_block_sizes: list[int], + ) -> dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape and dtype. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + kernel_block_sizes: The kernel block sizes for each KV cache group. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support mlu kv8 and deepseek v3.2 indexer + ''' + kv_caches: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + has_attn, has_mamba = False, False + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + if group.kv_cache_group_id == len(kernel_block_sizes): + # There may be a last group for layers without kv cache. + continue + kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + raw_tensor = kv_cache_raw_tensors[layer_name] + cache_, scale_, index_cache_ = raw_tensor + total_numel = cache_.numel() + scale_.numel() + index_cache_.numel() + assert total_numel % kv_cache_spec.page_size_bytes == 0 + num_blocks = total_numel // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True + num_blocks_per_kv_block = ( + kv_cache_spec.block_size // kernel_block_size + ) + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + kv_cache_shape = attn_backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) + dtype = kv_cache_spec.dtype + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + # The allocation respects the backend-defined stride order + # to ensure the semantic remains consistent for each + # backend. We first obtain the generic kv cache shape and + # then permute it according to the stride order which could + # result in a non-contiguous tensor. + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + cache_ = ( + cache_ + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) + # Reshape kv cache scale tensor + if dtype in [torch.int8, torch.uint8]: + kv_cache_scale_shape = attn_backend.get_kv_cache_scale_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + ) + scale_ = ( + scale_ + .view(torch.float32) + .view(kv_cache_scale_shape) + ) + # Reshape index_cache + if (isinstance(kv_cache_spec, MLUMLAAttentionSpec) + and kv_cache_spec.index_n_heads > 0): + index_cache_shape = ( + kernel_num_blocks, + kv_cache_spec.index_n_heads, + kernel_block_size, + kv_cache_spec.index_head_dim, + ) + index_cache_ = index_cache_.view(dtype).view(index_cache_shape) + kv_caches[layer_name] = [cache_, scale_, index_cache_] + elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + cache_ = kv_cache_raw_tensors[layer_name] + raw_tensor, scale_, index_cache_ = cache_ + state_tensors = [] + storage_offset_bytes = 0 + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size + ) + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + if self.mlu_config.enable_mamba_split_page_size: + num_element_per_page *= self.mlu_config.mamba_to_attn_block_ratio + num_blocks = self.mamba_block_num + ''' + ================== + End of MLU Hijack + ================== + ''' + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + + kv_caches[layer_name] = state_tensors + else: + raise NotImplementedError + ''' + ================== + End of MLU Hijack + ================== + ''' + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: support qwen3-next + ''' + if has_attn and has_mamba and not self.mlu_config.enable_mamba_split_page_size: + self._update_hybrid_attention_mamba_layout(kv_caches) + ''' + ================== + End of MLU Hijack + ================== + ''' + + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + kernel_block_sizes: The kernel block sizes for each KV cache group. + + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes + ) + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # validate all draft model layers belong to the same kv cache + # group + self.drafter.validate_same_kv_cache_group(kv_cache_config) + + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: bind kv cache to deepseek prefill attn + ''' + if self.model_config.is_deepseek_mla: + forward_context = self.vllm_config.compilation_config.static_forward_context + for layer_name, kv_cache in kv_caches.items(): + if layer_name.endswith("self_attn.mla_attn"): + layer_name = layer_name.replace( + "self_attn.mla_attn", "self_attn.attn") + forward_context[layer_name].kv_cache = [kv_cache] + + # matches self_attn.0.attn or self_attn.1.attn + if self.model_config.is_longcat_flash: + forward_context = self.vllm_config.compilation_config.static_forward_context + for layer_name, kv_cache in kv_caches.items(): + if (match := re.match(r".*self_attn\.(0|1)\.mla_attn$", layer_name)): + digit = match.group(1) # Extract the captured digit (0 or 1) + layer_name = layer_name.replace( + f"self_attn.{digit}.mla_attn", + f"self_attn.{digit}.attn" + ) + forward_context[layer_name].kv_cache = [kv_cache] + ''' + ================== + End of MLU Hijack + ================== + ''' + + return kv_caches + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + # block_size = self.vllm_config.cache_config.block_size + # use_mla = self.vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) + for layer_name, attn_module in attn_layers.items(): + if isinstance(attn_module, Attention) and ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ): + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: skip deepseek prefill attn init kv_cache + ''' + if ( + self.model_config.is_deepseek_mla + and layer_name.endswith("self_attn.attn") + ): + continue + # matches self_attn.0.attn or self_attn.1.attn + if ( + self.model_config.is_longcat_flash + and re.match(r".*self_attn\.(0|1)\.attn$", layer_name) + ): + continue + ''' + ================== + End of MLU Hijack + ================== + ''' + # Skip modules that don't need KV cache (eg encoder-only attention) + if spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec + + return kv_cache_spec + + def reset_capture_context(self, + prefill_enable_mlugraph: bool, + batch_size: int, + input_len: int): + self.graph_runners = {} + self.context_graph_runner = None + self.graph_memory_pool = None + + # reset prefill mlugraph infos + self.prefill_enable_mlugraph = prefill_enable_mlugraph + self.prefill_mlugraph_batch_size = batch_size + self.prefill_mlugraph_seq_len = input_len + + gc.collect() + torch.mlu.empty_cache() + + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor + ) -> None: + if self.valid_sampled_token_count_event is None: + return + + ''' + ============================= + Modify by vllm_mlu + @brief: replace current stream for MLU device. + ======= + ''' + default_stream = torch.mlu.current_stream() + # Initialize a new stream to overlap the copy operation with + # prepare_input of draft model. + with torch.mlu.stream(self.valid_sampled_token_count_copy_stream): + self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore + counts = valid_sampled_tokens_count + counts_cpu = self.valid_sampled_token_count_cpu + counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) + self.valid_sampled_token_count_event.record() + ''' + ================== + End of MLU Hijack + ================== + ''' + + self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) + + + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> tuple[ + dict[str, int], + LogprobsLists | None, + list[np.ndarray], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], + ]: + num_nans_in_logits = {} + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits(logits) + + discard_sampled_tokens_req_indices = self.discard_request_indices.np[ + : self.num_discarded_requests + ] + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + invalid_req_indices = [] + valid_sampled_token_ids: list[np.ndarray] + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)] = np.array([]) + else: + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + invalid_req_indices_set = set(invalid_req_indices) + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + # With spec decoding, this is done in propose_draft_token_ids(). + if self.input_batch.prev_sampled_token_ids is None: + assert sampled_token_ids.shape[-1] == 1 + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + req_ids = self.input_batch.req_ids + logprobs_tensors = sampler_output.logprobs_tensors + cu_num_accepted_tokens = ( + [0] if spec_decode_metadata and logprobs_tensors else None + ) + for req_idx in range(num_sampled_tokens): + sampled_ids: np.ndarray | None + if self.use_async_scheduling: + sampled_ids = ( + np.array([-1]) if req_idx not in invalid_req_indices_set else None + ) + else: + sampled_ids = valid_sampled_token_ids[req_idx] + + num_sampled_ids: int = ( + sampled_ids.shape[0] if sampled_ids is not None else 0 + ) + + if cu_num_accepted_tokens is not None: + cu_num_accepted_tokens.append( + cu_num_accepted_tokens[-1] + num_sampled_ids + ) + + if sampled_ids is None or num_sampled_ids == 0: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + num_sampled_ids + ''' + ============================= + Modify by vllm_mlu + @brief: end_idx may exceed max_model_len for sepculative tokens in MTP mode. + ======= + ''' + num_async_sched_tokens = 1 if self.use_async_scheduling else 0 + max_model_len = self.num_spec_tokens + self.max_model_len + num_async_sched_tokens + assert end_idx <= max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{max_model_len}" + ) + if end_idx > self.max_model_len: + end_idx = self.max_model_len + sampled_ids = sampled_ids[:end_idx - start_idx] + ''' + ================== + End of MLU Hijack + ================== + ''' + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + + req_id = req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + logprobs_lists = ( + logprobs_tensors.tolists(cu_num_accepted_tokens) + if not self.use_async_scheduling and logprobs_tensors is not None + else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + diff --git a/vllm_mlu/v1/worker/gpu_worker.py b/vllm_mlu/v1/worker/gpu_worker.py new file mode 100644 index 0000000..f9401c5 --- /dev/null +++ b/vllm_mlu/v1/worker/gpu_worker.py @@ -0,0 +1,638 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +# SPDX-License-Identifier: Apache-2.0 +"""A GPU worker class.""" +import copy +import gc +import os +from contextlib import AbstractContextManager, nullcontext +from types import NoneType +from typing import TYPE_CHECKING, Optional + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_tp_group, get_pp_group +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.utils import report_usage_stats +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.utils.mem_constants import GiB_bytes + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +from vllm_mlu.model_executor.warmup.kernel_warmup import kernel_warmup +from vllm_mlu.profiler.mlu_profiler import MluProfilerWrapper +from vllm_mlu.utils import MemorySnapshot, memory_profiling +from vllm_mlu._mlu_utils import VLLM_DUMP_MLU_INFO_EN +from vllm_mlu.device_allocator.cnmem import CnMemAllocator +from vllm_mlu.v1.worker.mlu_quant import MLUWorkerQuant +from vllm_mlu.v1.worker.gpu_model_runner import MLUModelRunner +from vllm_mlu.v1.worker.dp_gpu_model_runner import DPMLUModelRunner + +logger = init_logger(__name__) + + +class MLUWorker(Worker, MLUWorkerQuant): + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + + WorkerBase.__init__(self, vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker) + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils.import_utils import init_cached_hf_modules + init_cached_hf_modules() + + # Buffers saved before sleep + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) + logger.debug( + "Profiler config: record_shapes=%s," + "profile_memory=%s,with_stack=%s,with_flops=%s", + envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + envs.VLLM_TORCH_PROFILER_WITH_STACK, + envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.MLU, + ], + record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True + ), + ) + elif envs.VLLM_TORCH_CUDA_PROFILE: + self.profiler = MluProfilerWrapper() + else: + self.profiler = None + + def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.mlu.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() for name, buffer in model.named_buffers() + } + + allocator = CnMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.mlu.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + allocator = CnMemAllocator.get_instance() + allocator.wake_up(tags) + + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CnMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be used for one instance per process." + ) + context = allocator.use_memory_pool(tag=tag) + else: + context = nullcontext() + return context + + def init_device(self): + if self.device_config.device.type == "mlu": + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("CNCL_ASYNC_ERROR_HANDLING", None) + # if ( + # self.parallel_config.data_parallel_size > 1 + # and self.parallel_config.data_parallel_size_local > 0 + # and self.parallel_config.distributed_executor_backend + # not in ["ray", "external_launcher"] + # and self.vllm_config.parallel_config.data_parallel_backend != "ray" + # ): + # # Use local DP rank if available, otherwise use global DP rank. + # dp_local_rank = self.parallel_config.data_parallel_rank_local + # if dp_local_rank is None: + # dp_local_rank = self.parallel_config.data_parallel_rank + + # tp_pp_world_size = ( + # self.parallel_config.pipeline_parallel_size + # * self.parallel_config.tensor_parallel_size + # ) + + # # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK + # self.local_rank += dp_local_rank * tp_pp_world_size + # assert self.local_rank < torch.mlu.device_count(), ( + # f"DP adjusted local rank {self.local_rank} is out of bounds. " + # ) + + self.device = torch.device(f"mlu:{self.local_rank}") + current_platform.set_device(self.device) + + current_platform.check_if_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment BEFORE taking + # memory snapshot + # This ensures NCCL buffers are allocated before we measure + # available memory + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) + + # Set random seed. + set_random_seed(self.model_config.seed) + + gc.collect() + torch.mlu.empty_cache() + + # take current memory snapshot + self.init_snapshot = MemorySnapshot() + self.requested_memory = ( + self.init_snapshot.total_memory + * self.cache_config.gpu_memory_utilization + ) + if self.init_snapshot.free_memory < self.requested_memory: + GiB = lambda b: round(b / GiB_bytes, 2) + raise ValueError( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " + f"is less than desired GPU memory utilization " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " + f"utilization or reduce GPU memory used by other processes." + ) + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Construct the model runner + model_runner_cls = (DPMLUModelRunner + if self._enable_moe_dp_opt() else MLUModelRunner) + self.model_runner: MLUModelRunner = model_runner_cls( + self.vllm_config, self.device) + + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) + + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the free memory that can be used for KV cache in + bytes. + + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + GiB = lambda b: b / GiB_bytes + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + msg = ( + f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly." + ) + logger.info(msg) + return kv_cache_memory_bytes + + torch.mlu.empty_cache() + torch.mlu.reset_peak_memory_stats() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: + self.model_runner.profile_run() + + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + + free_gpu_memory = profile_result.after_profile.free_memory + GiB = lambda b: b / GiB_bytes + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.init_snapshot, + weights_memory=int( + self.model_runner.model_memory_usage)) as profile_result: + self.model_runner.profile_run() + + free_gpu_memory = profile_result.after_profile.free_memory + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_snapshot.free_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " + f"current free memory {GiB(free_gpu_memory)} GiB. " + "This happens when other processes sharing the same container " + "release GPU memory while vLLM is profiling during initialization. " + "To fix this, ensure consistent GPU memory allocation or " + "isolate vLLM in its own container." + ) + self.available_kv_cache_memory_bytes = ( + self.requested_memory - profile_result.non_kv_cache_memory + ) + + unrequested_memory = self.init_snapshot.free_memory - self.requested_memory + logger.debug( + "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB", + GiB(self.init_snapshot.free_memory), + self.cache_config.gpu_memory_utilization, + GiB(self.requested_memory), + ) + logger.debug( + "Free memory after profiling: %.2f GiB (total), " + "%.2f GiB (within requested)", + GiB(free_gpu_memory), + GiB(free_gpu_memory - unrequested_memory), + ) + logger.debug(profile_result) + logger.info_once( + "Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + scope="local", + ) + gc.collect() + + self.peak_memory = profile_result.non_kv_cache_memory + self.block_memory = self.available_kv_cache_memory_bytes + + + return int(self.available_kv_cache_memory_bytes) + + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + + # Init kv cache connector here, because it requires + # `kv_cache_config`. + # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, + # because `initialize_kv_cache` will inject kv cache groups not + # related to kv cache connector (e.g. kv cache sharing layers). + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) + + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CnMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + context = nullcontext() + with context: + self.model_runner.initialize_kv_cache(kv_cache_config) + + def compile_or_warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes + if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + # We skip EPLB here since we don't want to record dummy metrics + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False) + self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) + + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) + + cuda_graph_memory_bytes = 0 + if not self.model_config.enforce_eager: + cuda_graph_memory_bytes = self.model_runner.capture_model() + + if self.cache_config.kv_cache_memory_bytes is None and hasattr( + self, "peak_activation_memory" + ): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + non_kv_cache_memory = ( + self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes + ) + kv_cache_memory_bytes_to_gpu_limit = ( + self.init_snapshot.free_memory + - non_kv_cache_memory + - redundancy_buffer_memory + ) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) + - non_kv_cache_memory + - redundancy_buffer_memory + ) + + msg = ( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` " + f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit " + f"into requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` " + f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{GiB(self.available_kv_cache_memory_bytes)} GiB." + ) + + logger.debug(msg) + + # Warm up sampler and preallocate memory buffer for logits and other + # sampling related tensors of max possible shape to avoid memory + # fragmentation issue. + # NOTE: This is called after `capture_model` on purpose to prevent + # memory buffers from being cleared by `torch.cuda.empty_cache`. + if get_pp_group().is_last_rank: + max_num_reqs = min( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + ) + + # We skip EPLB here since we don't want to record dummy metrics + hidden_states, last_hidden_states = self.model_runner._dummy_run( + num_tokens=max_num_reqs, + skip_eplb=True, + ) + if self.model_runner.is_pooling_model: + self.model_runner._dummy_pooler_run(hidden_states) + else: + self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states) + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def execute_model( + self, scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput | None: + intermediate_tensors = None + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + if forward_pass and not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + ) + + with self.annotate_profile(scheduler_output): + output = self.model_runner.execute_model( + scheduler_output, intermediate_tensors + ) + if isinstance(output, (ModelRunnerOutput, NoneType)): + return output + + assert isinstance(output, IntermediateTensors) + parallel_config = self.vllm_config.parallel_config + assert ( + parallel_config.distributed_executor_backend != "external_launcher" + and not get_pp_group().is_last_rank + ) + + get_pp_group().send_tensor_dict( + output.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + + return None + + def _enable_moe_dp_opt(self): + ''' + We will enable the MLU-optimized DP scheme for the specified MoE models, + otherwise the native DP implementation will be used. + ''' + # case0 enable data parallel + enable_dp = self.parallel_config.data_parallel_size > 1 + # case1 ds mla + is_ds_mla = self.model_config.is_deepseek_mla + # case2 qwen3 moe + is_supported_moe_model = hasattr(self.model_config.hf_text_config, "model_type") and \ + self.model_config.hf_text_config.model_type in ('qwen3_moe', 'glm4_moe') + # case 3, private model + is_private_model = getattr(self.model_config.hf_config, "is_private", False) + return enable_dp and (is_ds_mla or is_supported_moe_model or is_private_model) + + def execute_dummy_batch(self) -> None: + if self._enable_moe_dp_opt(): + self.model_runner.moe_dp_execute_dummy_batch(1) + else: + self.model_runner._dummy_run(1, uniform_decode=True) + + def response_remote_alloc_once(self) -> None: + self.model_runner.response_remote_alloc_once() + + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: + from vllm.distributed.parallel_state import get_ep_group + + if get_ep_group().rank == 0: + logger.info( + "[Elastic EP] Starting expert resharding before scaling down..." + ) + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange( + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping, + ) + torch.mlu.synchronize() + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + from vllm.config import set_current_vllm_config + from vllm.distributed.parallel_state import ( + cleanup_dist_env_and_memory, + get_ep_group, + ) + + old_ep_size = get_ep_group().world_size + old_ep_rank = get_ep_group().rank + new_ep_size = ( + reconfig_request.new_data_parallel_size + * get_tp_group().world_size + * get_pp_group().world_size + ) + if new_ep_size < old_ep_size: + self._eplb_before_scale_down(old_ep_size, new_ep_size) + + cleanup_dist_env_and_memory() + + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): + assert old_ep_rank >= new_ep_size + # shutdown + return + + self._reconfigure_parallel_config(reconfig_request) + + with set_current_vllm_config(self.vllm_config): + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) + + global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size) + + if new_ep_size > old_ep_size: + assert global_expert_loads is not None + self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads) + + def get_hfu_info(self, batch, input_len, output_len): + try: + self.model_runner.model.collect_hfu_io_effciency_info(batch, input_len, output_len) + if VLLM_DUMP_MLU_INFO_EN: + return self.model_runner.model.hfu_info, self.model_runner.model.io_efficiency + else: + return self.model_runner.model.flops_info, 0.0 + except Exception as e: + raise RuntimeError( + "Model match failure when get HFU info, please check if an init method was registed." + ) + + def _get_latency(self, time_markers): + total_latency = 0 + if not isinstance(time_markers, list): + time_markers = [time_markers] + for time_marker in time_markers: + start, end = time_marker + latency = start.elapsed_time(end) + total_latency += latency + return total_latency + + def get_latency(self): + return self._get_latency(self.model_runner.time_markers) + + def get_mm_encoder_latency(self): + if not hasattr(self.model_runner, "mm_time_markers"): + return None + mm_time_markers = self.model_runner.mm_time_markers + return None if len(mm_time_markers) == 0 else\ + self._get_latency(mm_time_markers) + + def get_memory_usage(self): + return (self.peak_memory, self.block_memory) + + def recapture_model(self, + prefill_enable_mlugraph: bool, + batch_size: int, + input_len: int): + # Reset history capture context + self.model_runner.reset_capture_context( + prefill_enable_mlugraph, batch_size, input_len) + # Re-capture decode graph(full graph or peicewise graph) + self.compile_or_warm_up_model() diff --git a/vllm_mlu/v1/worker/kv_connector_model_runner_mixin.py b/vllm_mlu/v1/worker/kv_connector_model_runner_mixin.py new file mode 100644 index 0000000..67351d9 --- /dev/null +++ b/vllm_mlu/v1/worker/kv_connector_model_runner_mixin.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project +""" +Define KV connector functionality mixin for model runners. +""" + +import copy +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import ( + TYPE_CHECKING, # noqa: UP035 +) + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_shutdown, + get_kv_transfer_group, + has_kv_transfer_group, +) +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.logger import init_logger +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +from vllm_mlu.mlu_hijack_utils import MluHijackObject +logger = init_logger(__name__) + + +# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) +class KVConnectorModelRunnerMixin_MluHijack(KVConnectorModelRunnerMixin): + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + kv_connector.request_remote_memory_send() + ''' + ================== + End of MLU Hijack + ================== + ''' + + # This context manager must be used within an active forward context. + # It encapsulates the entire KV connector lifecycle within execute_model + @staticmethod + @contextmanager + def _get_kv_connector_output( + scheduler_output: "SchedulerOutput", wait_for_save: bool = True + ) -> Generator[KVConnectorOutput, None, None]: + output = KVConnectorOutput() + + # Update KVConnector with the KVConnector metadata forward(). + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + ''' + ============================= + Modify by vllm_mlu + ============================= + @brief: supoort disagg for mlu. + ''' + kv_connector.request_remote_memory_send() + ''' + ================== + End of MLU Hijack + ================== + ''' + + try: + yield output + finally: + output.finished_sending, output.finished_recving = ( + kv_connector.get_finished(scheduler_output.finished_req_ids) + ) + output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() + + output.kv_connector_stats = ( + KVConnectorModelRunnerMixin.get_kv_connector_stats() + ) + + +MluHijackObject.apply_hijack(KVConnectorModelRunnerMixin, + KVConnectorModelRunnerMixin.maybe_setup_kv_connector, + KVConnectorModelRunnerMixin_MluHijack.maybe_setup_kv_connector) +MluHijackObject.apply_hijack(KVConnectorModelRunnerMixin, + KVConnectorModelRunnerMixin._get_kv_connector_output, + KVConnectorModelRunnerMixin_MluHijack._get_kv_connector_output) \ No newline at end of file diff --git a/vllm_mlu/v1/worker/lora_model_runner_mixin.py b/vllm_mlu/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000..fde085b --- /dev/null +++ b/vllm_mlu/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +from typing import List + +from vllm.lora.request import LoRARequest +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +from vllm_mlu.mlu_hijack_utils import MluHijackObject + + +def vllm_mlu__v1__worker__LoRAModelRunnerMixin__add_dummy_loras(self, num_loras: int) -> List[LoRARequest]: + assert num_loras > 0 + assert self.lora_manager is not None + + dummy_lora_requests: list[LoRARequest] = [] + with self.lora_manager.dummy_lora_cache(): + for idx in range(num_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"capture_graph_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=self.LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + return dummy_lora_requests + + +MluHijackObject.apply_hijack(LoRAModelRunnerMixin, + "add_dummy_loras", + vllm_mlu__v1__worker__LoRAModelRunnerMixin__add_dummy_loras) diff --git a/vllm_mlu/v1/worker/mlu_quant.py b/vllm_mlu/v1/worker/mlu_quant.py new file mode 100644 index 0000000..03a7933 --- /dev/null +++ b/vllm_mlu/v1/worker/mlu_quant.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project + +"""A MLU quant class.""" +import functools +from collections import defaultdict +from typing import Dict, Any, List, Optional, Union + +import numpy as np +import torch +import torch.distributed + +from vllm.distributed import ( + get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, + get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +import vllm.envs as envs +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding, + ParallelLMHead) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention +from vllm_mlu.model_executor.layers.feed_forward import FeedForward +from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def default_act_range_value(): + return { + "x": None, + "split": None, + "is_linear": False, + "is_qkv": False, + "q_proj_size": 0, + "num_kv_head_replicas": 1, + "is_merge": False, + "input_id": [], + "self_rank": 0, + "rank": None, + "tensor_rank": None, + "tp_world_size": None, + "moe_tp_rank": None, + "moe_tp_world_size": None, + "moe_ep_rank": None, + "moe_ep_world_size": None, + "weight": None, + } + + +def _str_to_torch_dtype(dtype: str) -> torch.dtype: + dtype = dtype.split(".")[-1] + # STR_DTYPE_TO_TORCH_DTYPE dict does not have float16 type + return STR_DTYPE_TO_TORCH_DTYPE[dtype] if dtype != "float16" else torch.float16 + + +class ActRangeValue: + """ + ActRangeValue for v1 MsgpackEncoder and MsgpackDecoder. This is a *WorkAround*. + The decode tensor can be wrong if we pass act range dict directly. + + NOTE: here, we transfer torch.Tensor to numpy ndarray because torch.Tensor + may cause core dump. + """ + def __init__(self): + self.layer_name: str = "" + self.x: Optional[np.ndarray] = None + self.split: str = None + self.is_linear: bool = False + self.is_qkv: bool = False + self.q_proj_size: int = 0 + self.num_kv_head_replicas: int = 1 + self.is_merge: bool = False + self.input_id_dtype: str = None + self.input_id: Optional[List[np.ndarray]] = [] + self.self_rank: int = 0 + self.rank: Optional[int] = None + self.tensor_rank: Optional[int] = None + self.tp_world_size: Optional[int] = None + self.moe_tp_rank: Optional[int] = None + self.moe_tp_world_size: Optional[int] = None + self.moe_ep_rank: Optional[int] = None + self.moe_ep_world_size: Optional[int] = None + self.weight: Optional[np.ndarray] = None + self.weight_dtype: str = None + + @classmethod + def serial(cls, layer_name: str, act_range: Dict[str, Any]) -> 'ActRangeValue': + instance = cls() + instance.layer_name = layer_name + instance.x = act_range.get("x") + instance.split = act_range.get("split") + instance.is_linear = act_range.get("is_linear", False) + instance.is_qkv = act_range.get("is_qkv", False) + instance.q_proj_size = act_range.get("q_proj_size", 0) + instance.num_kv_head_replicas = act_range.get("num_kv_head_replicas", 1) + instance.is_merge = act_range.get("is_merge", False) + instance.input_id = act_range.get("input_id", []) + instance.self_rank = act_range.get("self_rank", 0) + instance.rank = act_range.get("rank") + instance.tensor_rank = act_range.get("tensor_rank") + instance.tp_world_size = act_range.get("tp_world_size") + instance.moe_tp_rank = act_range.get("moe_tp_rank") + instance.moe_tp_world_size = act_range.get("moe_tp_world_size") + instance.moe_ep_rank = act_range.get("moe_ep_rank") + instance.moe_ep_world_size = act_range.get("moe_ep_world_size") + instance.weight = act_range.get("weight") + + if instance.x is not None: + instance.x = instance.x.numpy() + # input_id and weight are used for debug + if isinstance(instance.input_id, torch.Tensor): + instance.input_id_dtype = str(instance.input_id.dtype) + instance.input_id = instance.input_id.float().numpy() + else: + input_id_np = [] + for input_id in instance.input_id: + instance.input_id_dtype = str(input_id.dtype) + input_id_np.append(input_id.float().numpy()) + instance.input_id = input_id_np + if instance.weight is not None: + instance.weight_dtype = str(instance.weight.dtype) + instance.weight = instance.weight.float().numpy() + + return instance + + def deserial(self) -> Dict[str, Any]: + act_range = self.to_dict() + if self.x is not None: + act_range["x"] = torch.from_numpy(self.x) + if self.input_id is not None: + if isinstance(self.input_id, torch.Tensor): + act_range["input_id"] = torch.from_numpy(self.input_id).to( + _str_to_torch_dtype(self.input_id_dtype)) + else: + input_id_tensor = [] + for input_id in self.input_id: + input_id_tensor.append(torch.from_numpy(input_id).to( + _str_to_torch_dtype(self.input_id_dtype))) + act_range["input_id"] = input_id_tensor + if self.weight_dtype is not None: + act_range["weight"] = torch.from_numpy(self.weight).to( + _str_to_torch_dtype(self.weight_dtype)) + return act_range + + def to_dict(self) -> Dict[str, Any]: + return { + "x": self.x, + "split": self.split, + "is_linear": self.is_linear, + "is_qkv": self.is_qkv, + "q_proj_size": self.q_proj_size, + "num_kv_head_replicas": self.num_kv_head_replicas, + "is_merge": self.is_merge, + "input_id": self.input_id, + "self_rank": self.self_rank, + "rank": self.rank, + "tensor_rank": self.tensor_rank, + "tp_world_size": self.tp_world_size, + "moe_tp_rank": self.moe_tp_rank, + "moe_tp_world_size": self.moe_tp_world_size, + "moe_ep_rank": self.moe_ep_rank, + "moe_ep_world_size": self.moe_ep_world_size, + "weight": self.weight, + } + + def __repr__(self) -> str: + return f"layer: {self.layer_name}, ActRangeValue({self.to_dict()})" + + +class MLUWorkerQuant(object): + ''' + mlu quant + ''' + def stat_tensor(self, name, tensor, act_range, key, dim): + logger.debug(f"name:{name}, key:{key}, dim:{dim}, tensor.shape:{tensor.shape}") + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs() + comming_max = torch.max(tensor, dim=dim)[0].float() + + if act_range[name][key] is None: + act_range[name][key] = comming_max + else: + act_range[name][key] = torch.max(act_range[name][key], comming_max) + + def stat_input_hook(self, m, x, y, name, act_range, is_linear, is_save_input_id): + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + logger.debug(f"name:{name}, x.shape:{x.shape}, y.shape:{y.shape}, m.weight.shape:{m.weight.shape}") + if is_linear: + self.stat_tensor(name, x, act_range, "x", 0) + if act_range[name]["is_qkv"] and is_save_input_id and ".0." in name: + x_cpu = x.clone().to("cpu") + act_range[name]["input_id"].append(x_cpu) + + def setup_smooth_hook(self, is_save_input_id: bool = False, is_save_moe_info: bool = False): + models = [self.model_runner.model] + if hasattr(self.model_runner, "drafter") and self.model_runner.drafter is not None: + models += [self.model_runner.drafter.model] + self.act_range = defaultdict(default_act_range_value) + self.hooks = [] + linear_class_list = (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) + other_class_list = (VocabParallelEmbedding, ParallelLMHead) + class_list = linear_class_list + other_class_list + row_class_list = (RowParallelLinear) + + for model in models: + for name, m in model.named_modules(): + if isinstance(m, FeedForward): + m.use_bt_ffn = False + if isinstance(m, SparseMoeMlp): + m.is_use_fused_moe = False + if isinstance(m, DeepseekV2MLAAttention): + m.use_fused_mla_qkv = False + + if isinstance(m, class_list): + is_linear = True if isinstance(m, linear_class_list) else False + split_type = "row" if isinstance(m, row_class_list) else "col" + self.act_range[name]["split"] = split_type + self.act_range[name]["is_linear"] = is_linear + if isinstance(m, QKVParallelLinear): + self.act_range[name]["is_qkv"] = True + self.act_range[name]["q_proj_size"] = m.num_heads * m.head_size + self.act_range[name]["num_kv_head_replicas"] = m.num_kv_head_replicas + self.act_range[name]["is_merge"] = isinstance(m, MergedColumnParallelLinear) + if is_save_moe_info: + self.act_range[name]["rank"] = torch.distributed.get_rank() + self.act_range[name]["tensor_rank"] = get_tensor_model_parallel_rank() + self.act_range[name]["tp_world_size"] = get_tensor_model_parallel_world_size() + self.act_range[name]["moe_tp_rank"] = get_moe_tensor_parallel_rank() + self.act_range[name]["moe_tp_world_size"] = get_moe_tensor_parallel_world_size() + self.act_range[name]["moe_ep_rank"] = get_moe_expert_parallel_rank() + self.act_range[name]["moe_ep_world_size"] = get_moe_expert_parallel_world_size() + if ".expert." in name: + self.act_range[name]["weight"] = m.weight + logger.info(f"rank:{self.rank}, add hook to {name}, is_linear:{is_linear}, split_type:{split_type}") + self.hooks.append(m.register_forward_hook(functools.partial(self.stat_input_hook, + name=name, act_range=self.act_range, + is_linear=is_linear, + is_save_input_id=is_save_input_id))) + + def remove_hooks(self): + for h in self.hooks: + h.remove() + + def get_act_range(self): + act_range = defaultdict(default_act_range_value) + for layer_name, layer_range in self.act_range.items(): + for tensor_key, tensor_value in layer_range.items(): + if isinstance(tensor_value, torch.Tensor): + act_range[layer_name][tensor_key] = tensor_value.to("cpu") + elif tensor_key == "input_id" and isinstance(tensor_value, list): + input_id_len = len(tensor_value) + for i in range(input_id_len): + if isinstance(tensor_value[i], torch.Tensor): + act_range[layer_name][tensor_key].append(tensor_value[i].to("cpu")) + else: + act_range[layer_name][tensor_key].append(tensor_value[i]) + else: + act_range[layer_name][tensor_key] = tensor_value + + serialization_result = [] + for layer_name, layer_range in act_range.items(): + serialization_result.append(ActRangeValue.serial(layer_name, layer_range)) + return serialization_result + + @torch.no_grad() + def get_named_parameters(self): + name_parameters = {} + for name, param in self.model_runner.model.named_parameters(): + name_parameters[name] = param.to("cpu") + + return name_parameters