Initial commit for vLLM-Kunlun Plugin

This commit is contained in:
dongxinyu03
2025-12-10 12:05:39 +08:00
commit c728e52505
131 changed files with 28816 additions and 0 deletions

53
.gitignore vendored Normal file
View File

@@ -0,0 +1,53 @@
# Virtualenv
/.venv/
/venv/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
/bin/
/build/
/develop-eggs/
/dist/
/eggs/
/lib/
/lib64/
/output/
/parts/
/sdist/
/var/
/*.egg-info/
/.installed.cfg
/*.egg
/.eggs
# AUTHORS and ChangeLog will be generated while packaging
/AUTHORS
/ChangeLog
# BCloud / BuildSubmitter
/build_submitter.*
/logger_client_log
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
.tox/
.coverage
.cache
.pytest_cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Sphinx documentation
docs/_build

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.10.10

16
.readthedocs.yaml Normal file
View File

@@ -0,0 +1,16 @@
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
sphinx:
configuration: docs/source/conf.py
fail_on_warning: false
formats: []
python:
install:
- requirements: docs/requirements-docs.txt

24
CHANGELOG.md Normal file
View File

@@ -0,0 +1,24 @@
Changelog
===# Change Chinese to English comments
The following records all changes worth noting in the project, formatted based on [Keep a Changelog].
This project version follows [Semantic Versioning] and [PEP-440].
[Unreleased]
---
### Added
- This records new content added
### Changed
- This records changed content
0.1.0 - 2025-08-12
---
### Added
- Create project
[Unreleased]: http://icode.baidu.com/repos/baidu/hac-aiacc/vllm-kunlun/merge/0.1.0...master
[Keep a Changelog]: https://keepachangelog.com/zh-CN/1.0.0/
[Semantic Versioning]: https://semver.org/lang/zh-CN/
[PEP-440]: https://www.python.org/dev/peps/pep-0440/

201
LICENSE.txt Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

314
README.md Normal file
View File

@@ -0,0 +1,314 @@
![vLLM Kunlun Logo](vllm_kunlun/patches/vLLM_Kunlun.jpg)
<p align="center">
<a href="./docs/_build/html/documentation.html"><b>Documentation</b></a> |
<a href=""><b>Users Forum</b></a> |
<a href="join.slack.com/t/vllm-kunlun/shared_invite/zt-3iinb8u5z-FcqZKbNNdMJ_32fHmipzvwjoin.slack.com/t/vllm-kunlun/shared_invite/zt-3iinb8u5z-FcqZKbNNdMJ_32fHmipzvw"><b>slack</b></a> |
</p>
---
## Latest News🔥
- [2025/11]
- [2025/11]
- [2025/11]
- [2025/11] Initial release of vLLM Kunlun
---
# Overview
vLLM Kunlun (vllm-kunlun) is a community-maintained hardware plugin designed to seamlessly run vLLM on the Kunlun XPU. It is the recommended approach for integrating the Kunlun backend within the vLLM community, adhering to the principles outlined in the [RFC]: Hardware pluggable. This plugin provides a hardware-pluggable interface that decouples the integration of the Kunlun XPU with vLLM.
By utilizing the vLLM Kunlun plugin, popular open-source models, including Transformer-like, Mixture-of-Expert, Embedding, and Multi-modal LLMs, can run effortlessly on the Kunlun XPU.
---
## Prerequisites
- **Hardware**: Kunlun3 P800
- **OS**: Ubuntu 22.04
- **Software**:
- Python >=3.10
- PyTorch ≥ 2.5.1
- vLLM (same version as vllm-kunlun)
---
## Supported Models
<style>
table {
width: 100%;
border-collapse: collapse;
background: white;
margin: 20px 0;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border-radius: 8px;
overflow: hidden;
}
th {
background: linear-gradient(135deg, #0E7DC6 0%, #0A5BA8 100%);
color: white;
padding: 14px 12px;
text-align: left;
font-weight: 600;
font-size: 13px;
letter-spacing: 0.5px;
border: none;
}
td {
padding: 12px;
border-bottom: 1px solid #e8e8e8;
font-size: 13px;
color: #333;
}
tr:last-child td {
border-bottom: none;
}
tbody tr {
transition: background-color 0.2s ease;
}
tbody tr:hover {
background-color: #f5faff;
}
tbody tr:nth-child(even) {
background-color: #fafbfc;
}
tbody tr:nth-child(even):hover {
background-color: #f0f7fc;
}
.status-support {
color: #22863a;
font-weight: 600;
font-size: 14px;
}
.status-progress {
color: #f6a909;
font-weight: 600;
font-size: 14px;
}
.status-coming {
color: #999;
font-size: 12px;
background-color: #f5f5f5;
padding: 2px 6px;
border-radius: 3px;
display: inline-block;
}
.model-name {
font-weight: 500;
color: #1e40af;
}
h3 {
color: #1e40af;
font-size: 16px;
margin-top: 30px;
margin-bottom: 15px;
font-weight: 600;
}
h3:first-of-type {
margin-top: 0;
}
</style>
<h3>Generaltive Models</h3>
<table>
<thead>
<tr>
<th width="20%">Model</th>
<th width="12%">Support</th>
<th width="15%">Quantization</th>
<th width="10%">LoRA</th>
<th width="20%">Piecewise Kunlun Graph</th>
<th width="23%">Note</th>
</tr>
</thead>
<tbody>
<tr>
<td class="model-name">Qwen2/2.5</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen3</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen3-Moe/Coder</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">QwQ-32B</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">LLama2/3/3.1</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">GLM-4.5/Air</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen3next</td>
<td class="status-progress">⚠️</td>
<td></td>
<td></td>
<td></td>
<td><span class="status-coming">comming soon</span></td>
</tr>
<tr>
<td class="model-name">Gpt oss</td>
<td class="status-progress">⚠️</td>
<td></td>
<td></td>
<td></td>
<td><span class="status-coming">comming soon</span></td>
</tr>
<tr>
<td class="model-name">Deepseek v3/3.2</td>
<td class="status-progress">⚠️</td>
<td></td>
<td></td>
<td></td>
<td><span class="status-coming">comming soon</span></td>
</tr>
</tbody>
</table>
<h3>Multimodal Language Models</h3>
<table>
<thead>
<tr>
<th width="20%">Model</th>
<th width="12%">Support</th>
<th width="15%">Quantization</th>
<th width="10%">LoRA</th>
<th width="20%">Piecewise Kunlun Graph</th>
<th width="23%">Note</th>
</tr>
</thead>
<tbody>
<tr>
<td class="model-name">Qianfan-VL</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen2.5VL</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">InternVL2.5/3/3.5</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">InternVL3.5</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">InternS1</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen2.5 omini</td>
<td class="status-progress">⚠️</td>
<td></td>
<td></td>
<td></td>
<td><span class="status-coming">comming soon</span></td>
</tr>
<tr>
<td class="model-name">Qwen3vl</td>
<td class="status-progress">⚠️</td>
<td></td>
<td></td>
<td></td>
<td><span class="status-coming">comming soon</span></td>
</tr>
</tbody>
</table>
## Performance Visualization 🚀
### High-performance computing at work: How different models perform on the Kunlun3 P800.
Current environment: 16-way concurrency, input/output size 2048.
![Models and tgs](./vllm_kunlun/patches/performance.png)
## Getting Started
Please use the following recommended versions to get started quickly:
| Version | Release type | Doc |
|----------|---------------|-----|
| v0.10.1.1 | Latest stable version | [QuickStart](./docs/_build/html/quick_start.html) and [Installation](./docs/_build/html/installation.html) for more details |
---
## Contributing
See [CONTRIBUTING]() for more details, which is a step-by-step guide to help you set up the development environment, build, and test.
We welcome and value any contributions and collaborations:
- Open an [Issue]() if you find a bug or have a feature request
## License
Apache License 2.0, as found in the [LICENSE](./LICENSE) file.

25
build.sh Normal file
View File

@@ -0,0 +1,25 @@
#!/bin/bash
set -euo pipefail
echo "========= build enter ========="
echo "$PATH"
WORK_DIR=$(cd $(dirname $0) && pwd) && cd $WORK_DIR
echo_cmd() {
echo $1
$1
}
echo "========= build vllm ========="
echo_cmd "rm -rf output"
echo_cmd "mkdir -p output"
cd ${WORK_DIR}
rm -rf output/.scm/
tar -zcvf ../vllm-kunlun.tar.gz ../vllm-kunlun/
mv ../vllm-kunlun.tar.gz ./output/
echo "========= build exit ========="

25
docs/Makefile Normal file
View File

@@ -0,0 +1,25 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
intl:
sphinx-intl build
@$(SPHINXBUILD) -b html -D language=zh_CN "$(SOURCEDIR)" "$(BUILDDIR)/html/zh-cn" $(SPHINXOPTS) $(O)

57
docs/README.md Normal file
View File

@@ -0,0 +1,57 @@
## 🚀 Installation
```bash
uv venv myenv --python 3.12 --seed
source myenv/bin/activate
# Step 1: Enter the docs directory
cd docs
# Step 2: Install dependencies (using uv)
uv pip install -r requirements-docs.txt
# Install sphinx-autobuild (if not in requirements file)
uv pip install sphinx-autobuild
# Run from the docs directory:
sphinx-autobuild ./source ./_build/html --port 8000
# Step 1: Clean up old files
make clean
# Step 2: Build HTML
make html
# Step 3: Local preview
python -m http.server -d _build/html/
Browser access: http://localhost:8000
🌍 Internationalization
Internationalization translation process (taking Chinese as an example)
# Step 1: Extract translatable text (generate .pot)
sphinx-build -b gettext source _build/gettext
# Step 2: Generate/update Chinese .po file
sphinx-intl update -p _build/gettext -l zh_CN
# Step 3: Manually translate .po file
# Use a text editor to open source/locale/zh_CN/LC_MESSAGES/*.po
# Fill in the Chinese translation in msgstr ""
# Step 4: Compile and build Chinese documentation
make intl
# Step 5: View the effect
python -m http.server -d _build/html
Browser access:
English version: http://localhost:8000
Chinese version: http://localhost:8000/zh-cn
```

183
docs/envs.py Normal file
View File

@@ -0,0 +1,183 @@
#
# Copyright (c) 2025 Baidu Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-kunlun project.
#
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from typing import Any, Callable, Dict
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
env_variables: Dict[str, Callable[[], Any]] = {
# max compile thread number for package building. Usually, it is set to
# the number of CPU cores. If not set, the default value is None, which
# means all number of CPU cores will be used.
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
# The build type of the package. It can be one of the following values:
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
# Whether to compile custom kernels. If not set, the default value is True.
# If set to False, the custom kernels will not be compiled. Please note that
# the sleep mode feature will be disabled as well if custom kernels are not
# compiled.
"COMPILE_CUSTOM_KERNELS": lambda: bool(
int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))
),
# The CXX compiler used for compiling the package. If not set, the default
# value is None, which means the system default CXX compiler will be used.
"CXX_COMPILER": lambda: os.getenv("CXX_COMPILER", None),
# The C compiler used for compiling the package. If not set, the default
# value is None, which means the system default C compiler will be used.
"C_COMPILER": lambda: os.getenv("C_COMPILER", None),
"SOC_VERSION": lambda: os.getenv("SOC_VERSION", "KUNLUNP800"),
# If set, vllm-kunlun will print verbose logs during compilation
"VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))),
# /usr/local/Kunlun/kunlun-toolkit/latest
"KUNLUN_HOME_PATH": lambda: os.getenv("KUNLUN_HOME_PATH", None),
# The path for XCCL library, it's used by pyxccl communicator backend. If
# not set, the default value is libxccl.so。
"XCCL_SO_PATH": lambda: os.environ.get("XCCL_SO_PATH", None),
# The version of vllm is installed. This value is used for developers who
# installed vllm from source locally. In this case, the version of vllm is
# usually changed. For example, if the version of vllm is "0.9.0", but when
# it's installed from source, the version of vllm is usually set to "0.9.1".
# In this case, developers need to set this value to "0.9.0" to make sure
# that the correct package is installed.
"VLLM_VERSION": lambda: os.getenv("VLLM_VERSION", None),
# Whether to enable the trace recompiles from pytorch.
"VLLM_KUNLUN_TRACE_RECOMPILES": lambda: bool(
int(os.getenv("VLLM_KUNLUN_TRACE_RECOMPILES", "0"))
),
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": lambda: bool(
int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", "0"))
),
# Whether to enable the model execute time observe profile. Disable it when
# running vllm kunlun in production environment.
"VLLM_KUNLUN_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(
int(os.getenv("VLLM_KUNLUN_MODEL_EXECUTE_TIME_OBSERVE", "0"))
),
# Some models are optimized by vllm kunlun. While in some case, e.g. rlhf
# training, the optimized model may not be suitable. In this case, set this
# value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv("USE_OPTIMIZED_MODEL", "1"))),
# The tolerance of the kv cache size, if the difference between the
# actual kv cache size and the cached kv cache size is less than this value,
# then the cached kv cache size will be used.
"VLLM_KUNLUN_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE": lambda: int(
os.getenv("VLLM_KUNLUN_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)
),
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
# We'll remove this flag in the future once it's stable enough.
"VLLM_KUNLUN_ENABLE_TOPK_TOPP_OPTIMIZATION": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_TOPK_TOPP_OPTIMIZATION", "1"))
),
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
# used for llmdatadist to build the communication topology for kv cache transfer, it is
# a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated
# pd. The rank table can be generated by adopting the script `gen_ranktable.sh`
# in vllm_kunlun's example folder.
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH": lambda: os.getenv(
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None
),
# `LLMDataDistCMgrConnector` required variable. `VLLM_KUNLUN_LLMDD_RPC_IP` is used as the
# rpc communication listening ip, which will be used to receive the agent metadata from the
# remote worker.
"VLLM_KUNLUN_LLMDD_RPC_IP": lambda: os.getenv(
"VLLM_KUNLUN_LLMDD_RPC_IP", "0.0.0.0"
),
# `LLMDataDistCMgrConnector` required variable. `VLLM_KUNLUN_LLMDD_RPC_PORT` is used as the
# rpc communication listening port, which will be used to receive the agent metadata from the
# remote worker.
"VLLM_KUNLUN_LLMDD_RPC_PORT": lambda: int(
os.getenv("VLLM_KUNLUN_LLMDD_RPC_PORT", 5557)
),
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
# and the mla_pa will be the default path of deepseek decode path.
"VLLM_KUNLUN_MLA_PA": lambda: int(os.getenv("VLLM_KUNLUN_MLA_PA", 0)),
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
"VLLM_KUNLUN_ENABLE_MATMUL_ALLREDUCE": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_MATMUL_ALLREDUCE", "0"))
),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large.
"VLLM_KUNLUN_ENABLE_FLASHCOMM1": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_FLASHCOMM1", "0"))
),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_KUNLUN_ENABLE_PREFETCH_MLP": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_PREFETCH_MLP", "0"))
),
# buffer size for gate up prefetch
"VLLM_KUNLUN_MLP_GATE_UP_PREFETCH_SIZE": lambda: int(
os.getenv("VLLM_KUNLUN_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)
),
# buffer size for down proj prefetch
"VLLM_KUNLUN_MLP_DOWN_PREFETCH_SIZE": lambda: int(
os.getenv("VLLM_KUNLUN_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)
),
# Whether to enable dense model and general optimizations for better performance.
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
"VLLM_KUNLUN_ENABLE_DENSE_OPTIMIZE": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_DENSE_OPTIMIZE", "0"))
),
# Whether to enable mlp optimize when tensor parallel is enabled.
# this feature in eager mode will get better performance.
"VLLM_KUNLUN_ENABLE_MLP_OPTIMIZE": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_MLP_OPTIMIZE", "0"))
),
# Determine the number of physical devices in a non-full-use scenario
# caused by the initialization of the Mooncake connector.
"PHYSICAL_DEVICES": lambda: os.getenv("PHYSICAL_DEVICES", None),
# Whether to enable msMonitor tool to monitor the performance of vllm-kunlun.
"MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))),
# Timeout (in seconds) for delayed KVCache block release. In the prefill
# node, if a request is marked for delayed KV block release and the blocks
# are not freed within this timeout, they will be forcibly released.
"VLLM_KUNLUN_KVCACHE_DELAY_FREE_TIMEOUT": lambda: int(
os.getenv("VLLM_KUNLUN_KVCACHE_DELAY_FREE_TIMEOUT", 250)
),
"VLLM_KUNLUN_ENABLE_MLAPO": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_MLAPO", "0"))
),
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
"VLLM_KUNLUN_ENABLE_NZ": lambda: int(os.getenv("VLLM_KUNLUN_ENABLE_NZ", 1)),
# Decide whether we should enable CP parallelism.
"VLLM_KUNLUN_ENABLE_CONTEXT_PARALLEL": lambda: bool(
int(os.getenv("VLLM_KUNLUN_ENABLE_CONTEXT_PARALLEL", "0"))
),
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in env_variables:
return env_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(env_variables.keys())

View File

@@ -0,0 +1,10 @@
sphinx
sphinx-argparse
sphinx-book-theme
sphinx-copybutton
sphinx-design
sphinx-togglebutton
myst-parser
msgspec
sphinx-substitution-extensions
sphinx-intl

View File

@@ -0,0 +1,58 @@
<!--
**********************************************************************
* Copyright (c) 2025 Baidu Technologies Co., Ltd. All Rights Reserved.
* Copyright 2023 The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* This file is a part of the vllm-kunlun project.
* Adapted from https://github.com/vllm-project/vllm/blob/main/docs/source/_templates/sections/header.html
**********************************************************************
-->
<style>
.notification-bar {
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
font-size: 16px;
}
.notification-bar p {
margin: 0;
}
.notification-bar a {
font-weight: bold;
text-decoration: none;
}
/* Light mode styles (default) */
.notification-bar {
background-color: #fff3cd;
color: #856404;
}
.notification-bar a {
color: #d97706;
}
/* Dark mode styles */
html[data-theme="dark"] .notification-bar {
background-color: #333;
color: #ddd;
}
html[data-theme="dark"] .notification-bar a {
color: #ffa500; /* Brighter color for visibility */
}
</style>
<!-- <div class="notification-bar">
<p>You are viewing the latest developer preview docs. <a href="https://vllm-kunlun.readthedocs.io/en/v0.9.1-dev">Click here</a> to view docs for the latest stable release(v0.9.1).</p>
</div> -->

View File

@@ -0,0 +1,38 @@
# Maintainers and Acknowledgments
## Maintainers
| Name | Github ID | Date |
| :----------: | :----------------------------------------------: | :-----: |
| Xinyu Dong | [@xyDong0223](https://github.com/xyDong0223) | 2025/11 |
| Qian Bao | [@baoqian426](https://github.com/baoqian426) | 2025/11 |
| Zhennan Chen | [@chanzhennan](https://github.com/chanzhennan) | 2025/11 |
| Yili Chen | [@chenyili0619](https://github.com/chenyili0619) | 2025/11 |
| Hanyu Jin | [@Hanyu-Jin](https://github.com/Hanyu-Jin) | 2025/11 |
| Donghua Li | [@ldh2020](https://github.com/ldh2020) | 2025/11 |
## Acknowledgments
| Name |
| :------------: |
| Haowen Han |
| Tianyu Ma |
| Jizhong Yuan |
| Yucheng Liang |
| Hanshuo Yang |
| Wei Li |
| Hao Wang |
| Zhihui Wang |
| Hao Wang |
| YingZhuo Zhao |
| Wanli Yang |
| Xin Zhao |
| Yuqi Lin |
| Xiaokang Cheng |
| Zeyu You |
| Jingyu Zhang |
| Lidang Jiang |
| Yijin Qiao |
| Chenchao Hu |
| Weijie Hong |
| Song Jiang |

View File

@@ -0,0 +1,51 @@
# Governance
## Mission
As a vital component of vLLM, the vLLM Kunlun project is dedicated to providing an easy, fast, and cheap LLM Serving for everyone on Kunlun XPUs and to actively contributing to the enrichment of vLLM.
## Principles
vLLM Kunlun follows the vLLM community's code of conduct: [vLLM - CODE OF CONDUCT](https://github.com/vllm-project/vllm/blob/main/CODE_OF_CONDUCT.md)
## Governance - Mechanics
vLLM Kunlun is an open-source project under the vLLM community, where the authority to appoint roles is ultimately determined by the vLLM community. It adopts a hierarchical technical governance structure.
- Contributor:
**Responsibility:** Help new contributors on boarding, handle and respond to community questions, review RFCs and code.
**Requirements:** Complete at least 1 contribution. A contributor is someone who consistently and actively participates in a project, including but not limited to issue/review/commits/community involvement.
The contributor permissions are granted by the [vllm-kunlun]'s repo `Triage` on GitHub, including repo read and clone, issue and PR management, facilitating efficient collaboration between community developers.
- Maintainer:
**Responsibility:** Develop the project's vision and mission. Maintainers are responsible for shaping the technical direction of the project and ensuring its long-term success. With code merge permissions, they lead roadmap planning, review community contributions, make ongoing code improvements, and actively participate in community engagement—such as regular meetings and events.
**Requirements:** Deep understanding of vLLM and vLLM Kunlun code bases, with a commitment to sustained code contributions and competency in design, development, and PR review workflows.
- **Review quality:** Actively participate in community code reviews, ensuring high-quality code integration.
- **Quality contribution:** Successfully develop and deliver at least one major feature while maintaining consistent high-quality contributions.
- **Community involvement:** Actively address issues, respond to forum inquiries, participate in discussions, and engage in community-driven tasks.
The approval from existing Maintainers is required. The vLLM community has the final decision-making authority.
Maintainers will be granted write access to the [vllm-kunlun] GitHub repo. This includes permission to read, clone, and push to the repository, as well as manage issues and pull requests.
## Nominating and Removing Maintainers
### The Principles
- Membership in vLLM Kunlun is given to individuals on merit basis after they demonstrate their strong expertise in vLLM/vLLM Kunlun through contributions, reviews, and discussions.
- For membership in the maintainer group, individuals have to demonstrate strong and continued alignment with the overall vLLM/vLLM Kunlun principles.
- Maintainers who have been inactive for a long time may be transitioned to **emeritus** status under lenient criteria.
- The membership is for an individual, not a company.
### Nomination and Removal
- Nomination: Anyone can nominate a candidate to become a maintainer, including self-nominations. All existing maintainers are responsible for reviewing and evaluating each nomination. The nominator should provide relevant information about the nominee's qualifications—such as review quality, quality contribution, and community involvement—among other strengths.
- Removal: Anyone may nominate an individual for removal from the maintainer role, including self-nominations. All current maintainers are responsible for reviewing and evaluating such nominations. The nominator should provide relevant information about the nominee—such as prolonged inactivity, misalignment with the project's overall direction, or other factors that may render them unsuitable for the maintainer position.

View File

@@ -0,0 +1,3 @@
# User stories
Comming soon...

View File

@@ -0,0 +1,3 @@
# Versioning policy
Comming soon...

144
docs/source/conf.py Normal file
View File

@@ -0,0 +1,144 @@
#
# Copyright (c) 2025 Baidu Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
# Adapted from vllm-project/vllm/docs/source/conf.py
#
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import json
import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = "vllm-kunlun"
copyright = "2025, vllm-kunlun team"
author = "the vllm-kunlun team"
# The full version, including alpha/beta/rc tags
release = ""
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
# Copy from https://github.com/vllm-project/vllm/blob/main/docs/source/conf.py
extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.intersphinx",
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"myst_parser",
"sphinxarg.ext",
"sphinx_design",
"sphinx_togglebutton",
"sphinx_substitution_extensions",
]
myst_enable_extensions = ["colon_fence", "substitution"]
# Change this when cut down release
myst_substitutions = {
# the branch of vllm, used in vllm clone
# - main branch: 'main'
# - vX.Y.Z branch: 'vX.Y.Z'
"vllm_version": "0.10.1.1",
# the branch of vllm-kunlun, used in vllm-kunlun clone and image tag
# - main branch: 'main'
# - vX.Y.Z branch: latest vllm-kunlun release tag
"vllm_kunlun_version": "0.10.1.1",
# the newest release version of vllm-kunlun and matched vLLM, used in pip install.
# This value should be updated when cut down release.
"pip_vllm_kunlun_version": "0.10.1.1",
"pip_vllm_version": "0.10.1.1",
# vllm version in ci
"ci_vllm_version": "0.10.1.1",
}
# For cross-file header anchors
myst_heading_anchors = 5
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
locale_dirs = ["locale/"]
gettext_compact = False
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [
"_build",
"Thumbs.db",
".DS_Store",
".venv",
"README.md",
"user_guide/release.template.md",
"**/*.zh.md",
]
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_title = project
html_theme = "sphinx_book_theme"
html_logo = "logos/vllm-kunlun-logo-text-light.png"
html_theme_options = {
"path_to_docs": "docs/source",
"repository_url": "https://github.com/xxxxx/vllm-kunlun",
"use_repository_button": True,
"use_edit_page_button": True,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']
READTHEDOCS_VERSION_TYPE = os.environ.get("READTHEDOCS_VERSION_TYPE")
if READTHEDOCS_VERSION_TYPE == "tag":
# remove the warning banner if the version is a tagged release
header_file = os.path.join(
os.path.dirname(__file__), "_templates/sections/header.html"
)
# The file might be removed already if the build is triggered multiple times
# (readthedocs build both HTML and PDF versions separately)
if os.path.exists(header_file):
os.remove(header_file)
def setup(app):
pass
if __name__ == "__main__":
print(json.dumps(myst_substitutions))

View File

@@ -0,0 +1,70 @@
# Contributing
## Building and Testing
It's recommended to set up a local development environment to build vllm-kunlun and run tests
before you submit a PR.
#### Run models locally
After completing Run lint setup which is shown in quicksatrt, you can run your changed locally:
```{code-block} bash
:substitutions:
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8356 \
--model /your_modified_models\
--trust-remote-code \
--tensor-parallel-size 1 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name your_modified_models \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
Please save a screenshot of your service running successfully, and attach an accuracy report.
#### Submit the commit
```bash
# Commit changed files using `-s`
git commit -sm "your commit info"
```
🎉 Congratulations! You have completed the development environment setup.
## PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
- `[Attention]` for new features or optimization in attention.
- `[Communicator]` for new features or optimization in communicators.
- `[ModelRunner]` for new features or optimization in model runner.
- `[Platform]` for new features or optimization in platform.
- `[Worker]` for new features or optimization in worker.
- `[Core]` for new features or optimization in the core vllm-kunlun logic (such as platform, attention, communicators, model runner)
- `[Kernel]` for changes affecting compute kernels and ops.
- `[Bugfix]` for bug fixes.
- `[Doc]` for documentation fixes and improvements.
- `[Test]` for tests (such as unit tests).
- `[CI]` for build or continuous integration improvements.
- `[Misc]` for PRs that do not fit the above categories. Please use this sparingly.
:::{note}
If the PR spans more than one category, please include all relevant prefixes.
:::
## Others
If you find any problem when contributing, you can join our slack group to talk with us and then feel free to submit a PR to improve the doc to help other developers.
:::{toctree}
:caption: Index
:maxdepth: 1
testing
multi_node_test
:::

View File

@@ -0,0 +1,271 @@
## Operator accuracy test
### torch_xray
torch_xray is an operator precision analysis tool that can dump module-level input-output precision comparisons and automatically construct operator unit tests.
#### 1.Download and install
***\*python3.10:\****
bos:/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl
[https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/](https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-py3-none-any.whl)torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl
***\*python3.8:\****
bos:/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-cp38-cp38-linux_x86_64.whl
[https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/](https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-py3-none-any.whl)torch_xray-999.9.9-cp38-cp38-linux_x86_64.whl
Note that the same installation package must be used when using it in different environments.
#### 2.Use
##### Dump module-level inputs and outputs and compare their precision.
Below is a sample code snippet used to dump the input and output of the vision module and compare the errors in the vllm framework.
```bash
from torch_xray import PrecisionDebugger
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
# dump_path # Path to store dump results
# rank # Rank that needs to be dumped
# step # Setting the inference value to 1 is sufficient.
# model # The module to be dumped must be of type nn.module
debugger = PrecisionDebugger(dump_path="dump-vision", hook_name="dump", rank=[0], step=[1], model=self.model.visual, dump_torch_api=False)
debugger.start()
........
```
The results directory will generate an h5 file and a csv file.
```bash
-rw-r--r-- 1 root root 471231309 Oct 31 13:12 globalrank-0_localrank-0.h5
-rw-r--r-- 1 root root 71 Oct 31 13:11 globalrank-0_localrank-0_summary.csv
```
##### Data processing
```bash
summary xxx.h5 sum.txt
```
The generated h5 file is processed using the summary command to generate a txt file in which the results are presented in tabular form.
```bash
+-------+------+------+-----------------------------------------------------------+-------------+-------------+--------------+-------------+
| Index | Step | Rank | Module | Min | Max | Mean | Std |
+-------+------+------+-----------------------------------------------------------+-------------+-------------+--------------+-------------+
| 0 | 1 | 0 | patch_embed.proj.Conv3d.0.forward_params.weight | -0.0776367 | 0.0795898 | 6.8e-06 | 0.0072608 |
| 1 | 1 | 0 | patch_embed.proj.Conv3d.0.forward_params.bias | -3.046875 | 2.953125 | 0.0113748 | 0.3257138 |
| 2 | 1 | 0 | patch_embed.proj.Conv3d.0.forward_input.0 | -0.7490234 | 0.7021484 | 0.3302804 | 0.2339017 |
| 3 | 1 | 0 | patch_embed.proj.Conv3d.0.forward_output.0 | -4.0078125 | 5.1210938 | 0.0147052 | 0.3815643 |
| 4 | 1 | 0 | pos_embed.Embedding.0.forward_params.weight | -13.8125 | 20.25 | 0.0010043 | 0.2428094 |
| 5 | 1 | 0 | pos_embed.Embedding.0.forward_input.0 | 0.0 | 2303.0 | 1153.9191895 | 714.594360 |
| 6 | 1 | 0 | pos_embed.Embedding.0.forward_output.0 | -13.8125 | 20.25 | 0.0007552 | 0.2643428 |
| 7 | 1 | 0 | rotary_pos_emb.Qwen2_5_VisionRotaryEmbedding.0.forward... | 0.0 | 25.0 | 1.7337022 | 3.9271674 |
| 8 | 1 | 0 | blocks.0.norm1.LayerNorm.0.forward_params.weight | -0.5351562 | 3.140625 | 0.4660275 | 0.7907906 |
| 9 | 1 | 0 | blocks.0.norm1.LayerNorm.0.forward_params.bias | -2.359375 | 2.921875 | 0.0013793 | 0.1879374 |
| 10 | 1 | 0 | blocks.0.norm1.LayerNorm.0.forward_input.0 | -15.65625 | 20.21875 | 0.0155256 | 0.4382802 |
| 11 | 1 | 0 | blocks.0.norm1.LayerNorm.0.forward_output.0 | -6.1640625 | 6.7460938 | 0.0006746 | 0.2708515 |
| 12 | 1 | 0 | blocks.0.attn.qkv.QKVParallelLinear.0.forward_params.bias | -6.125 | 6.1875 | -0.0292423 | 0.8602651 |
| 13 | 1 | 0 | blocks.0.attn.qkv.QKVParallelLinear.0.forward_input.0 | -6.1640625 | 6.7460938 | 0.0006746 | 0.2708515 |
| 14 | 1 | 0 | blocks.0.attn.qkv.QKVParallelLinear.0.forward_output.0 | -6.5859375 | 7.6171875 | -0.0125549 | 1.0678084 |
| 15 | 1 | 0 | blocks.0.attn.proj.RowParallelLinear.0.forward_params... | -3.578125 | 3.203125 | -0.0043617 | 0.4846557 |
| 16 | 1 | 0 | blocks.0.attn.proj.RowParallelLinear.0.forward_input.0 | -1.9130859 | 1.4375 | 0.0005577 | 0.0947055 |
| 17 | 1 | 0 | blocks.0.attn.proj.RowParallelLinear.0.forward_output.0 | -9.109375 | 7.3867188 | -0.0034284 | 0.4465481 |
| 18 | 1 | 0 | blocks.0.norm2.LayerNorm.1.forward_params.weight | -0.1376953 | 14.5625 | 1.9166113 | 3.017405 |
| 19 | 1 | 0 | blocks.0.norm2.LayerNorm.1.forward_params.bias | -1.6328125 | 3.84375 | 0.0062865 | 0.2443586 |
| 20 | 1 | 0 | blocks.0.norm2.LayerNorm.1.forward_input.0 | -8.5859375 | 11.109375 | 0.0120974 | 0.4243064 |
| 21 | 1 | 0 | blocks.0.norm2.LayerNorm.1.forward_output.0 | -12.015625 | 14.265625 | -0.0012364 | 0.4973041 |
| 22 | 1 | 0 | blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forwar... | -9.4375 | 0.7304688 | -2.4200516 | 1.6754951 |
| 23 | 1 | 0 | blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forwar... | -12.015625 | 14.265625 | -0.0012364 | 0.4973041 |
| 24 | 1 | 0 | blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forwar... | -12.59375 | 13.0625 | -2.1465943 | 1.8433502 |
| 25 | 1 | 0 | blocks.0.mlp.act_fn.GELU.0.forward_input.0 | -12.59375 | 13.0625 | -2.1465943 | 1.8433502 |
+-------+------+------+-----------------------------------------------------------+-------------+-------------+--------------+-------------+
```
##### Accuracy Comparison
```bash
# The results are stored in result.csv
compare xpu.h5 gpu.h5 result.csv
```
The `compare` command is used to process the H5 files generated on the GPU and XPU, resulting in a CSV file. This CSV file is then downloaded to the local machine and opened with Excel, yielding a result similar to the image below.
If you encounter a "no matched keys" problem, please refer to the instructions at the end of this article for a solution.
##### Example of results
```bash
+-------+--------+-----------------------------------------------------------+--------+-----------+-------------+-------------+--------+
| Index | Status | Module (Bench/Target) | Cosine | RMSE | IsClose (%) | Max Err (t) | GtNum |
+-------+--------+-----------------------------------------------------------+--------+-----------+-------------+-------------+--------+
| 0 | | patch_embed.proj.Conv3d.0.forward_params.weight | 1 | 0 | 100 | 0 | 0 |
| 1 | | patch_embed.proj.Conv3d.0.forward_params.bias | 1 | 0 | 100 | 0 | 0 |
| 2 | | patch_embed.proj.Conv3d.0.forward_input.0 | 1 | 0 | 100 | 0 | 0 |
| 3 | | patch_embed.proj.Conv3d.0.forward_output.0 | 1 | 9.90E-06 | 100 | 0.001953 | 267 |
| 4 | | pos_embed.Embedding.0.forward_params.weight | 1 | 0 | 100 | 0 | 0 |
| 5 | | pos_embed.Embedding.0.forward_input.0 | 1 | 0 | 100 | 0 | 0 |
| 6 | | pos_embed.Embedding.0.forward_output.0 | 1 | 0 | 100 | 0 | 0 |
| 7 | | rotary_pos_emb.Qwen2_5_VisionRotaryEmbedding.0.forward... | 1 | 0 | 100 | 0 | 0 |
| 8 | | blocks.0.norm1.LayerNorm.0.forward_params.weight | 1 | 0 | 100 | 0 | 0 |
| 9 | | blocks.0.norm1.LayerNorm.0.forward_params.bias | 1 | 0 | 100 | 0 | 0 |
| 10 | | blocks.0.norm1.LayerNorm.0.forward_input.0 | 1 | 1.14E-05 | 100 | 0.00390625 | 216 |
| 11 | | blocks.0.norm1.LayerNorm.0.forward_output.0 | 1 | 1.84E-05 | 99.98 | 0.0078125 | 1585 |
| 12 | | blocks.0.attn.qkv.QKVParallelLinear.0.forward_params.bias | 1 | 0 | 100 | 0 | 0 |
| 13 | | blocks.0.attn.qkv.QKVParallelLinear.0.forward_input.0 | 1 | 1.84E-05 | 99.98 | 0.0078125 | 1585 |
| 14 | | blocks.0.attn.qkv.QKVParallelLinear.0.forward_output.0 | 1 | 0.0002776 | 99.53 | 0.00390625 | 119074 |
| 15 | | blocks.0.attn.proj.RowParallelLinear.0.forward_params... | 1 | 0 | 100 | 0 | 0 |
| 16 | | blocks.0.attn.proj.RowParallelLinear.0.forward_input.0 | 1 | 3.40E-05 | 99.07 | 0.0012207 | 52482 |
| 17 | | blocks.0.attn.proj.RowParallelLinear.0.forward_output.0 | 1 | 0.0001283 | 99.07 | 0.00390625 | 50591 |
| 18 | | blocks.0.norm2.LayerNorm.1.forward_params.weight | 1 | 0 | 100 | 0 | 0 |
| 19 | | blocks.0.norm2.LayerNorm.1.forward_params.bias | 1 | 0 | 100 | 0 | 0 |
| 20 | | blocks.0.norm2.LayerNorm.1.forward_input.0 | 1 | 0.0001437 | 99.01 | 0.0039062 | 31376 |
| 21 | Fail | blocks.0.norm2.LayerNorm.1.forward_output.0 | 1 | 0.0002779 | 98.72 | 0.015625 | 40770 |
| 22 | | blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forward... | 1 | 0 | 100 | 0 | 0 |
| 23 | Fail | blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forward... | 1 | 0.0002779 | 98.72 | 0.015625 | 40770 |
| 24 | | blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forward... | 1 | 0.000779 | 98.67 | 0.0078125 | 196313 |
| 25 | | blocks.0.mlp.act_fn.GELU.0.forward_input.0 | 1 | 0.000779 | 98.67 | 0.0078125 | 196313 |
| 26 | | blocks.0.mlp.act_fn.GELU.0.forward_output.0 | 1 | 0.0001012 | 98.08 | 0.0039062 | 153508 |
+-------+--------+-----------------------------------------------------------+--------+-----------+-------------+-------------+--------+
```
Generally, the main focus is on Min Err/Max Err.
##### Indicator Explanation
To be improved...
#### The dump operator is tested and run.
```bash
X_DEBUG=0x102 # trace operator name、arguments shape、dtype、data_range
X_DEDUP=True # Remove duplicates based on shape and dtype.
X_DUMP_NUM # The default value is 0, meaning no tensor data is saved. Setting it to n means that n parameters are randomly selected from each operator to save the actual parameters.
```
Below is a sample code snippet that dumps information such as the size and dtype of the forward operator of Qwen3_VisionTransformer. During runtime, an xray_debug directory will be automatically created in the current directory to store the dump results.
```bash
from torch_xray import begin_dump, end_dump
.............

class Qwen3_VisionTransformer(nn.Module):

def __init__(
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
..........
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# Start dump
# X_DEBUG=0x102 # trace operator name、arguments shape、dtype、data_range
# X_DEDUP=True # Remove duplicates based on shape and dtype.
# The default value is 0, meaning no tensor data is saved. Setting it to n means that n parameters are randomly selected from each operator to save the actual parameters.
begin_dump(X_DEBUG=0x102, X_DEDUP=True, X_DUMP_NUM=5)
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
...........
# End dump
end_dump(clear_context=True)
return hidden_states
```
This is the file directory.
```bash
├── xary_debug/
│ ├── proc_xxx/ # Process-based storage results
│ ├── dump/ # The dumped tensor
│ ├── dump.json # Information needed to generate unit tests, such as input/output size and dtype.
```
##### Generate unit test
jprof --cpu_init --blacklist --factory=load dump.json
Create a pytests directory in the current directory to store unit tests.
##### Run unit test
The GPU only needs to copy the XPU's pytests directory and execute it.
Since the unit test program defaults to finding the actual dumped tensors using relative paths, this step must be performed in the xary_debug/ directory.
```bash
# detail_compare_path stores the unit test results.
pytest --detail_compare_path=./xxx.csv proc_xxx/pytests/ --seed 42
```
##### Results Comparison
```bash
# After obtaining two result CSV files, compare them and generate result.csv.
summary_diff_check ./xpu.csv ./gpu.csv ./result.csv
```
##### Example of results
```bash
+------------+-----------------------+-------------+-------------+-----------+----------+---------+---------+----------+
| name | op_name | dtype | shape | min-val | max-val | is_pass | xpu_max | gpu_max |
+------------+-----------------------+-------------+-------------+-----------+----------+---------+---------+----------+
| 00004-aten | aten.linspace.default | torch.float | [10] | 0 | 47 | pass | 0 | 1.91E-06 |
| 00005-aten | aten.linspace.default | torch.float | [26] | 0 | 47 | pass | 0 | 0 |
| 00027-aten | aten.add.Tensor | torch.int64 | [10, 26] | 0 | 0 | pass | 0 | 0 |
| 00028-aten | aten.add.Tensor | torch.int64 | [10, 26] | 0 | 0 | pass | 0 | 0 |
| 00037-aten | aten.add.Tensor | torch.float | [260, 1152] | -29.09375 | 33.75 | pass | 0 | 0 |
| 00038-aten | aten.add.Tensor | torch.float | [260, 1152] | -27.1875 | 37.625 | pass | 0 | 0 |
| 00047-aten | aten.add.Tensor | torch.float | [260, 1152] | -28.98438 | 42.34375 | pass | 0 | 0 |
| 00082-aten | aten.sub.Tensor | torch.int32 | [1] | 0 | 0 | pass | 0 | 0 |
+------------+-----------------------+-------------+-------------+-----------+----------+---------+---------+----------+
```
The main focus is on the values of gpu_1e-1, xpu_1e-1, etc., which represent the number of elements whose error between the gpu/xpu result and the cpu result exceeds the order of 1e-n. This serves as the primary basis for determining whether there is a problem with the operator's precision.
#### Replenish
##### Bypassing the issue of differing naming conventions between Kunlun Card and GPU modules, which prevents diff calculation.
```bash
#
blocks.0.mlp.linear_fc1.ColumnParallelLinear.0.forward_params.bias
#
blocks.0.mlp.linear_fc1.ColumnParalleLinear.forward_params.bias
```
As shown in the figure above, due to various reasons, the module names dumped by the GPU and XPU are often different, and the compare command cannot be used to identify them directly.
```python
for step in steps: # (['/'] for group creation order h5py >= 3.10.0)
# for bench_key, target_key in get_matched_names(
# list(dump_ben[str(step)].keys()),
# list(dump_tar[str(step)].keys()),
# fuzzy_match,
# ):
for bench_key, target_key in zip(
list(dump_ben[str(step)].keys()),
list(dump_tar[str(step)].keys()),
):
```
Modify torch_xray/compare/compare.py to skip the get_matched_name step. This modification will allow for line-by-line comparison even if module names differ, producing a compare result. However, it's crucial to ensure that the number of rows in the GPU and XPU dumps is consistent.

View File

@@ -0,0 +1,240 @@
## Overall accuracy test
### EvalScope
#### 1.Download and install
EvalScope supports use in Python environments. Users can install EvalScope via pip or from source code. Here are examples of both installation methods:
```bash
#pip
pip install evalscope[perf] -U
#git
git clone https://github.com/modelscope/evalscope.git
cd evalscope
pip install -e '.[perf]'
```
#### 2.Dataset preparation script
```python
from evalscope.collections import CollectionSchema, DatasetInfo, WeightedSampler
from evalscope.utils.io_utils import dump_jsonl_data
import os # Step 1: Import the os module
schema = CollectionSchema(
name="VL-Test",
datasets=[
CollectionSchema(
name="PureText",
weight=1,
datasets=[
DatasetInfo(
name="mmlu_pro",
weight=1,
task_type="exam",
tags=["en"],
args={"few_shot_num": 0},
),
DatasetInfo(
name="ifeval",
weight=1,
task_type="instruction",
tags=["en"],
args={"few_shot_num": 0},
),
DatasetInfo(
name="gsm8k",
weight=1,
task_type="math",
tags=["en"],
args={"few_shot_num": 0},
),
],
),
CollectionSchema(
name="Vision",
weight=2,
datasets=[
DatasetInfo(
name="math_vista",
weight=1,
task_type="math",
tags=["en"],
args={"few_shot_num": 0},
),
DatasetInfo(
name="mmmu_pro",
weight=1,
task_type="exam",
tags=["en"],
args={"few_shot_num": 0},
),
],
),
],
)
# get the mixed data
mixed_data = WeightedSampler(schema).sample(1000)
output_path = "outputs/vl_test.jsonl" # Step 2: Define the output file path
output_dir = os.path.dirname(output_path) # Step 3: Obtain the directory name
if not os.path.exists(output_dir): # Step 4: Check if the directory exists
os.makedirs(output_dir, exist_ok=True) # Step 5: Automatically create directories
# dump the mixed data to a jsonl file
dump_jsonl_data(mixed_data, output_path) # Step 6: Securely write to the file
```
Dataset composition visualization:
```
┌───────────────────────────────────────┐
│ VL-Test (1000 samples) │
├─────────────────┬─────────────────────┤
│ PureText │ Vision │
│ (333 samples) │ (667 samples) │
├─────────────────┼─────────────────────┤
│ • mmlu_pro │ • math_vista │
│ • ifeval │ • mmmu_pro │
│ • gsm8k │ │
└─────────────────┴─────────────────────┘
```
#### 3.Test
```python
from dotenv import dotenv_values
from evalscope import TaskConfig, run_task
from evalscope.constants import EvalType
task_cfg = TaskConfig(
model="Qwen2.5-VL-7B-Instruct",
api_url="http://localhost:8804/v1",
api_key="EMPTY",
eval_type=EvalType.SERVICE,
datasets=[
"data_collection",
],
dataset_args={
"data_collection": {
"local_path": "../outputs/vl_test.jsonl",
}
},
eval_batch_size=5,
generation_config={
"max_tokens": 30000, # The maximum number of tokens that can be generated should be set to a large value to avoid output truncation.
"temperature": 0.6, # Sampling temperature (recommended value from qwen report)
"top_p": 0.95, # top-p sampling (recommended value from qwen report)
"top_k": 20, # Top-k sampling (recommended value from qwen report)
"n": 1, # Number of responses generated per request
"repetition_penalty": 1.0, # 1.0 = Penalty disabled, >1.0 = Penalty repeated.
},
)
run_task(task_cfg=task_cfg)
```
Parameter Tuning Guide:
| Parameter | Current value | Effect | Adjustment suggestions |
| ----------------- | ------------- | ---------------------------------------- | -------------------------------------------------------- |
| `temperature` | 0.6 | Control output diversity | Math problems ↓ 0.3 / Creative writing ↑ 0.9 |
| `top_p` | 0.95 | Filtering low-probability tokens | Reduce "nonsense" |
| `eval_batch_size` | 5 | Number of requests processed in parallel | With sufficient video memory, it can be increased to 10. |
Run the test:
```bash
#!/bin/bash
# ========================================
# Step 1: Set the log file path
# ========================================
LOG_FILE="accuracy_$(date +%Y%m%d_%H%M).log"
# ========================================
# Step 2: Execute the Python script and capture all output
# Meaning of 2>&1:
# - 2 represents standard error output (stderr)
# ->& represents redirection and merging
# - 1 represents standard output (stdout)
# Function: Merges error messages into standard output as well.
# ========================================
python accuracy.py 2>&1 | tee "$LOG_FILE"
# ========================================
# Step 3: Check Execution Status
# ${PIPESTATUS[0]} Get the exit code of the first command (Python) in the pipeline
# ========================================
EXIT_CODE=${PIPESTATUS[0]}
if [ $EXIT_CODE -eq 0 ]; then
echo "✅ Evaluation completed! Log saved to: $LOG_FILE"
else
echo "❌ Evaluation failed! Exit code: $EXIT_CODE Please check the log: $LOG_FILE"
fi
```
#### 4.Common problem fixes
##### 4.1 NLTK resource missing fix
```bash
Resource punkt_tab not found.
```
Solution
```python
import nltk
import os
# Step 1: Set the download path (select a writable directory)
download_dir = "/workspace/myenv/nltk_data"
os.makedirs(download_dir, exist_ok=True)
# Step 2: Configure NLTK data path
nltk.data.path.append(download_dir)
# Step 3: Download necessary resources
print("🔽 Start downloading punkt_tab resource...")
try:
nltk.download("punkt_tab", download_dir=download_dir)
print("✅ Download successful!")
except Exception as e:
print(f"❌ Download failed: {e}")
print("💡 Alternative: Download manually from GitHub")
print(
" URL: https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/tokenizers/punkt_tab.zip"
)
```
repair:
```bash
# Activate environment
source /workspace/myenv/bin/activate
# Run the repair script
python fix_nltk.py
# Rerun the test
bash run_accuracy_test.sh
```
#### 5.Results Display
```bash
+-------------+---------------------+--------------+---------------+-------+
| task_type | metric | dataset_name | average_score | count |
+-------------+---------------------+--------------+---------------+-------+
| exam | acc | mmmu_pro | 0.521 | 334 |
| math | acc | math_vista | 0.6066 | 333 |
| exam | acc | mmlu_pro | 0.5405 | 111 |
| instruction | prompt_level_strict | ifeval | 0.6937 | 111 |
| math | acc | gsm8k | 0.8288 | 111 |
+-------------+---------------------+--------------+---------------+-------+
```

View File

@@ -0,0 +1,10 @@
# Accuracy
This document details the accuracy testing methods for vllm-kunlun and the analysis of the results.
:::{toctree}
:caption: Accuracy
:maxdepth: 1
accuracy_server
accuracy_kernel
:::

View File

@@ -0,0 +1,18 @@
# GLM-Air-4.5
* vLLM Version: vLLM: 0.10.1.1 , vLLM-KunLun Version: v0.10.1.1
* Software Environment:OS: Ubuntu 22.04, PyTorch ≥ 2.5.1
* Hardware Environment: KunLun P800
* Parallel mode:TP8
```bash
+-------------+----------+---------------+---------+-----+--------+---------+
| Model | Dataset | Metric | Subset | Num | Score | Cat.0 |
+-------------+----------+---------------+---------+-----+--------+---------+
| GLM-4.5-Air | math_500 | AveragePass@1 | Level 1 | 43 | 0.9302 | default |
| GLM-4.5-Air | math_500 | AveragePass@1 | Level 2 | 90 | 0.9222 | default |
| GLM-4.5-Air | math_500 | AveragePass@1 | Level 3 | 105 | 0.8762 | default |
| GLM-4.5-Air | math_500 | AveragePass@1 | Level 4 | 128 | 0.8984 | default |
| GLM-4.5-Air | math_500 | AveragePass@1 | Level 5 | 134 | 0.8955 | default |
+-------------+----------+---------------+---------+-----+--------+---------+
```

View File

@@ -0,0 +1,18 @@
# GLM-4.5
* vLLM Version: vLLM: 0.10.1.1 , vLLM-KunLun Version: v0.10.1.1
* Software Environment:OS: Ubuntu 22.04, PyTorch ≥ 2.5.1
* Hardware Environment: KunLun P800
* Parallel mode:TP8
```bash
+---------+----------+---------------+---------+-----+--------+---------+
| Model | Dataset | Metric | Subset | Num | Score | Cat.0 |
+---------+----------+---------------+---------+-----+--------+---------+
| GLM-4.5 | math_500 | AveragePass@1 | Level 1 | 43 | 0.9302 | default |
| GLM-4.5 | math_500 | AveragePass@1 | Level 2 | 90 | 0.8111 | default |
| GLM-4.5 | math_500 | AveragePass@1 | Level 3 | 105 | 0.7143 | default |
| GLM-4.5 | math_500 | AveragePass@1 | Level 4 | 128 | 0.6172 | default |
| GLM-4.5 | math_500 | AveragePass@1 | Level 5 | 134 | 0.5149 | default |
+---------+----------+---------------+---------+-----+--------+---------+
```

View File

@@ -0,0 +1,18 @@
# InternVL3_5-30B-A3B
* vLLM Version: vLLM: 0.10.1.1 , vLLM-KunLun Version: v0.10.1.1
* Software Environment:OS: Ubuntu 22.04, PyTorch ≥ 2.5.1
* Hardware Environment: KunLun P800
* Parallel mode:TP8
```
+-------------+---------------------+--------------+---------------+-------+
| task_type | metric | dataset_name | average_score | count |
+-------------+---------------------+--------------+---------------+-------+
| exam | acc | mmmu_pro | 0.5449 | 334 |
| math | acc | math_vista | 0.6847 | 333 |
| exam | acc | mmlu_pro | 0.6126 | 111 |
| instruction | prompt_level_strict | ifeval | 0.7658 | 111 |
| math | acc | gsm8k | 0.9369 | 111 |
+-------------+---------------------+--------------+---------------+-------+
```

View File

@@ -0,0 +1,18 @@
# Qwen2.5-VL-7B-Instruct
* vLLM Version: vLLM: 0.10.1.1 , vLLM-KunLun Version: v0.10.1.1
* Software Environment:OS: Ubuntu 22.04, PyTorch ≥ 2.5.1
* Hardware Environment: KunLun P800
* Parallel mode:TP1
```
+-------------+---------------------+--------------+---------------+-------+
| task_type | metric | dataset_name | average_score | count |
+-------------+---------------------+--------------+---------------+-------+
| exam | acc | mmmu_pro | 0.521 | 334 |
| math | acc | math_vista | 0.6066 | 333 |
| exam | acc | mmlu_pro | 0.5405 | 111 |
| instruction | prompt_level_strict | ifeval | 0.6937 | 111 |
| math | acc | gsm8k | 0.8288 | 111 |
+-------------+---------------------+--------------+---------------+-------+
```

View File

@@ -0,0 +1,10 @@
# Accuracy Report
:::{toctree}
:caption: Accuracy Report
:maxdepth: 1
Qwen2.5-VL-7B-Instruct
InternVL3_5-30B-A3B
GLM-4.5
GLM-4.5-Air
:::

View File

@@ -0,0 +1,8 @@
# Accuracy
:::{toctree}
:caption: Accuracy
:maxdepth: 1
accuracy/index
accuracy_report/index
:::

View File

@@ -0,0 +1,76 @@
# Kunlun Graph
## Why we need Kunlun Graph?
When in LLM inference, each token requires nearly thousand operator executions, and when host launching operators are slower than device, it will cause host bound. In severe cases, the device will be idle for more than half of the time. To solve this problem, we use graph in LLM inference.
```
eager mode:
host: | launch op1 | launch op2 | launch op3 | launch op4 | launch op5 |
device: | run op1 |free| run op2 |free| run op3 |free| run op4 |free| run op5 |
| <----- total time -----> |
graph mode:
host: | launch graph |
device: | run op1 | run op2 | run op3 | run op4 | run op5 |
| <----- total time -----> |
```
## How to use Kunlun Graph?
Kunlun Graph is enabled by default in V1 Engine, just need to check that `enforce_eager` is not set to `True`.
## How it works?
In short, graph mode works in two steps: **capture and replay**. When engine starts, we will capture all of the ops in model forward and save it as a graph, and when req come in, we just replay the graph on devices, and waiting for result.
But in reality, graph mode is not that simple.
### Padding and Bucketing
Due to graph can only replay the ops captured before, without doing tiling and checking graph input, we need to ensure the consistency of the graph input, but we know that model input's shape depends on the request scheduled by Scheduler, we can't ensure the consistency.
Obviously, we can solve this problem by capturing the biggest shape and padding all of the model input to it. But it will bring a lot of redundant computing and make performance worse. So we can capture multiple graphs with different shape, and pad the model input to the nearest graph, which will greatly reduce redundant computing. But when `max_num_batched_tokens` is very large, the number of graphs that need to be captured will also become very large. But we know that when intensor's shape is large, the computing time will be very long, and graph mode is not necessary in this case. So all of things we need to do is:
1. Set a threshold;
2. When `num_scheduled_tokens` is bigger than the threshold, use `eager_mode`;
3. Capture multiple graphs within a range below the threshold;
```
| graph1 |
| graph2 |
| graph3 |
| graph4 | # the threshold
| input1 | pad | # use graph1
| input2 | # don't need pad
| input3 | pad | # use graph4
| input4 | # use eager mode
```
### Piecewise and Full graph
Due to the increasing complexity of the attention layer in current LLM, we can't ensure all types of attention can run in graph. In MLA, prefill_tokens and decode_tokens have different calculation method, so when a batch has both prefills and decodes in MLA, graph mode is difficult to handle this situation.
vLLM solves this problem with piecewise graph mode. We use eager mode to launch attention's ops, and use graph to deal with others. But it also bring some problems: The cost of launching ops has become large again, although much smaller than eager mode, but it will also lead to host bound when cpu is poor or `num_tokens` is small.
## How it be implemented?
vLLM has already implemented most of the modules in graph mode. You can see more details at: [CUDA Graphs](https://docs.vllm.ai/en/latest/design/cuda_graphs.html)
When in graph mode, vLLM will call `current_platform.get_static_graph_wrapper_cls` to get current device's graph model wrapper, so what we need to do is to implement the graph mode wrapper on Kunlun: `Kunlun Graph Wrapper`.
vLLM has added `support_torch_compile` decorator to all models, this decorator will replace the `__init__` and `forward` interface of the model class, and when `forward` called, the code inside the `vllm_kunlun.compilation` will be executed, and it will do capture or replay as mentioned above.
## Limitation
1. `FULL` and `FULL_AND_PIECEWISE` are not supported now;
3. `use_inductor` is not supported now;

View File

@@ -0,0 +1,9 @@
# Feature Guide
This section provides an overview of the features implemented in vLLM-Kunlun. Developers can refer to this guide to understand how vLLM-Kunlun works.
:::{toctree}
:caption: Feature Guide
:maxdepth: 1
Kunlun_Graph
:::

View File

@@ -0,0 +1,7 @@
# Performance
:::{toctree}
:caption: Performance
:maxdepth: 1
performance_benchmark/index
:::

View File

@@ -0,0 +1,147 @@
## Operator performance
### XProfiler
#### 1.Download and install
- The download link for the x86_64 platform installation package xre-Linux-x86_64 is:
`https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.26/peermem/xre-Linux-x86_64-5.0.21.26.run`
`https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.26/peermem/xre-Linux-x86_64-5.0.21.26.tar.gz`
- If the client is using bdCentOS, we recommend using the following download link:
`https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.26/xre-bdcentos-x86_64-5.0.21.26.tar.gz`
After downloading and extracting, you can directly execute `xpu-installer` and `install_rt.sh` to install.
#### 2.Start using
XProfiler supports three modes: 1) fork mode; 2) time mode; and 3) daemon mode. After execution, XProfiler will generate two types of JSON files:
- xprofiler.settings.json: Records the event configuration for this trace.
- xprofiler.trace.json: Records the results of this trace.
The specific modes will be introduced below.
##### fork mode
The fork pattern is used to track the entire time period from the start to the end of a user program. This pattern is suitable for most inference tasks and is the simplest to use. An example is shown below:
```bash
/xxxx/xxxx/xprofiler -r500 --xpu=0 python test.py
```
- --r: Sets the trace time resolution in nanoseconds (ns). The default is 100. If an "out of space error" occurs, try increasing the -r value to 500.
- --xpu: Specifies the acquisition device ID, supporting multi-card configuration. --xpu=all enables all cards; the default is card 0.
More parameters can be found in the command-line parameters section later.
##### time mode
The time mode is used to track user programs for a period of time. This method is suitable for tasks that need to run for a long time.
Using the -t or --time command-line parameter, XPorfiler will run for the specified time and then exit, in seconds. In this mode, the application needs to be started separately. An example is as follows:
(1) Starting XPorfiler
```bash
/xxxx/xxxx/xprofiler -r 500 --xpu=0 -t600 # Time mode collects events within a specified time period, measured in seconds (s).
```
A temporary .sock file will be generated in the execution directory. The path needs to be configured in the environment variables.
(2) Start the program
```bash
export XPU_ENABLE_PROFILER_TRACING=1
export XPU_TRACING_OUTPUT_NAME=<xprofiler execution directory>/xprofiler.sock
# Start your own program
python xxx.py
```
##### deamon mode
The daemon mode is used to track the event timeline of a specified code segment, eliminating interference from redundant information. The startup command is the same as in fork mode.
(1) Insert start and stop interfaces.
```python
import xtorch_ops
# Only capture events during the generate phase
xtorch_ops.kunlun_profiler_start()
outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)
xtorch_ops.kunlun_profiler_end()
```
(2) Launch X profiler in a terminal
```python
# Specify the output file as the trace_output file in the current path.
/xxxx/xxxx/xprofiler-Linux_x86_64-2.0.2.0/bin/xprofiler -r 500 --xpu=0 -e ./trace_output -d
```
After startup, a .sock file will be generated in the current directory.
```bash
xprofiler.sock
```
(3) Launch your own program on another terminal.
```python
export XPU_ENABLE_PROFILER_TRACING=1
# Here, the path to the .sock file from step 2 is used for assignment.
export XPU_TRACING_OUTPUT_NAME=<xprofiler execution directory>/xprofiler.sock
# Start your own program
python xxx.py
```
Note: If you want to specify a particular card to run on, you must import the XPU_VISIBLE_DEVICES environment variable in the terminal in steps 2 and 3; otherwise, you will not be able to capture the data.
##### More parameters
| parameters | Example | default value | describe |
| -------------------------- | --------------------------------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| -b or --buffer-size | -b=512 | 256 | Specifies the size of the trace buffer in MB. This is generally not required. However, if there are many trace signals, the buffer size can be increased appropriately to avoid OOS (Out of Size). |
| -x or --xpu | -x=0--xpu=0 | 0 | Set the card number to be tracked; multiple cards or all cards can be set. |
| -t or --time | -t=10 | off | Enable time mode, in seconds, to capture information over a specified period. |
| -d or --deamonize | -r500 | 0 | Enable daemon mode to retrieve events in the background. |
| -r or --export-profile | -e ./trace_output-e ./output/trace.json | ./ | Record the trace results to a document or folder. If this parameter is not specified, a default xprofiler.trace.json file will be generated in the execution directory. |
| -S or --settings | -S xprofiler.trace.json | off | xprofiler reads a JSON file containing the events that need to be traced. If this parameter is not configured, xprofiler enables `--profile-api-trace` and `--sse-trace` by default. |
| -A or --profiler-api-trace | -A | on | Get driver events. |
| -s or --sse-trace | -s | on | Get all SSE events. |
| -C or --cluster-trace | -C | off | Retrieve all cluster events. |
| -n or --sdnn-trace | -n | off | Get all SDNN events. |
| -c or --sdnn-cluster-trace | -c | off | Retrieve all SDNN cluster events. |
| -E or --cache-trace | -E | off | Get bandwidth statistics events. |
| -u or --debug | -u44:open logdebug level-u0:close log | 33 | Debug the interface and enable driver event/device event logging.。 |
#### 3.View Results
The generated xprofiler.trace.json file can be viewed and analyzed using a visual interface. Two tools are introduced here.
##### Chrome browser
Enter chrome://tracing/ in your browser (you may need to enable developer tools the first time you access this site), and click "load" in the top left corner to import the file. Interface display.
![img](https://rte.weiyun.baidu.com/wiki/attach/image/api/imageDownloadAddress?attachId=89aef70f112a4394adcac8b03ef994db&docGuid=WFoZOcuqnSXJIE)
##### prefetto ui
Search directly, or visit[Perfetto UI](https://ui.perfetto.dev/#!/viewer?local_cache_key)The interface is as follows。
![img](https://rte.weiyun.baidu.com/wiki/attach/image/api/imageDownloadAddress?attachId=895a715344e9473c9ee93518c3064b27&docGuid=WFoZOcuqnSXJIE)
#### 4.Performance Analysis
With various performance data available, analysis and optimization can then be performed based on the results.
(Further details to be added later)

View File

@@ -0,0 +1,199 @@
## vLLM server performance
### vLLM benchmark CLI
You can directly use vLLM's CLI benchmark. For more details, please refer to[vLLM Developer Guide Benchmark Suites](https://docs.vllm.ai/en/stable/contributing/benchmarks.html)
#### 1.Online testing
##### 1.1Start the vLLM server
Server startup script reference
```bash
USE_ORI_ROPE=1 VLLM_USE_V1=1 python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port xxxx \
--model /xxxx/xxxx/model\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 32768 \
--max-seq-len-to-capture 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name modelname \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
##### 1.2Execute test
To run the test script, you can refer to the code below.
```bash
#!/bin/bash
# Run benchmark tests
python -m vllm.entrypoints.cli.main bench serve \
--host 127.0.0.1 \
--port xxxx \
--backend vllm \
--model modelname \
--dataset-name random \
--num-prompts 500 \
--random-input-len 1024 \
--random-output-len 1024 \
--tokenizer /xxxx/xxxx/model \
--ignore-eos 2>&1 | tee benchmark.log
```
##### 1.3Result
The following content will be displayed after the process is complete.
```bash
========== Serving Benchmark Result ==========
Successful requests: 500
Benchmark duration (s): 144.89
Total input tokens: 510414
Total generated tokens: 512000
Request throughput (req/s): 3.45
Output token throughput (tok/s): 3533.68
Total Token throughput (tok/s): 7056.42
----------Time to First Token----------
Mean TTFT (ms): 57959.61
Median TTFT (ms): 43551.93
P99 TTFT (ms): 116202.52
----------Time per Output Token (excl. 1st token)----------
Mean TPOT (ms): 33.30
Median TPOT (ms): 34.15
P99 TPOT (ms): 35.59
----------Inter-token Latency----------
Mean ITL (ms): 33.30
Median ITL (ms): 29.05
P99 ITL (ms): 46.14
============================================
```
Key Parameter Explanation:
| index | meaning | Optimization Objective |
| --------------------------- | ------------------------------------| ---------- |
| ***\*Output Throughput\**** | Output token generation rate | ↑ The higher the better |
| ***\*Mean TTFT\**** | First Token Delay (Time To First Token) | ↓ The lower the better |
| ***\*P99 TTFT\**** | 99% of requests have delayed first token. | ↓ The lower the better |
| ***\*Mean TPOT\**** | Average generation time per output token | ↓ The lower the better |
| ***\*P99 TPOT\**** | 99% of requests' time per token generation | ↓ The lower the better |
| ***\*ITL\**** | Delay between adjacent output tokens | ↓ The lower the better |
#### 2.Offline testing
Comming soon...
### EvalScope
EvalScope is a comprehensive model testing tool that can test not only model accuracy but also performance. For more information, please visit [website address missing].[EvalScope](https://evalscope.readthedocs.io/en/latest/index.html)A brief introduction follows.
#### 1.Download and install
EvalScope supports use in Python environments. Users can install EvalScope via pip or from source code. Here are examples of both installation methods:
```bash
#pip
pip install evalscope[perf] -U
#git
git clone https://github.com/modelscope/evalscope.git
cd evalscope
pip install -e '.[perf]'
```
After downloading, some modules may be missing, causing the program to fail to run. Just follow the prompts to install them.
#### 2.Start using
The following demonstrates the performance test of the Qwen3-8B in a single-card scenario.
##### 2.1Start the server
The first step is to start the server. The example script is shown below.
```bash
USE_ORI_ROPE=1 VLLM_USE_V1=1 python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port xxxx \
--model /xxxx/xxxx/Qwen3-8B\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 32768 \
--max-seq-len-to-capture 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-8B \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
##### 2.2 Start EvalScope
Start EvalScope to begin performance testing.
```bash
evalscope perf \
--parallel 1 10\#The number of concurrent requests can be tested at once, separated by spaces.
--number 10 20\#The total number of requests per request, aligned with spaces and the concurrency count.
--model Qwen3-8B \
--url http://127.0.0.1:xxxx/v1/chat/completions \
--api openai \
--dataset random \
--max-tokens 1024 \
--min-tokens 1024 \
--prefix-length 0 \
--min-prompt-length 1024 \
--max-prompt-length 1024 \
--tokenizer-path /xxxx/xxxx/Qwen3-8B\
--extra-args '{"ignore_eos": true}'
```
##### 2.3Results Analysis
The following figure shows the results. You can view other data from a single test through the logs. For the specific meaning of the parameters, please refer to the parameter interpretation in the vLLM benchmark test.
```bash
Performance Test Summary Report
Basic Information:
+-------------------+------------------------+
| Model | Qwen3-8B |
| Total Generated | 30,720.0 tokens |
| Total Test Time | 199.79 seconds |
| Avg Output Rate | 153.76 tokens/sec |
+-------------------+------------------------+
Detailed Performance Metrics
+-------+------+------------+------------+-----------+-----------+-----------+-----------+-----------+---------------+
| Conc. | RPS | Avg Lat.(s)| P99 Lat.(s)| Gen. Toks/s| Avg TTFT(s)| P99 TTFT(s)| Avg TPOT(s)| P99 TPOT(s)| Success Rate |
+-------+------+------------+------------+-----------+-----------+-----------+-----------+-----------+---------------+
| 1 | 0.07 | 16.191 | 16.475 | 70.40 | 0.080 | 0.085 | 0.016 | 0.016 | 100.0% |
| 10 | 0.53 | 18.927 | 19.461 | 540.87 | 0.503 | 0.562 | 0.018 | 0.019 | 100.0% |
+-------+------+------------+------------+-----------+-----------+-----------+-----------+-----------+---------------+
Best Performance Configuration
Highest RPS: Concurrency 10 (0.53 req/sec)
Lowest Latency: Concurrency 1 (16.191 seconds)
Performance Recommendations:
* The system seems not to have reached its performance bottleneck, try higher concurrency
```

View File

@@ -0,0 +1,11 @@
# Performance_benchmark
This document details the performance testing methods for vllm-kunlun and the analysis of the results to ultimately optimize performance. The main considerations are server throughput and operator performance.
:::{toctree}
:caption: Performance
:maxdepth: 1
benchmark_server
benchmark_kernel
profiling
:::

View File

@@ -0,0 +1,418 @@
## Profiling
### 🔧 Action PlanThree Phases
#### Phase 1⃣: Multi-Device Log Redirection Configuration
##### Background
By default, kernel logs from all 8 XPU devices are interleaved and emitted to [stdout], resulting in:
- It becomes impossible to distinguish which log originates from which device.
- Timestamps become interleaved, making it difficult to analyze the temporal relationships.
- Single-device bottlenecks are masked by global aggregation.
##### Solution
During model initialization, create separate log files for each device.
##### Code Explanation (embedded in qwen2.py)
```python
import os # ← Ensure this is imported at the top of the file
from vllm.distributed import get_tensor_model_parallel_rank # ← Import function to get the tensor model parallel rank
class Qwen2Model(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
super().__init__()
# ========== [Expert Solution] Kunlun XPU Multi-Device Log Redirection ==========
try:
# Step 1: Get the current XPU device's rank (0~7)
rank = get_tensor_model_parallel_rank()
# Step 2: Create log directory (works with your get_kernel_time_ex.py)
log_dir = "./xpu_logs"
os.makedirs(log_dir, exist_ok=True)
# Step 3: Generate a separate log file for each device
log_file = os.path.join(log_dir, f"rank_{rank}.log")
# Step 4: Core operation redirect file descriptors
# os.O_TRUNC: Clear previous logs on each run to avoid mixing outputs
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o664)
os.dup2(fd, 1) # Redirect stdout → rank_X.log
os.dup2(fd, 2) # Redirect stderr → rank_X.log
os.close(fd) # Close original file descriptor; redirection persists
# Optional: print a confirmation message (will go into rank_X.log)
print(f"[Qwen2Model Init] Rank {rank} log redirected to {log_file}")
except Exception as e:
# Fallback mechanism: failure to redirect logs does not affect model loading
print(f"[WARNING] Failed to redirect log for rank: {e}", flush=True)
# ========== End of log redirection code ==========
```
##### ⚠️ Common Issues
**Q1**:Why not use Python's `logging` module?
**A**:The XPU runtime kernel logs are emitted from the C++ layer and cannot be captured by Pythons `logging` module. Redirection via low-level file descriptors is required.
**Q1**:Will logs be lost if the model fails to load??
**A**:The `try-except` block ensures that if log redirection fails, it falls back to the default behavior without affecting model startup.
#### Phase 2⃣: Profiling Environment Activation
##### 🚀 vLLM Launch
```bash
unset XPU_DUMMY_EVENT
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export XPU_USE_MOE_SORTED_THRES=1
export XFT_USE_FAST_SWIGLU=1
export XMLIR_CUDNN_ENABLED=1
export XPU_USE_DEFAULT_CTX=1
export XMLIR_FORCE_USE_XPU_GRAPH=1
export XPU_USE_FAST_SWIGLU=1
export VLLM_HOST_IP=$(hostname -i)
echo "VLLM_HOST_IP: $VLLM_HOST_IP"
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false
export XPUAPI_DEBUG=0x1 # Enable kernel performance logging
export XPURT_DISPATCH_MODE=PROFILING # Activate profiling mode
USE_ORI_ROPE=1 VLLM_USE_V1=1 python -m vllm.entrypoints.openai.api_server \
      --host 0.0.0.0 \
      --port 8000 \
      --model /models/Qwen2.5-72B-Instruct \
      --gpu-memory-utilization 0.9 \
      --trust-remote-code \
      --max-model-len 32768 \
      --tensor-parallel-size 8 \
      --dtype float16 \
      --max_num_seqs 512 \
      --max_num_batched_tokens 32768 \
      --max-seq-len-to-capture 32768 \
      --block-size 128 \
      --no-enable-prefix-caching \
      --no-enable-chunked-prefill \
      --distributed-executor-backend mp \
      --served-model-name Qwen2.5-72B-Instruct \
      --compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' 2>&1 | tee output_p800.log
```
##### 🚀 Client Load Testing
```bash
#!/bin/bash
# Define test combinations array (concurrency x input length x output length)
TEST_COMBINATIONS=(
"8x1024x1024" # Medium-low concurrency
)
# Create result directory
RESULT_DIR="bench_$(date +%Y%m%d_%H%M)"
mkdir -p $RESULT_DIR
# Summary results file
SUMMARY_FILE="$RESULT_DIR/summary_results.csv"
echo "num_prompts,input_len,output_len,throughput,latency_mean,latency_p50,latency_p90,latency_p99" >$SUMMARY_FILE
# Progress counter
TOTAL_TESTS=${#TEST_COMBINATIONS[@]}
CURRENT_TEST=0
# Loop through different test combinations
for COMBINATION in "${TEST_COMBINATIONS[@]}"; do
# Parse combination parameters
NUM_PROMPTS=$(echo $COMBINATION | cut -d'x' -f1)
INPUT_LEN=$(echo $COMBINATION | cut -d'x' -f2)
OUTPUT_LEN=$(echo $COMBINATION | cut -d'x' -f3)
# Update progress
CURRENT_TEST=$((CURRENT_TEST + 1))
echo "=========================================================="
echo "Test progress: $CURRENT_TEST/$TOTAL_TESTS ($(printf "%.1f" $(echo "$CURRENT_TEST/$TOTAL_TESTS*100" | bc -l))%)"
echo "Current test configuration: concurrency=$NUM_PROMPTS, input length=$INPUT_LEN, output length=$OUTPUT_LEN"
echo "=========================================================="
OUTPUT_FILE="$RESULT_DIR/p800_${NUM_PROMPTS}_${INPUT_LEN}_${OUTPUT_LEN}.log"
# Run benchmark
python3 -m vllm.entrypoints.cli.main bench serve \
--host 127.0.0.1 \
--port 8000 \
--backend vllm \
--model Qwen2.5-72B-Instruct \
--dataset-name random \
--num-prompts $NUM_PROMPTS \
--random-input-len $INPUT_LEN \
--random-output-len $OUTPUT_LEN \
--tokenizer /ssd1/models/Qwen2.5-72B-Instruct \
--ignore-eos 2>&1 | tee $OUTPUT_FILE
# Wait 15 seconds to let the service recover
echo "Waiting 15 seconds before the next round..."
sleep 15
# Extract key performance metrics from output and append to summary file
THROUGHPUT=$(grep "Throughput" $OUTPUT_FILE | awk '{print $2}')
LATENCY_MEAN=$(grep "Mean latency" $OUTPUT_FILE | awk '{print $3}')
LATENCY_P50=$(grep "p50 latency" $OUTPUT_FILE | awk '{print $3}')
LATENCY_P90=$(grep "p90 latency" $OUTPUT_FILE | awk '{print $3}')
LATENCY_P99=$(grep "p99 latency" $OUTPUT_FILE | awk '{print $3}')
echo "$NUM_PROMPTS,$INPUT_LEN,$OUTPUT_LEN,$THROUGHPUT,$LATENCY_MEAN,$LATENCY_P50,$LATENCY_P90,$LATENCY_P99" >>$SUMMARY_FILE
done
# Output summary report
echo "=========================================================="
echo "Benchmark completed! Results saved in: $RESULT_DIR"
echo "=========================================================="
```
#### Phase 3⃣: Log Analysis and Bottleneck Identification
```lua
xpu_logs/
├─ rank_0.log
├─ rank_1.log
├─ rank_2.log
├─ rank_3.log
├─ rank_4.log
├─ rank_5.log
├─ rank_6.log
└─ rank_7.log
```
##### 🔍 Script Workflow (op_log.py)
**Input**:Raw Kernel Logs (Sample Format)
```
[XPURT_PROF] void xblas_xpu3::fc_cdnn_infer<float16,...> 123456 ns
[XPURT_PROF] void kl3_all_reduce<float16> 987654 ns
```
**Processing logic**
:::::{tab-set}
::::{tab-item} op_log.py
```python
"""
A better version of 'get_op_time.py', get more level dump and support kl3.
 
Usage: python3 get_kernel_time_ex.py --help
"""
 
import os
import sys
import re
 
unit_factors = [0.9, 1.3, 1.45] # kunlun1, kunlun2, kunlun3
patterns = ["\[XPURT_PROF\] (\S+)\s+\S+\s+(\S+) ns", "\[XPURT_PROF\] (\S+)\s+(\S+)\s+\S+ ns"]
tab_space_num = int(4)
 
def get_total_time(res):
    total_time = 0.0
    for i in res.values():
        total_time += i
    return  total_time
 
def print_info_op(res, cnt, unit, op):
    total_time = get_total_time(res)
    total_cnt = 0
    # print detailed op time
    lis=sorted(res.items(), key=lambda d:d[1], reverse=True)
    if sys.version_info.major == 2:
        import commands
        for i in range(len(lis)):
            (status, cmd_output) = commands.getstatusoutput("c++filt {}".format(lis[i][0]))
            if status == 0:
                formt_type = (cmd_output.split('('))[0]
            total_cnt += cnt[lis[i][0]]
    elif sys.version_info.major == 3:
        import subprocess
        for i in range(len(lis)):
            (status, cmd_output) = subprocess.getstatusoutput("c++filt {}".format(lis[i][0]))
            if status == 0:
                formt_type = (cmd_output.split('('))[0]
            total_cnt += cnt[lis[i][0]]
    print(f"{op} {total_time / unit} {total_cnt}")
 
def print_info_kernel(res, cnt, unit):
    total_time = get_total_time(res)
    total_cnt = 0
    print("Total time(ms) is {}".format(total_time / unit))
    # print detailed op time
    lis=sorted(res.items(), key=lambda d:d[1], reverse=True)
    if sys.version_info.major == 2:
        print("{:<90}{:<10}{:<15}{:<15}".format("Op type", "count", "time(ms)", "%"))
        import commands
        for i in range(len(lis)):
            (status, cmd_output) = commands.getstatusoutput("c++filt {}".format(lis[i][0]))
            if status == 0:
                formt_type = (cmd_output.split('('))[0]
            print("{:<90}{:<10}{:<15}{:<15.5}".format(formt_type, cnt[lis[i][0]], lis[i][1] / unit, \
                lis[i][1] / total_time * 100))
            total_cnt += cnt[lis[i][0]]
    elif sys.version_info.major == 3:
        print("{:<90}{:<10}{:<20}{:<20}".format("Op type", "count", "time(ms)", "%"))
        import subprocess
        for i in range(len(lis)):
            (status, cmd_output) = subprocess.getstatusoutput("c++filt {}".format(lis[i][0]))
            if status == 0:
                formt_type = (cmd_output.split('('))[0]
            print("{:<150}{:<10}{:<25}{:<20.5}".format(formt_type, cnt[lis[i][0]], lis[i][1] / unit, \
                lis[i][1] / total_time * 100))
            total_cnt += cnt[lis[i][0]]
 
    print("Total count is {}".format(total_cnt))
 
def count_head_spaces(s: str) -> int:
   
    count = 0
    for char in s:
        if char == ' ':
            count += 1
        else:
            break
    return count
 
def process_line(lines, pattern1, unit_factor, dump_level):
    """ process a line in a file with profiling info
 
    Args:
        unit_factor: A factor differentiated by KUNLUN1 and KUNLUN2
 
    """
    res = {}
    cnt = {}
    op = "init_op"
    unit = unit_factor * 1000 * 1000 # ns -> ms
    wait_next_one = False
    for i in range(len(lines)):
        cur_line = lines[i]
        if "gtest_" in cur_line:
            cur_level = count_head_spaces(cur_line) / tab_space_num
            if cur_level == dump_level:
                wait_next_one = False
                print_info_op(res, cnt, unit, op)
                # clear buf
                res = {}
                cnt = {}
                op = cur_line.lstrip().rstrip()
            elif cur_level < dump_level:
                wait_next_one = True
                # skip record kernel time untime next one
                continue
        if wait_next_one:
            # skip record kernel time
            continue
        match = re.match(pattern1, lines[i])
        if match:
            op_type = match.group(1)
            op_time = match.group(2)
            if op_type in res:
                res[op_type] += float(op_time)
                cnt[op_type] += 1
            else:
                res[op_type] = float(op_time)
                cnt[op_type] = 1
 
    # get left total time
    if dump_level == -1:
        print_info_kernel(res, cnt, unit)
    else:
        print_info_op(res, cnt, unit, op)
    return res
 
def process_file(file_name, pattern2, unit_factor, dump_level = -1):
    """ Process a file line by line
 
    Iteratively process each line in the target file.
 
    """
 
    with open(file_name, "r") as f:
        lines = f.readlines()
        f1_res_list = process_line(lines, pattern2, unit_factor, dump_level)
 
if __name__ == '__main__':
    import argparse
 
    parser = argparse.ArgumentParser()
 
    group = parser.add_mutually_exclusive_group()
    group.add_argument('-xpu1', action='store_true', help='指定为 xpu1')
    group.add_argument('-xpu2', action='store_true', help='指定为 xpu2')
    group.add_argument('-xpu3', action='store_true', help='指定为 xpu3')
    parser.add_argument('--level', type=int, default=-1, help='指定 dump 缩进级别默认为 -1')
    parser.add_argument('filename', help='要处理的文件名')
 
    args = parser.parse_args()
 
    filename = args.filename
    xpu_version = 0
    if args.xpu2:
        xpu_version = 1
    if args.xpu3:
        xpu_version = 2
    dump_level = args.level
    print(f'Filename: {filename}')
    print(f'-xpu option: {xpu_version}')
    print(f'--level option: {dump_level}')
 
    unit_factor = unit_factors[xpu_version]
    pattern_idx = 0
    if xpu_version > 0:
        pattern_idx = 1
    process_file(filename, patterns[pattern_idx], unit_factor, dump_level)
 
```
::::
::::{tab-item} op_log.sh
```bash
for i in {0..7}; do
    python op_log.py -xpu3 xpu_logs/rank_${i}.log > analysis_rank${i}.log
    echo "Rank ${i} 分析完成"
done
for i in {0..7}; do
    echo "=== Rank $i ===" 
    head -n 6 analysis_rank${i}.log | tail -n 5
done
```
::::
:::::
##### 📈 Output Example (analysis_rank0.log)
```
Filename: xpu_logs/rank_0.log
-xpu option: 2
--level option: -1
Total time(ms) is 53742.29571862069
Op type                                                                                   count     time(ms)            %                   
void xblas_xpu3::fc_cdnn_infer<float16, float16, float16, float16, float, float, float, float, 1>                                                     661569    22736.262780689656       42.306              
void kl3_all_reduce<float16>                                                                                                                          176134    14782.525712413793       27.506              
void kl3_all_reduce_butterfly<float16>                                                                                                                164864    4197.28395862069         7.81           
```
##### 🚨 Troubleshooting Guide
|Symptom|Cause|Solution|
|-|-|-|
|`xpu_logs` directory is empty|XPUAPI_DEBUG not enabled|Verify that the environment variable is correctly set|
All 8 log files have identical content|Multi-process backend not activated|Ensure `--distributed-executor-backend` mp is specified|
|Throughput drops >15%|Profiling overhead too high|Enable profiling only during analysis; disable in production|

39
docs/source/faqs.md Normal file
View File

@@ -0,0 +1,39 @@
# FAQs
## Version Specific FAQs
- [[v0.10.1.1] FAQ & Feedback]
## General FAQs
### 1. What devices are currently supported?
Currently, **ONLY** Kunlun3 series(P800) series are supported
Below series are NOT supported yet:
- Kunlun4 series(M100 and M300)
- Kunlun2 series(R200)
- Kunlun1 series
We will support the kunlun4 M100 platform in early 2026.
### 2. How to get our docker containers?
**base**:`docker pull wjie520/vllm_kunlun:v0.0.1`.
### 3. How vllm-kunlun work with vLLM?
vllm-kunlun is a hardware plugin for vLLM. Basically, the version of vllm-kunlun is the same as the version of vllm. For example, if you use vllm 0.10.1.1, you should use vllm-kunlun 0.10.1.1 as well. For main branch, we will make sure `vllm-kunlun` and `vllm` are compatible by each commit.
### 4. How to handle the out-of-memory issue?
OOM errors typically occur when the model exceeds the memory capacity of a single XPU. For general guidance, you can refer to [vLLM OOM troubleshooting documentation](https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#out-of-memory).
In scenarios where XPUs have limited high bandwidth memory (HBM) capacity, dynamic memory allocation/deallocation during inference can exacerbate memory fragmentation, leading to OOM. To address this:
- **Limit `--max-model-len`**: It can save the HBM usage for kv cache initialization step.
- **Adjust `--gpu-memory-utilization`**: If unspecified, the default value is `0.9`. You can decrease this value to reserve more memory to reduce fragmentation risks. See details in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig).

69
docs/source/index.md Normal file
View File

@@ -0,0 +1,69 @@
# Welcome to vLLM Kunlun Plugin
:::{figure} ./logos/vllm-kunlun-logo-text-light.png
:align: center
:alt: vLLM
:class: no-scaled-link
:width: 70%
:::
:::{raw} html
<p style="text-align:center">
<strong>vLLM Kunlun Plugin
</strong>
</p>
<p style="text-align:center">
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a class="github-button" href="https://github.com/vllm-project/vllm" data-show-count="true" data-size="large" aria-label="Star">Star</a>
<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
</p>
:::
vLLM Kunlun (vllm-kunlun) is a community-maintained hardware plugin designed to seamlessly run vLLM on the Kunlun XPU. It is the recommended approach for integrating the Kunlun backend within the vLLM community, adhering to the principles outlined in the [[RFC]: Hardware pluggable](https://github.com/vllm-project/vllm/issues/11162). This plugin provides a hardware-pluggable interface that decouples the integration of the Kunlun XPU with vLLM.
By utilizing the vLLM Kunlun plugin, popular open-source models, including Transformer-like, Mixture-of-Expert, Embedding, and Multi-modal LLMs, can run effortlessly on the Kunlun XPU.
## Documentation
% How to start using vLLM on Kunlun XPU?
:::{toctree}
:caption: Getting Started
:maxdepth: 1
quick_start
installation
tutorials/index.md
faqs
:::
% What does vLLM Kunlun Plugin support?
:::{toctree}
:caption: User Guide
:maxdepth: 1
user_guide/support_matrix/index
user_guide/configuration/index
user_guide/feature_guide/index
user_guide/release_notes
:::
% How to contribute to the vLLM Kunlun project
:::{toctree}
:caption: Developer Guide
:maxdepth: 1
developer_guide/contribution/index
developer_guide/feature_guide/index
developer_guide/evaluation/index
developer_guide/performance/index
:::
% How to involve vLLM Kunlun
:::{toctree}
:caption: Community
:maxdepth: 1
community/governance
community/contributors
community/versioning_policy
community/user_stories/index
:::

129
docs/source/installation.md Normal file
View File

@@ -0,0 +1,129 @@
# Installation
This document describes how to install vllm-kunlun manually.
## Requirements
- **OS**: Ubuntu 22.04
- **Software**:
- Python >=3.10
- PyTorch ≥ 2.5.1
- vLLM (same version as vllm-kunlun)
## Setup environment using container
We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:v0.0.1`.You can pull it using the `docker pull` command.
### Container startup script
:::::{tab-set}
:sync-group: install
::::{tab-item} start_docker.sh
:selected:
:sync: pip
```{code-block} bash
:substitutions:
#!/bin/bash
XPU_NUM=8
DOCKER_DEVICE_CONFIG=""
if [ $XPU_NUM -gt 0 ]; then
for idx in $(seq 0 $((XPU_NUM-1))); do
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi
export build_image="wjie520/vllm_kunlun:v0.0.1"
docker run -itd ${DOCKER_DEVICE_CONFIG} \
--net=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
--cap-add=SYS_PTRACE \
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
--name "$1" \
-w /workspace \
"$build_image" /bin/bash
```
::::
:::::
## Install vLLM-kunlun
### Install vLLM 0.10.1.1
```
conda activate python310_torch25_cuda
pip install vllm==0.10.1.1 --no-build-isolation --no-deps
```
### Build and Install
Navigate to the vllm-kunlun directory and build the package:
```
git clone https://github.com/baidu/vLLM-Kunlun # TODO: replace with Github Url to install vllm-kunlun
cd vllm-kunlun
pip install -r requirements.txt
python setup.py build
python setup.py install
```
### Replace eval_frame.py
Copy the eval_frame.py patch:
```
cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/python310_torch25_cuda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
```
## Update xpytorch
```
wget https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/0830/xpytorch-cp310-torch251-ubuntu2004-x64.run
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
```
## Install custom ops
```
pip install \
https://xtorch_ops
pip install \
https://xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl
```
## Quick Start
### Set up the environment
```
chmod +x /workspace/vllm-kunlun/setup_env.sh && source /workspace/vllm-kunlun/setup_env.sh
```
### Run the server
:::::{tab-set}
:sync-group: install
::::{tab-item} start_service.sh
:selected:
:sync: pip
```{code-block} bash
:substitutions:
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8356 \
--model /models/Qwen3-8B\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 32768 \
--max-seq-len-to-capture 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-8B \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
::::
:::::

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

200
docs/source/quick_start.md Normal file
View File

@@ -0,0 +1,200 @@
# Quickstart
## Prerequisites
### Supported Devices
- Kunlun3 P800
## Setup environment using container
:::::{tab-set}
::::{tab-item} Ubuntu
```{code-block} bash
:substitutions:
#!/bin/bash
XPU_NUM=8
DOCKER_DEVICE_CONFIG=""
if [ $XPU_NUM -gt 0 ]; then
for idx in $(seq 0 $((XPU_NUM-1))); do
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi
export build_image="wjie520/vllm_kunlun:v0.0.1"
docker run -itd ${DOCKER_DEVICE_CONFIG} \
--net=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
--cap-add=SYS_PTRACE \
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
--name "$1" \
-w /workspace \
"$build_image" /bin/bash
```
::::
:::::
Start docker:
```bash
#start
bash ./rundocker.sh <container_name>
#Enter container
docker exec -it <container_name> bash
```
The default working directory is `/workspace`. With the fully provisioned environment image we provide, you can quickly start developing and running tasks within this directory.
## Set up system environment
```
#Set environment
chmod +x /workspace/vllm-kunlun/setup_env.sh && source /workspace/vllm-kunlun/setup_env.sh
```
## Usage
You can start the service quickly using the script below.
:::::{tab-set}
::::{tab-item} Offline Batched Inference
With vLLM installed, you can start generating texts for list of input prompts (i.e. offline batch inferencing).
Try to run below Python script directly or use `python3` shell to generate texts:
<!-- tests/e2e/doctest/001-quickstart-test.sh should be considered updating as well -->
```python
import os
from vllm import LLM, SamplingParams
def main():
model_path = "/models/Qwen3-8B"
llm_params = {
"model": model_path,
"tensor_parallel_size": 1,
"trust_remote_code": True,
"dtype": "float16",
"enable_chunked_prefill": False,
"distributed_executor_backend": "mp",
}
llm = LLM(**llm_params)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is your name?"
}
]
}
]
sampling_params = SamplingParams(
max_tokens=200,
temperature=1.0,
top_k=50,
top_p=1.0,
stop_token_ids=[181896]
)
outputs = llm.chat(messages, sampling_params=sampling_params)
response = outputs[0].outputs[0].text
print("=" * 50)
print("Input content:", messages)
print("Model response:\n", response)
print("=" * 50)
if __name__ == "__main__":
main()
```
::::
::::{tab-item} OpenAI Completions API
vLLM can also be deployed as a server that implements the OpenAI API protocol. Run
the following command to start the vLLM server with the
[Qwen3-8B]model:
<!-- tests/e2e/doctest/001-quickstart-test.sh should be considered updating as well -->
```bash
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8356 \
--model /models/Qwen3-8B\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 32768 \
--max-seq-len-to-capture 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-8B \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
If you see a log as below:
```bash
(APIServer pid=51171) INFO: Started server process [51171]
(APIServer pid=51171) INFO: Waiting for application startup.
(APIServer pid=51171) INFO: Application startup complete.
(Press CTRL+C to quit)
```
Congratulations, you have successfully started the vLLM server!
You can query the model with input prompts:
```bash
curl http://localhost:8356/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen3-8B",
"prompt": "What is your name?",
"max_tokens": 7,
"temperature": 0
}'
```
vLLM is serving as a background process, you can use `kill -2 $VLLM_PID` to stop the background process gracefully, which is similar to `Ctrl-C` for stopping the foreground vLLM process:
<!-- tests/e2e/doctest/001-quickstart-test.sh should be considered updating as well -->
```bash
VLLM_PID=$(pgrep -f "vllm serve")
kill -2 "$VLLM_PID"
```
The output is as below:
```
INFO: Shutting down FastAPI HTTP server.
INFO: Shutting down
INFO: Waiting for application shutdown.
INFO: Application shutdown complete.
```
Finally, you can exit the container by using `ctrl-D`.
::::
:::::

View File

@@ -0,0 +1,9 @@
# Tutorials
:::{toctree}
:caption: Deployment
:maxdepth: 1
single_xpu_Qwen3-8B
multi_xpu_GLM-4.5
multi_xpu_Qwen3-Coder-480B-A35B(W8A8)
:::

View File

@@ -0,0 +1,153 @@
# Multi XPU (GLM-4.5)
## Run vllm-kunlun on multi XPU
Setup environment using container:
```bash
docker run -itd \
--net=host \
--cap-add=SYS_PTRACE --security-opt=seccomp=unconfined \
--ulimit=memlock=-1 --ulimit=nofile=120000 --ulimit=stack=67108864 \
--shm-size=128G \
--privileged \
--name=glm-vllm-01011 \
-v ${PWD}:/data \
-w /workspace \
-v /usr/local/bin/:/usr/local/bin/ \
-v /lib/x86_64-linux-gnu/libxpunvidia-ml.so.1:/lib/x86_64-linux-gnu/libxpunvidia-ml.so.1 \
iregistry.baidu-int.com/hac_test/aiak-inference-llm:xpu_dev_20251113_221821 bash
docker exec -it glm-vllm-01011 /bin/bash
```
### Offline Inference on multi XPU
Start the server in a container:
```{code-block} bash
:substitutions:
import os
from vllm import LLM, SamplingParams
def main():
model_path = "/data/GLM-4.5"
llm_params = {
"model": model_path,
"tensor_parallel_size": 8,
"trust_remote_code": True,
"dtype": "float16",
"enable_chunked_prefill": False,
"distributed_executor_backend": "mp",
}
llm = LLM(**llm_params)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello, who are you?"
}
]
}
]
sampling_params = SamplingParams(
max_tokens=100,
temperature=0.7,
top_k=50,
top_p=1.0,
stop_token_ids=[181896]
)
outputs = llm.chat(messages, sampling_params=sampling_params)
response = outputs[0].outputs[0].text
print("=" * 50)
print("Input content:", messages)
print("Model response:\n", response)
print("=" * 50)
if __name__ == "__main__":
main()
```
:::::
If you run this script successfully, you can see the info shown below:
```bash
==================================================
Input content: [{'role': 'user', 'content': [{'type': 'text', 'text': 'Hello, who are you?'}]}]
Model response:
<think>
Well, the user asked a rather direct question about identity. This question seems simple, but there could be several underlying intentions—perhaps they are testing my reliability for the first time, or they simply want to confirm the identity of the conversational partner. From the common positioning of AI assistants, the user has provided a clear and flat way to define identity while leaving room for potential follow-up questions.\n\nThe user used "you" instead of "your", which leans towards a more informal tone, so the response style can be a bit more relaxed. However, since this is the initial response, it is better to maintain a moderate level of professionalism. Mentioning
==================================================
```
### Online Serving on Single XPU
Start the vLLM server on a single XPU:
```{code-block} bash
python -m vllm.entrypoints.openai.api_server \
--host localhost \
--port 8989 \
--model /data/GLM-4.5 \
--gpu-memory-utilization 0.95 \
--trust-remote-code \
--max-model-len 131072 \
--tensor-parallel-size 8 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 4096 \
--max-seq-len-to-capture 4096 \
--block-size 128 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name GLM-4.5 \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun", "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2"]}' > log_glm_plugin.txt 2>&1 &
```
If your service start successfully, you can see the info shown below:
```bash
(APIServer pid=51171) INFO: Started server process [51171]
(APIServer pid=51171) INFO: Waiting for application startup.
(APIServer pid=51171) INFO: Application startup complete.
```
Once your server is started, you can query the model with input prompts:
```bash
curl http://localhost:8989/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "GLM-4.5",
"messages": [
{"role": "user", "content": "Hello, who are you?"}
],
"max_tokens": 100,
"temperature": 0.7
}'
```
If you query the server successfully, you can see the info shown below (client):
```bash
{"id":"chatcmpl-6af7318de7394bc4ae569e6324a162fa","object":"chat.completion","created":1763101638,"model":"GLM-4.5","choices":[{"index":0,"message":{"role":"assistant","content":"\n<think>The user asked, \"Hello, who are you?\" This is a question about my identity. First, I need to confirm the user's intent. They might be using this service for the first time or have never interacted with similar AI assistants before, so they want to know my background and capabilities.\n\nNext, I should ensure my answer is clear and friendly, focusing on key points: who I am, who developed me, and what I can do. I should avoid technical jargon and keep the response conversational so it's easy to understand.\n\nAdditionally, the user may have potential needs, such as wanting to know what I am capable of.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":11,"total_tokens":111,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"kv_tr
```
Logs of the vllm server:
```bash
(APIServer pid=54567) INFO: 127.0.0.1:60338 - "POST /v1/completions HTTP/1.1" 200 OK
(APIServer pid=54567) INFO 11-13 14:35:48 [loggers.py:123] Engine 000: Avg prompt throughput: 0.5 tokens/s, Avg generation throughput: 0.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
```

View File

@@ -0,0 +1,132 @@
# Multi XPU (Qwen3-Coder-480B-A35B(W8A8))
## Run vllm-kunlun on Multi XPU
Setup environment using container:
```bash
# !/bin/bash
# rundocker.sh
XPU_NUM=8
DOCKER_DEVICE_CONFIG=""
if [ $XPU_NUM -gt 0 ]; then
for idx in $(seq 0 $((XPU_NUM-1))); do
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi
export build_image="xxxxxxxxxxxxxxxxx"
docker run -itd ${DOCKER_DEVICE_CONFIG} \
--net=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
--cap-add=SYS_PTRACE \
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
--name "$1" \
-w /workspace \
"$build_image" /bin/bash
```
### Preparation Weight
* Pull Qwen3-Coder-480B-A35B-Instruct bf16 weights
* Modify the weights configuration.json file and add the fields quantization_config and compression_config.
```json
{
"architectures": [
"Qwen3MoeForCausalLM"
],
"attention_dropout": 0.0,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 6144,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 262144,
"max_window_layers": 62,
"mlp_only_layers": [],
"model_type": "qwen3_moe",
"moe_intermediate_size": 2560,
"norm_topk_prob": true,
"num_attention_heads": 96,
"num_experts": 160,
"num_experts_per_tok": 8,
"num_hidden_layers": 62,
"num_key_value_heads": 8,
"output_router_logits": false,
"qkv_bias": false,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000000,
"router_aux_loss_coef": 0.0,
"shared_expert_intermediate_size": 0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_qk_norm": true,
"use_sliding_window": false,
"vocab_size": 151936,
"quantization_config": {
"quant_method": "compressed-tensors"
},
"compression_config": {
"format": "pack_quantized",
"config_groups": {
"linear_w8a8": {
"targets": ["Linear"],
"weights": {
"type": "int",
"num_bits": 8,
"strategy": "channel",
"group_size": null,
"symmetric": true,
"dynamic": false
},
"input_activations": {
"type": "int",
"num_bits": 8,
"strategy": "token",
"group_size": null,
"symmetric": true,
"dynamic": true
}
}
},
"ignore": [],
"sparsity_config": null
}
}
```
### Online Serving on Multi XPU
Start the vLLM server on multi XPU:
```bash
python3 -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8898 \
--model /Qwen/Qwen3-Coder-480B-A35B-Instruct \
--dtype float16 \
--trust-remote-code \
--tensor-parallel-size 8 \
--block-size 128 \
--max-model-len 40960 \
--max-num-seqs 512 \
--max-num-batched-tokens 40960 \
--max-seq-len-to-capture 40960 \
--distributed-executor-backend mp \
--enable-chunked-prefill=False \
--no-enable-prefix-caching \
--disable-log-requests \
--gpu-memory-utilization 0.85
```

View File

@@ -0,0 +1,168 @@
# Single XPU (Qwen3-8B)
## Run vllm-kunlun on Single XPU
Setup environment using container:
```bash
# !/bin/bash
# rundocker.sh
XPU_NUM=8
DOCKER_DEVICE_CONFIG=""
if [ $XPU_NUM -gt 0 ]; then
for idx in $(seq 0 $((XPU_NUM-1))); do
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi
export build_image="xxxxxxxxxxxxxxxxx"
docker run -itd ${DOCKER_DEVICE_CONFIG} \
--net=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
--cap-add=SYS_PTRACE \
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
--name "$1" \
-w /workspace \
"$build_image" /bin/bash
```
### Offline Inference on Single XPU
Start the server in a container:
```{code-block} bash
from vllm import LLM, SamplingParams
def main():
model_path = "/models/Qwen3-8B"
llm_params = {
"model": model_path,
"tensor_parallel_size": 1,
"trust_remote_code": True,
"dtype": "float16",
"enable_chunked_prefill": False,
"distributed_executor_backend": "mp",
}
llm = LLM(**llm_params)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "tell a joke"
}
]
}
]
sampling_params = SamplingParams(
max_tokens=200,
temperature=1.0,
top_k=50,
top_p=1.0,
stop_token_ids=[181896]
)
outputs = llm.chat(messages, sampling_params=sampling_params)
response = outputs[0].outputs[0].text
print("=" * 50)
print("Input content:", messages)
print("Model response:\n", response)
print("=" * 50)
if __name__ == "__main__":
main()
```
:::::
If you run this script successfully, you can see the info shown below:
```bash
==================================================
Input content: [{'role': 'user', 'content': [{'type': 'text', 'text': 'tell a joke'}]}]
Model response:
<think>
Okay, the user asked me to tell a joke. First, I need to consider the user's needs. They might just want to relax or need some entertainment. Next, I need to choose a suitable joke that is not too complicated, easy to understand, and also interesting.
The user might expect the joke to be in Chinese, so I need to ensure that the joke conforms to the language habits and cultural background of Chinese. I need to avoid sensitive topics, such as politics, religion, or anything that might cause misunderstanding. Then, I have to consider the structure of the joke, which usually involves a setup and an unexpected ending to create humor.
For example, I could tell a light-hearted story about everyday life, such as animals or common scenarios. For instance, the story of a turtle and a rabbit racing, but with a twist. However, I need to ensure that the joke is of moderate length and not too long, so the user doesn't lose interest. Additionally, I should pay attention to using colloquial language and avoid stiff or complex sentence structures.
I might also need to check if this joke is common to avoid repetition. If the user has heard something similar before, I may need to come up with a different angle.
==================================================
```
### Online Serving on Single XPU
Start the vLLM server on a single XPU:
```{code-block} bash
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 9000 \
--model /models/Qwen3-8B\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 32768 \
--max-seq-len-to-capture 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-8B \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
If your service start successfully, you can see the info shown below:
```bash
(APIServer pid=118459) INFO: Started server process [118459]
(APIServer pid=118459) INFO: Waiting for application startup.
(APIServer pid=118459) INFO: Application startup complete.
```
Once your server is started, you can query the model with input prompts:
```bash
curl http://localhost:9000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen3-8B",
"prompt": "What is your name?",
"max_tokens": 100,
"temperature": 0
}'
```
If you query the server successfully, you can see the info shown below (client):
```bash
{"id":"cmpl-80ee8b893dc64053947b0bea86352faa","object":"text_completion","created":1763015742,"model":"Qwen3-8B","choices":[{"index":0,"text":" is the S, and ,","logprobs":null,"finish_reason":"length","stop_reason":null,"prompt_logprobs":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":5,"total_tokens":12,"completion_tokens":7,"prompt_tokens_details":null},"kv_transfer_params":null}
```
Logs of the vllm server:
```bash
(APIServer pid=54567) INFO: 127.0.0.1:60338 - "POST /v1/completions HTTP/1.1" 200 OK
(APIServer pid=54567) INFO 11-13 14:35:48 [loggers.py:123] Engine 000: Avg prompt throughput: 0.5 tokens/s, Avg generation throughput: 0.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
```

View File

@@ -0,0 +1,17 @@
# Environment Variables
vllm-kunlun uses the following environment variables to configure the system:
| *Environment Variables* | ***\*Recommended value\**** | ***\*Function description\**** |
| ---------------------------------------- | ----------------- | ------------------------------------------------------------ |
| `unset XPU_DUMMY_EVENT` | | ***\*Unsets\**** `XPU_DUMMY_EVENT` variable, usually done to ensure real XPU events are used for synchronization and performance measurement. |
| `export XPU_VISIBLE_DEVICES` | `0,1,2,3,4,5,6,7` | ***\*Specify visible XPU Devices\****. Here, 8 devices (0 to 7) are specified for inference tasks. This is required for multi-card or distributed inference. |
| `export XPU_USE_MOE_SORTED_THRES` | `1` | Enables the Moe Model ***\*Sort Optimization\****.Setting to `1` usually enables this performance optimization. |
| `export XFT_USE_FAST_SWIGLU` | `1` | Enables the ***\*Fast SwiGLU Ops\****. SwiGLU is a common activation function, and enabling this accelerates model inference. |
| `export XPU_USE_FAST_SWIGLU` | `1` | Enables the ***\*Fast SwiGLU Ops\****. Similar to `XFT_USE_FAST_SWIGLU`, this enables the fast SwiGLU calculation in Fused MoE Fusion Ops. |
| `export XMLIR_CUDNN_ENABLED` | `1` | Enables XMLIR (an intermediate representation/compiler) to use the ***\*cuDNN compatible/optimized path\**** (which may map to corresponding XPU optimized libraries in the KunlunCore environment). |
| `export XPU_USE_DEFAULT_CTX` | `1` | Sets the XPU to use the default context. Typically used to simplify environment configuration and ensure runtime consistency. |
| `export XMLIR_FORCE_USE_XPU_GRAPH` | `1` | ***\*Forces the enablement of XPU Graph mode.\****. This can capture and optimize the model execution graph, significantly boosting inference performance. |
| `export VLLM_HOST_IP` | `$(hostname -i)` | ***\*Sets the host IP address for the vLLM service\****. This uses a shell command to dynamically get the current host's internal IP. It's used for inter-node communication in a distributed environment. |
| `export XMLIR_ENABLE_MOCK_TORCH_COMPILE` | `false` | ***\*Disable Mock Torch Compile Function\****. Set to `false` to ensure the actual compilation and optimization flow is used, rather than mock mode. |
| `FUSED_QK_ROPE_OP` | `0` | ***\*Control whether to use the Fused QK-Norm and RoPE implementation\****. Default is `0` (use original/standard RoPE). Setting to `1` may be used to enable QWEN3. |

View File

@@ -0,0 +1,9 @@
# Configuration Guide
This section provides a detailed configuration guide of vLLM Kunlun.
:::{toctree}
:caption: Configuration Guide
:maxdepth: 1
env_vars
:::

View File

@@ -0,0 +1,82 @@
# Graph Mode Guide
This guide provides instructions for using Kunlun Graph Mode with vLLM Kunlun. Please note that graph mode is available both on V1 and V0 Engine. All supported models are highly compatible with Kunlun Graph.
## Getting Started
From vLLM-KunLun-0.10.1.1 with V1 Engine, vLLM Kunlun will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model.
There is a graph mode supported by vLLM Kunlun:
- **KunlunGraph**: This is the default graph mode supported by vLLM Kunlun. In vLLM-KunLun-0.10.1.1, Qwen, GLM and InternVL series models are well tested.
## Using KunlunGraph
KunlunGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine(default) is enough.
Offline example:
```python
import os
from vllm import LLM
model = LLM(model="models/Qwen3-8B-Instruct")
outputs = model.generate("Hello, how are you?")
```
Online example:
```shell
vllm serve Qwen3-8B-Instruct
```
## Using KunlunGraph
Enabling Kunlun Graph on the Kunlun platform requires the use of splitting ops.
Online example:
```shell
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8000 \
--model /models/Qwen3-8B-Instruct\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-8B-Instruct \
--compilation-config '{"splitting_ops": ["vllm.unified_attention_with_output_kunlun",
"vllm.unified_attention", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2"]}' \
```
## Fallback to the Eager Mode
If `KunlunGraph` fail to run, you should fallback to the eager mode.
Online example:
```shell
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8000 \
--model /models/Qwen3-8B-Instruct\
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-8B-Instruct \
--enforce_eager
```

View File

@@ -0,0 +1,11 @@
# Feature Guide
This section provides a detailed usage guide of vLLM Kunlun features.
:::{toctree}
:caption: Feature Guide
:maxdepth: 1
graph_mode
quantization
lora
:::

View File

@@ -0,0 +1,27 @@
# LoRA Adapters Guide
## Overview
Like vLLM, vllm_kunlun supports LoRA as well. The usage and more details can be found in [vLLM official document ](https://docs.vllm.ai/en/latest/features/lora.html).
You can refer to [Supported Models ](https://docs.vllm.ai/en/latest/models/supported_models.html#list-of-text-only-language-models)to find which models support LoRA in vLLM.
Currently, only vLLM v0 mode (including eager and CUDA Graph modes) supports multi-LoRA inference in vllm_kunlun.
## Example
We provide a simple LoRA example here:
```bash
export ENABLE_KUNLUN_LARGE_OPS=0
USE_ORI_ROPE=0 VLLM_USE_V1=0 vllm serve qwen3-8b \
--enable-lora \
--max-lora-rank 64 \
--lora-modules lora1=/path/to/lora1 lora2=/path/to/lora2
```
## Custom LoRA Operators
We have implemented LoRA-related custom operators for Kunlun hardware, such as `bgmv_shrink`, `bgmv_expand`, `sgmv_shrink`, and `sgmv_expand`. The implementation can be found in `vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py`.

View File

@@ -0,0 +1,45 @@
# Quantization Guide
>Note: This feature is currently experimental. In future versions, there may be behavioral changes around configuration, coverage, performance improvement.
Like vLLM, we now support quantization methods such as compressed-tensors, AWQ, and GPTQ, enabling various precision configurations including W8A8, W4A16, and W8A16. These can help reduce memory consumption and accelerate inference while preserving model accuracy.
## Usages
### Compressed-tensor
To run a `compressed-tensors` model with vLLM-kunlun, you should first add the below configuration to the model's `config.json`:
```Bash
"quantization_config": {
"quant_method": "compressed-tensors"
}
```
Then you run `Qwen/Qwen3-30B-A3B` with dynamic W8A8 quantization with the following command:
```Bash
python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen3-30B-A3B \
--quantization compressed-tensors
```
### AWQ
To run an `AWQ` model with vLLM-kunlun, you can use `Qwen/Qwen3-32B-AWQ` with the following command:
```Bash
python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen3-32B-AWQ \
--quantization awq
```
### GPTQ
To run a `GPTQ` model with vLLM-kunlun, you can use `Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4` with the following command:
```Bash
python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 \
--quantization gptq
```

View File

@@ -0,0 +1,3 @@
# Release Notes
Comming soon...

View File

@@ -0,0 +1,10 @@
# Features and Models
This section provides a detailed matrix supported by vLLM-Kunlun.
:::{toctree}
:caption: Support Matrix
:maxdepth: 1
supported_models
supported_features
:::

View File

@@ -0,0 +1,14 @@
# Supported Features
The feature support principle of vLLM-KunLun is: **aligned with the vLLM**. We are also actively collaborating with the community to accelerate support.
You can check the [support status of vLLM V1 Engine][v1_user_guide]. Below is the feature support status of vLLM-KunLun:
## Features Supported
|Feature|Status|Note|
|-|-|-|
|Tensor Parallel|🟢 Functional||
|Experts Parallel|🟢 Functional||
|Graph Mode|🟢 Functional||
|Quantization| 🟢 Functional||
|LoRA|⚠️ Need Test|Only LLM models|

View File

@@ -0,0 +1,33 @@
# Supported Models
## Generative Models
| Model | Support | W8A8 | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Piecewise Kunlun Graph |
| :------------ | :------------ | :--- | :--- | :-------------- | :-------------- | :------------ | :--------------------- |
| Qwen2 | ✅ | | ✅ | ✅ | | ✅ | ✅ |
| Qwen2.5 | ✅ | | ✅ | ✅ | | ✅ | ✅ |
| Qwen3 | ✅ | | ✅ | ✅ | | ✅ | ✅ |
| Qwen3-Moe | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen3-Coder | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| QwQ-32B | ✅ | | | ✅ | | ✅ | ✅ |
| LLama2 | ✅ | | | ✅ | | ✅ | ✅ |
| LLama3 | ✅ | | | ✅ | | ✅ | ✅ |
| LLama3.1 | ✅ | | | ✅ | | ✅ | ✅ |
| GLM-4.5 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| GLM-4.5-Air | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen3-next | 🔜Comming soon | | | | | | |
| gpt-oss | 🔜Comming soon | | | | | | |
| DeepSeek-V3 | 🔜Comming soon | | | | | | |
| DeepSeek-V3.2 | 🔜Comming soon | | | | | | |
## Multimodal Language Models
| Model | Support | W8A8 | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Piecewise Kunlun Graph |
| :----------- | :------------ | :--- | :--- | :-------------- | :-------------- | :------------ | :--------------------- |
|Qianfan-VL | ✅ | | | ✅| |✅ |✅|
| Qwen2.5VL | ✅ | | | ✅ | | ✅ | ✅ |
| InternVL2.5 | ✅ | | | ✅ | | ✅ | ✅ |
| InternVL3 | ✅ | | | ✅ | | ✅ | ✅ |
| InternVL3.5 | ✅ | | | ✅ | | ✅ | ✅ |
| InternS1 | ✅ | | | ✅ | | ✅ | ✅ |
| Qwen2.5-Omni | 🔜Comming soon | | | | | | |
| Qwen3-VL | 🔜Comming soon | | | | | | |

30
pyproject.toml Normal file
View File

@@ -0,0 +1,30 @@
[build-system]
requires = ["hatchling>=1.22"]
build-backend = "hatchling.build"
[project]
name = "vllm-kunlun"
version = "0.10.1.1"
description = "vLLM Kunlun3 backend plugin"
readme = "README.md"
requires-python = ">=3.10"
license = { text = "MIT" }
authors = [{ name = "kunlun"}]
dependencies = []
[project.scripts]
vllm-kunlun = "vllm_kunlun.cmdline:main"
[project.entry-points."vllm.platform_plugins"]
kunlun = "vllm_kunlun:register"
[project.entry-points."vllm.general_plugins"]
kunlun_model = "vllm_kunlun:register_model"
[tool.hatch.build]
packages = ["vllm_kunlun"]
include = ["vllm_kunlun/conf/*", "vllm_kunlun/data/*"]
[tool.hatch.build.targets.wheel]
packages = ["vllm_kunlun"]
output-dir = "output/dist"

34
requirements.txt Normal file
View File

@@ -0,0 +1,34 @@
setuptools==80.9.0
black==23.3.0
blake3==1.0.5
cachetools==6.1.0
cbor2==5.7.0
cloudpickle==3.1.1
compressed-tensors==0.10.2
diskcache==5.6.3
gguf==0.17.1
mistral_common==1.8.3
msgspec==0.19.0
numba==0.61.2
openai==1.99.1
openai-harmony==0.0.4
opencv-contrib-python==4.12.0.88
partial-json-parser==0.2.1.1.post6
prometheus_client==0.22.1
pybase64==1.4.1
pyzmq==27.0.1
ray==2.48.0
setproctitle==1.3.7
watchfiles==1.1.0
pydantic==2.11.7
tokenizers>=0.21.2
uvloop==0.21.0
prometheus-fastapi-instrumentator==7.1.0
transformers>=4.56.1
hatchling>=1.25
build>=1.0.3
pytest
mock

66
setup.py Normal file
View File

@@ -0,0 +1,66 @@
#
# setup.py for vllm_kunlun
#
import os
import shutil
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CppExtension, BuildExtension
ROOT_DIR = os.path.dirname(__file__)
ext_modules = [
CppExtension(
name='vllm_kunlun._kunlun',
sources=['vllm_kunlun/csrc/utils.cpp'],
include_dirs=[
'vllm_kunlun/csrc',
"/usr/local/cuda/include",
],
library_dirs=["/usr/local/cuda/lib64"],
extra_compile_args=['-O3'],
)
]
class CustomBuildExt(BuildExtension):
def run(self):
super().run()
for ext in self.extensions:
ext_path = self.get_ext_fullpath(ext.name)
file_name = os.path.basename(ext_path)
target_path = os.path.join("vllm_kunlun", file_name)
if os.path.exists(target_path):
os.remove(target_path)
shutil.copyfile(ext_path, target_path)
print(f"[BuildExt] Copied {ext_path} -> {target_path}")
if __name__ == '__main__':
setup(
name='vllm_kunlun',
version="v1.0",
author="vLLM-Kunlun team",
license="Apache 2.0",
description="vLLM Kunlun3 backend plugin",
packages=find_packages(exclude=("docs", "examples", "tests*")),
package_data={
'vllm_kunlun': ['_kunlun.so', 'so/*.so', 'include/*.h']
},
python_requires=">=3.10",
ext_modules=ext_modules,
cmdclass={
'build_ext': CustomBuildExt,
},
entry_points={
'vllm.platform_plugins': ["kunlun = vllm_kunlun:register"],
'vllm.general_plugins': [
"kunlun_model = vllm_kunlun:register_model",
"kunlun_quant = vllm_kunlun:register_quant_method"
],
"console_scripts": [
"vllm_kunlun = vllm_kunlun.entrypoints.main:main"
]
}
)

11
setup_env.sh Normal file
View File

@@ -0,0 +1,11 @@
unset XPU_DUMMY_EVENT
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export XPU_USE_MOE_SORTED_THRES=1
export XFT_USE_FAST_SWIGLU=1
export XMLIR_CUDNN_ENABLED=1
export XPU_USE_DEFAULT_CTX=1
export XMLIR_FORCE_USE_XPU_GRAPH=1
export XPU_USE_FAST_SWIGLU=1
export VLLM_HOST_IP=$(hostname -i)
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false
export FUSED_QK_ROPE_OP=0

180
vllm_kunlun/__init__.py Normal file
View File

@@ -0,0 +1,180 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Xinyu Dong
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""vllm kunlun init"""
from .platforms import current_platform
import sys
import importlib
import warnings
import builtins
import os
import time
import vllm.envs as envs
OLD_IMPORT_HOOK = builtins.__import__
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
try:
start_time = time.time()
# Module mapping table
module_mappings = {
"vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer",
"vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe",
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
}
# Keep the original imported modules
original_imports = [
"vllm.model_executor.layers.fused_moe.base",
"vllm.model_executor.layers.fused_moe.config",
"vllm.model_executor.layers.fused_moe.layer",
]
if module_name in original_imports:
if module_name == "vllm.model_executor.layers.fused_moe.layer" and fromlist:
if "FusedMoEMethodBase" in fromlist:
return OLD_IMPORT_HOOK(
module_name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level,
)
if module_name in module_mappings:
if module_name in sys.modules:
return sys.modules[module_name]
target_module = module_mappings[module_name]
module = importlib.import_module(target_module)
sys.modules[module_name] = module
sys.modules[target_module] = module
return module
relative_mappings = {
(
"compressed_tensors_moe",
"compressed_tensors",
): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer",
}
if level == 1:
parent = globals.get("__package__", "").split(".")[-1] if globals else ""
key = (module_name, parent)
if key in relative_mappings:
if module_name in sys.modules:
return sys.modules[module_name]
target_module = relative_mappings[key]
module = importlib.import_module(target_module)
sys.modules[module_name] = module
sys.modules[target_module] = module
return module
except Exception:
pass
return OLD_IMPORT_HOOK(
module_name, globals=globals, locals=locals, fromlist=fromlist, level=level
)
def import_hook():
"""Apply import hook for VLLM Kunlun"""
if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")):
builtins.__import__ = _custom_import
try:
modules_to_preload = [
"vllm_kunlun.ops.quantization.compressed_tensors_moe",
"vllm_kunlun.ops.fused_moe.custom_ops",
"vllm_kunlun.ops.fused_moe.layer",
"vllm_kunlun.ops.quantization.fp8",
]
for module_name in modules_to_preload:
importlib.import_module(module_name)
except Exception:
pass
def register():
"""Register the Kunlun platform"""
from .utils import redirect_output
from .vllm_utils_wrapper import (
direct_register_custom_op,
patch_annotations_for_schema,
)
import_hook()
if envs.VLLM_USE_V1:
patch_V1blockTable()
patch_V1top_p_K()
patch_V1penalties()
else:
patch_sampler()
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
def register_model():
"""Register models for training and inference"""
from .models import register_model as _reg
_reg()
def patch_sampler():
try:
custom_sampler = importlib.import_module("vllm_kunlun.ops.sample.sampler")
sys.modules["vllm.model_executor.layers.sampler"] = custom_sampler
print("[vllm_kunlun] sampler patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] sampler patch failed: {e!r}")
def patch_V1top_p_K():
try:
custom_sampler = importlib.import_module(
"vllm_kunlun.v1.sample.ops.topk_topp_sampler"
)
sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler
print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}")
def patch_V1penalties():
try:
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.penalties")
sys.modules["vllm.v1.sample.ops.penalties"] = custom_sampler
print("[vllm_kunlun] V1sampler penalties patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 sampler penalties patch failed: {e!r}")
def patch_V1blockTable():
try:
custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table")
sys.modules["vllm.v1.worker.block_table"] = custom_sampler
print("[vllm_kunlun] V1 block table patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}")
# Automatically apply patches when modules are imported
import_hook()

View File

View File

@@ -0,0 +1,148 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian
# Email: baoqian@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
class TorchCompileWrapperWithCustomDispatcher:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
1. Implement the forward method
2. Implement the dispatch logic in the __call__ method
It can use `self.compiled_codes` to access the compiled bytecode,
and `with self.dispatch_to_code(index):` to dispatch to
the compiled code.
3. Implement the `__init__` method to determine how to call
`torch.compile` over the forward method.
"""
def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
if compiled_callable is None:
# default compilation settings
# compiling the forward method
backend = vllm_config.compilation_config.init_backend(vllm_config)
options = None
if isinstance(backend, str) and backend == "inductor":
options = get_current_vllm_config(
).compilation_config.inductor_compile_config
compiled_callable = torch.compile(
self.forward,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend,
options=options)
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: list[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
# read the env var to determine whether to use the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
from vllm.config import CompilationLevel
self.use_custom_dispatcher: bool = \
compilation_level >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)
@abstractmethod
def forward(self, *args, **kwargs):
...
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while frame and frame.f_back:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code
if frame.f_locals["self"] is not self:
return
self.compiled_codes.append(new_code)
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
if isinstance(local_cache_dir, str):
decompiled_file = os.path.join(local_cache_dir,
"transformed_code.py")
if not os.path.exists(decompiled_file):
try:
# usually the decompilation will succeed for most models,
# as we guarantee a full-graph compilation in Dynamo.
# but there's no 100% guarantee, since decompliation is
# not a reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)
logger.debug("Dynamo transformed code saved to %s",
decompiled_file)
except Exception:
pass
# if self.vllm_config.compilation_config.use_cudagraph and \
# "update" in new_code.co_names:
# import depyf
# src = depyf.decompile(new_code)
# msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
# raise RuntimeError(msg)
@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object

View File

@@ -0,0 +1,49 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#pragma once
#include <torch/all.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@@ -0,0 +1,32 @@
#include "xops.h"
#include "dispatch_utils.h"
#include <torch/extension.h>
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
// Ensure tensor is on CUDA
if (!tensor.is_cuda()) {
throw std::runtime_error("Tensor must be on CUDA device");
}
// Get the raw data pointer
void* data_ptr = tensor.data_ptr();
// Get tensor sizes and strides
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
// Get tensor options (dtype, device)
auto options = tensor.options();
// Create a new tensor from the raw data pointer
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
TORCH_LIBRARY(_kunlun, m) {
m.def("weak_ref_tensor", &weak_ref_tensor);
}
PYBIND11_MODULE(_kunlun, m) {
m.def("weak_ref_tensor", &weak_ref_tensor);
}

241
vllm_kunlun/csrc/xops.h Normal file
View File

@@ -0,0 +1,241 @@
#ifndef OPS_H
#define OPS_H
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
void rms_norm_xpu(torch::Tensor &output,
torch::Tensor &input,
torch::Tensor &weight,
double eps);
// inplace
void fused_add_rms_norm_xpu(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon);
void silu_and_mul_xpu(torch::Tensor &output,
torch::Tensor &input);
void quick_gelu_xpu(torch::Tensor &output,
torch::Tensor &input);
// neox && gptj
void rotary_embedding(torch::Tensor &positions,
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding(torch::Tensor &positions,
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int64_t rot_dim,
torch::Tensor& offsets);
// x = 16 // sizeof(cache dtype)
void paged_attention_v1_xpu(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
int64_t num_kv_heads,
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
torch::Tensor& seq_lens_host, // [num_seqs]
int64_t block_size,
int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, // [num_heads]
const std::string& kv_cache_dtype,
double k_scale,
double v_scale,
int64_t tp_rank, int64_t blocksparse_local_blocks, // no used but to keep same with vllm-offficial
int64_t blocksparse_vert_stride, int64_t blocksparse_block_size, // no used but to keep same with vllm-offficial
int64_t blocksparse_head_sliding_step // no used but to keep same with vllm-offficial
);
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype,
const double k_scale,
const double v_scale);
void flash_attention_context_vllm_xpu(
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads, head_size]
torch::Tensor& value, // [num_tokens, num_kv_heads, head_size]
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& seq_lod, // [batch_size + 1]
torch::Tensor& seq_lod_host, // [batch_size + 1]
int64_t max_seq_len,
int64_t max_kv_len,
double scale,
const c10::optional<torch::Tensor>& alibi_slopes, // [num_heads],
const c10::optional<torch::Tensor>& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const c10::optional<torch::Tensor>& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const c10::optional<torch::Tensor>& block_tables, // [num_seqs, max_num_blocks_per_seq]
const c10::optional<torch::Tensor>& kv_prefix_start_loc, // [lod of prefix]
const c10::optional<torch::Tensor>& kv_prefix_start_loc_host, // [lod of prefix]
const c10::optional<bool> is_causal // use causal mask or not, default true
);
void paged_attention_v2_xpu(
torch::Tensor &out,
torch::Tensor &exp_sums,
torch::Tensor &max_logits,
torch::Tensor &tmp_out,
torch::Tensor &query, // [num_seqs, num_heads, head_size]
torch::Tensor &
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor &
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
int64_t num_kv_heads,
double scale,
torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor &seq_lens, // [num_seqs]
torch::Tensor& seq_lens_host, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, // [num_heads]
const std::string &kv_cache_dtype, double k_scale, double v_scale,
int64_t tp_rank, int64_t blocksparse_local_blocks, // no used but to keep same with vllm-offficial
int64_t blocksparse_vert_stride, int64_t blocksparse_block_size, // no used but to keep same with vllm-offficial
int64_t blocksparse_head_sliding_step // no used but to keep same with vllm-offficial
);
void weight_only_quant_matmul_xpu(
torch::Tensor &x,
torch::Tensor &out,
torch::Tensor &qweight,
torch::Tensor &qscale
);
void multi_latent_attention_xpu(
torch::Tensor q,
torch::Tensor kv_rope_cache,
torch::Tensor out,
torch::Tensor block_tables,
torch::Tensor seq_lens,
double scale,
int64_t max_seq_len
);
void outplace_fused_experts_xpu(
torch::Tensor &hidden_states,
torch::Tensor &output,
torch::Tensor &w1,
torch::Tensor &w2,
torch::Tensor &topk_weights,
torch::Tensor &topk_ids
);
void outplace_fused_experts_sorted_xpu(
torch::Tensor &hidden_states,
torch::Tensor &output,
torch::Tensor &w1,
torch::Tensor &w2,
torch::Tensor &topk_weights,
torch::Tensor &topk_ids
);
void grouped_topk_xpu(torch::Tensor &router_logits,
torch::Tensor& score_bias,
torch::Tensor& topk_weight,
torch::Tensor& topk_ids,
double scale,
int64_t expert_group_num,
int64_t moe_topk_group,
int64_t moe_top_k);
void topk_softmax_xpu(torch::Tensor &topk_weights, /* [m, topk] */
torch::Tensor& topk_indices, /* [m, topk] */
torch::Tensor& token_expert_indices, /* no used in xpu */
torch::Tensor& gating_output /* [m, n] */
);
torch::Tensor weak_ref_tensor(torch::Tensor& tensor);
void dynamic_scaled_int8_quant_xpu(torch::Tensor &out,
torch::Tensor &x,
torch::Tensor &input_scale,
const c10::optional<torch::Tensor>& input_azp
);
void cutlass_scaled_mm_xpu(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void castte_xpu(
torch::Tensor& input, // [num_tokens, hidden_dim]
torch::Tensor& ouput, // [num_tokens, hidden_dim]
torch::Tensor& scale // [1]
);
void castte_per_token_xpu(
torch::Tensor& input, // [num_tokens, hidden_dim]
torch::Tensor& ouput, // [num_tokens, hidden_dim]
torch::Tensor& scale // [num_tokens]
);
void fc_fusion_castte_xpu(
torch::Tensor& x, // [num_tokens, in_dim]
torch::Tensor& ouput, // [num_tokens, out_dim]
torch::Tensor& x_scale, // [1]
torch::Tensor& qweight, // [out_dim, in_dim]
torch::Tensor& qscale, // [1]
const c10::optional<torch::Tensor>& bias // [out_dim]
);
void fc_fusion_castte_per_token_xpu(
torch::Tensor& x, // [num_tokens, in_dim]
torch::Tensor& ouput, // [num_tokens, out_dim]
torch::Tensor& x_scale, // [num_tokens]
torch::Tensor& qweight, // [out_dim, in_dim]
torch::Tensor& qscale, // [1]
const c10::optional<torch::Tensor>& bias // [out_dim]
);
// trival cutlass
bool cutlass_scaled_mm_supports_fp8_xpu(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8_xpu(int64_t cuda_device_capability);
void outplace_split_norm_rope_xpu(
torch::Tensor &qkv,
torch::Tensor &cos_sin_cache,
torch::Tensor &q_weight,
torch::Tensor &k_weight,
torch::Tensor &positions,
torch::Tensor &q_emb_out,
torch::Tensor &k_emb_out,
torch::Tensor &v_out,
const int64_t emb_batch_size,
const int64_t max_seqlen,
const int64_t head_num,
const int64_t kv_head_num,
const int64_t head_dim,
const int64_t rotary_dim
);
void moe_fc_int8(
torch::Tensor &hidden_states, // dtype : bfloat16
torch::Tensor &output,
torch::Tensor &w1,
torch::Tensor &w1_scale,
torch::Tensor &w2,
torch::Tensor &w2_scale,
torch::Tensor &topk_weights,
torch::Tensor &topk_ids
);
#endif // OPS_H

View File

View File

@@ -0,0 +1,102 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian, Dong Xinyu
# Email: baoqian@baidu.com, dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun_communicator"""
from contextlib import contextmanager
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
class KunlunCommunicator(CudaCommunicator):
"""KunlunCommunicator"""
def __init__(self,
device,
device_group,
cpu_group,
unique_name):
"""
Initializes the CUDA Communicator.
Args:
cpu_group (ProcessGroup): The CPU process group.
device (Optional[torch.device], optional): The device to use. Defaults to None.
device_group (Optional[ProcessGroup], optional): The device process group. Defaults to None.
unique_name (str, optional): The unique name of this communicator. Defaults to "".
Raises:
ValueError: If both ``device`` and ``device_group`` are not specified.
"""
DeviceCommunicatorBase.__init__(self, cpu_group, device, device_group, unique_name)
self.ca_comm = None
self.disabled = False
with torch.cuda.device(device):
self.stream = torch.cuda.Stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
del data
def all_reduce(self, input_):
"""all_reduce"""
return DeviceCommunicatorBase.all_reduce(self, input_)
def all_gather(self, input_, dim):
"""all_gather"""
return DeviceCommunicatorBase.all_gather(self, input_, dim)
def gather(self, input_, dst, dim):
"""gather"""
return DeviceCommunicatorBase.gather(self, input_, dst, dim)
def send(self, tensor, dst):
"""send"""
DeviceCommunicatorBase.send(self, tensor, dst)
def recv(self, size, dtype, src):
"""recv"""
return DeviceCommunicatorBase.recv(self, size, dtype, src)
def destroy(self):
"""destroy"""
pass
@contextmanager
def change_state(self, enable, stream):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream

View File

@@ -0,0 +1,16 @@
"""# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""
from vllm_kunlun.lora.ops.kunlun_ops.lora_ops import (bgmv_expand,bgmv_expand_slice, bgmv_shrink,
sgmv_expand, sgmv_expand_slice,
sgmv_shrink)
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink"
]

View File

@@ -0,0 +1,443 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Author: Wang Hao
# Email: wanghao129@baidu.com
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun_ops for lora"""
import torch
from torch._C import dtype
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
"""
sgmv_shrink
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
device=device, dtype=torch.int32
)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
torch.ops._C.gen_block_statistic(lora_ids, block_statistic)
inputs_sorted = torch.zeros_like(inputs, dtype=inputs.dtype, device=device)
torch.ops._C.moe_pre_sorted(
inputs,
lora_ids,
block_statistic,
inputs_sorted,
moe_index,
expert_m,
sorted_tokens_num_lod
)
output_tensor.unsqueeze_(1)
torch.ops._C.moe_fc(
x=inputs_sorted,
weight=lora_a_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=output_tensor,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_tensor.squeeze_(1).mul_(scaling)
return output_tensor
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
"""
sgmv_expand
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
device=device, dtype=torch.int32
)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
out = torch.zeros((token_nums, 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(
output_post,
moe_index.unsqueeze(1),
normed_scale,
normed_scale,
output_post
)
common_len = min(output_post.shape[1], output_tensor.shape[1])
limit = min(output_post.shape[0], output_tensor.shape[0])
if add_inputs:
output_tensor[:limit, :common_len] += output_post[:limit, :common_len]
else:
output_tensor[:limit, :common_len] = output_post[:limit, :common_len]
return output_tensor
def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
"""
sgmv_expand_slice
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
device=device, dtype=torch.int32
)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
out = torch.zeros((token_nums, 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(
output_post,
moe_index.unsqueeze(1),
normed_scale,
normed_scale,
output_post
)
slice_end = slice_offset + slice_size
actual_slice_size = min(slice_size, output_tensor.shape[1] - slice_offset)
limit = min(output_post.shape[0], output_tensor.shape[0])
if add_inputs:
output_tensor[:limit, slice_offset:slice_end] += output_post[:limit, :actual_slice_size]
else:
output_tensor[:limit, slice_offset:slice_end] = output_post[:limit, :actual_slice_size]
return output_tensor
def bgmv_shrink(
inputs: torch.Tensor, # [m, hidden_dim]
lora_a_weights: torch.Tensor, # [n, 1, r, hidden_dim]
output_tensor: torch.Tensor, # [m, r]
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
lora_indices_tensor: torch.Tensor, # [m]
scaling: float = 1.0
) -> torch.Tensor:
"""
bgmv_shrink
"""
expert_num = 9
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
torch.ops._C.gen_block_statistic(lora_ids.unsqueeze(1), block_statistic)
inputs_sorted = torch.empty_like(inputs, dtype=inputs.dtype, device=inputs.device)
torch.ops._C.moe_pre_sorted(
inputs,
lora_ids.unsqueeze(1),
block_statistic,
inputs_sorted,
moe_index,
expert_m,
sorted_tokens_num_lod
)
output_tensor.unsqueeze_(1) # Change to [m, 1, r]
torch.ops._C.moe_fc(
x=inputs_sorted,
weight=lora_a_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=output_tensor,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_tensor.squeeze_(1).mul_(scaling)
return output_tensor
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
""""
bgmv_expand
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
out = torch.zeros((inputs.shape[0], 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(output_post, moe_index.unsqueeze(1), normed_scale, normed_scale, output_post)
limit = output_tensor.shape[0]
if output_post.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
# LoRA adapter and model may add different amounts of padding to output
common_len = min(output_post.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += output_post[:limit, :common_len]
else:
output_tensor[:, :common_len] = output_post[:limit, :common_len]
return output_tensor
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True
):
"""
bgmv_expand_slice
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
out = torch.zeros((inputs.shape[0], 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(output_post, moe_index.unsqueeze(1), normed_scale, normed_scale, output_post)
slice_end = slice_offset + slice_size
actual_slice_size = min(slice_size, output_tensor.shape[1] - slice_offset)
limit = min(output_post.shape[0], output_tensor.shape[0])
if add_inputs:
output_tensor[:limit, slice_offset:slice_end] += output_post[:limit, :actual_slice_size]
else:
output_tensor[:limit, slice_offset:slice_end] = output_post[:limit, :actual_slice_size]
return output_tensor

View File

@@ -0,0 +1,547 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Wang Hao
# Email: wanghao129@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import TYPE_CHECKING, Optional, Union, final
import torch
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union
from vllm_kunlun.lora.ops.kunlun_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
import time
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperKunlun(PunicaWrapperBase):
"""
PunicaWrapperKunlun with moe_fc
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: Union[torch.device, str],
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
):
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
sgmv_shrink(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
expert_m,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
):
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
bgmv_shrink(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
expert_m,
self.token_lora_indices,
scale,
)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
add_inputs: bool,
):
sgmv_expand(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
self.token_lora_indices,
add_inputs,
)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
sgmv_expand_slice(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
normed_scale,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
bgmv_expand_slice(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
normed_scale,
self.token_lora_indices,
y_offset,
y_slice_size,
add_inputs,
)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
)
expand_slice_fun(
y,
x,
w_t_all,
block_statistic,
sorted_tokens_num_lod,
moe_index,
y_offset,
y_slice_size,
add_inputs,
)
def _apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (
self._shrink_prefill if self.is_prefill else self._shrink_decode
)
shrink_fun(
y, x, w_t_all, block_statistic, sorted_tokens_num_lod, moe_index, scale
)
y = y.view_as(y_org)
def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)): # Each slice represents a layer
self._apply_shrink(
y[slice_idx],
x,
lora_a_stacked[slice_idx],
block_statistic,
sorted_tokens_num_lod,
moe_index,
scale,
)
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
block_statistic,
sorted_tokens_num_lod,
moe_index,
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
expand_fun: Callable = (
self._expand_prefill if self.is_prefill else self._expand_decode
)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
if self.no_lora:
return
expert_num = 9
block_statistic = torch.zeros(
[12, expert_num], dtype=torch.int32, device=x.device
)
sorted_tokens_num_lod = torch.zeros(
expert_num + 1, dtype=torch.int32, device=x.device
)
token_nums = x.size(0)
moe_index = torch.zeros(token_nums, dtype=torch.int32, device=x.device)
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None:
r = lora_b_stacked[0].size(-1)
buffer = tuple(
torch.zeros((x.size(0), r), dtype=torch.float16, device=x.device)
for _ in range(len(output_slices))
)
# [tensor.squeeze_(1) for tensor in lora_a_stacked]
new_lora_a_stacked = tuple(lora_a.squeeze(1) for lora_a in lora_a_stacked)
self.add_shrink(
buffer,
x,
new_lora_a_stacked,
block_statistic,
sorted_tokens_num_lod,
moe_index,
scale,
**kwargs,
)
# [tensor.unsqueeze_(1) for tensor in lora_a_stacked]
# [tensor.squeeze_(1) for tensor in lora_b_stacked]
new_lora_b_stacked = tuple(lora_b.squeeze(1) for lora_b in lora_b_stacked)
self.add_expand(
y,
buffer,
new_lora_b_stacked,
block_statistic,
sorted_tokens_num_lod,
moe_index,
None,
output_slices,
add_inputs=True,
**kwargs,
)
# [tensor.unsqueeze_(1) for tensor in lora_b_stacked]
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
if lora_a_stacked.dim() == 2:
lora_a_stacked = lora_a_stacked.unsqueeze(0)
if lora_b_stacked.dim() == 2:
lora_b_stacked = lora_b_stacked.unsqueeze(0)
r = lora_a_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
indices = self.sampler_indices
if indices.max() >= lora_a_stacked.size(0):
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
y = y.view_as(y_org)

View File

@@ -0,0 +1,68 @@
from vllm import ModelRegistry
def register_model():
# from .demo_model import DemoModel # noqa: F401
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
from .qwen3 import Qwen3ForCausalLM #noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
# ModelRegistry.register_model(
# "DemoModel",
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration")
ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration")
ModelRegistry.register_model(
"Qwen3ForCausalLM",
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
ModelRegistry.register_model(
"GlmForCausalLM",
"vllm_kunlun.models.glm:GlmForCausalLM")
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
ModelRegistry.register_model(
"InternLM2ForCausalLM",
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
ModelRegistry.register_model(
"Qwen2ForCausalLM",
"vllm_kunlun.models.qwen2:Qwen2ForCausalLM")
ModelRegistry.register_model(
"InternVLChatModel",
"vllm_kunlun.models.internvl:InternVLChatModel")
ModelRegistry.register_model(
"InternS1ForConditionalGeneration",
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
ModelRegistry.register_model(
"Glm4MoeForCausalLM",
"vllm_kunlun.models.glm4_moe:Glm4MoeForCausalLM")
ModelRegistry.register_model(
"Glm4ForCausalLM",
"vllm_kunlun.models.glm4:Glm4ForCausalLM")
ModelRegistry.register_model(
"Glm4vForConditionalGeneration",
"vllm_kunlun.models.glm4_1v:Glm4vForConditionalGeneration")
def register_quant_method():
"""to do"""

24
vllm_kunlun/models/glm.py Normal file
View File

@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only HF format GLM-4 model compatible with THUDM weights."""
from vllm.config import VllmConfig
# from vllm.model_executor.models.llama import LlamaForCausalLM
from .llama import LlamaForCausalLM #noqa: F401
from vllm.model_executor.models.utils import PPMissingLayer
class GlmForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
print("glm for causalLM initialization!!!!", flush=True)
vllm_config.model_config.hf_config.partial_rotary_factor = 0.5
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Hack Llama model to fit HF format GLM implementation
# Attention difference between GLM and Llama:
# 1. Half partial rotary_dim and no Neox style.
# 2. There is no bias for o_proj in attention
for layer in self.model.layers:
if not isinstance(layer, PPMissingLayer):
layer.self_attn.rotary_emb.is_neox_style = False
layer.self_attn.o_proj.bias = None
layer.self_attn.o_proj.skip_bias_add = True

301
vllm_kunlun/models/glm4.py Normal file
View File

@@ -0,0 +1,301 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/glm4.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import Glm4Config
from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm_kunlun.models.llama import LlamaMLP as Glm4MLP
from vllm_kunlun.models.llama import LlamaModel
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
class Glm4Attention(nn.Module):
def __init__(self,
config: Glm4Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.rotary_dim = self.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=partial_rotary_factor,
is_neox_style=False,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Glm4DecoderLayer(nn.Module):
def __init__(
self,
config: Glm4Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Glm4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=AttentionType.DECODER,
)
self.mlp = Glm4MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_self_attn_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states = self.post_self_attn_layernorm(hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": Glm4DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Glm4Model(LlamaModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=Glm4DecoderLayer)
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Glm4Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,716 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/glm4_moe.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
import os
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,get_dp_group,get_tp_group,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
logger = init_logger(__name__)
class Glm4MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Glm4MoE(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
# NOTE In the transformers implementation, the gate isn't an nn.Linear,
# so we cannot use ReplicatedLinear here.
# See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
self.gate = nn.Linear(
config.hidden_size,
config.n_routed_experts,
bias=False,
dtype=torch.float32,
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = Glm4MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
kunlun_linear_weights = self.gate.weight
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
linear_weights=kunlun_linear_weights) * self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
class Glm4MoeAttention(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 131072,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-05,
qkv_bias: bool = False,
use_qk_norm: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = use_qk_norm
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.getenv('USE_ORI_ROPE') == "1" or not self.use_qk_norm:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads,
self.head_dim)).reshape(q.shape)
k = self.k_norm(k.reshape(-1, self.num_kv_heads,
self.head_dim)).reshape(k.shape)
q, k = self.rotary_emb(positions, q, k)
else:
# Rope fusion operators
q, k, v = Split_Norm_Rope(qkv,
self.rotary_emb.cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.max_position_embeddings,
self.num_heads,
self.num_kv_heads,
self.head_dim,
partial_rotary_factor=self.partial_rotary_factor,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Glm4MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
131072)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.self_attn = Glm4MoeAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=config.attention_bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_qk_norm=config.use_qk_norm,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace):
self.mlp = Glm4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Glm4MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Glm4MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Glm4MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.expert_weights = []
# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Glm4MoeDecoderLayer)
if isinstance(layer.mlp, Glm4MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No Glm4MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if f"layers.{layer_idx+i}." in weight_name:
return layer_idx + i
return None

View File

@@ -0,0 +1,630 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/gpt_oss.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from typing import Optional
import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .utils import extract_layer_index, maybe_prefix
class OAIAttention(nn.Module):
def __init__(
self,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.head_dim = config.head_dim
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
dtype=torch.float32,
rope_scaling={
"rope_type":
"yarn",
"factor":
config.rope_scaling["factor"],
"original_max_position_embeddings":
config.rope_scaling["original_max_position_embeddings"],
"beta_fast":
config.rope_scaling["beta_fast"],
"beta_slow":
config.rope_scaling["beta_slow"],
},
is_neox_style=True,
)
tp_size = get_tensor_model_parallel_world_size()
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=torch.bfloat16,
requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.qkv = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.num_attention_heads,
total_num_kv_heads=self.num_key_value_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.num_attention_heads * self.head_dim,
output_size=self.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.num_local_attention_heads = config.num_attention_heads // tp_size
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
# Only apply sliding window to every other layer
sliding_window = (config.sliding_window if self.layer_idx %
2 == 0 else None)
self.attn = Attention(
self.num_local_attention_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_local_key_value_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
sinks=self.sinks,
)
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
t = self.norm(hidden_states)
qkv, _ = self.qkv(t)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output + hidden_states
class MLPBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts,
dtype=torch.bfloat16)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False,
has_bias=True,
activation="swigluoai")
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
g = self.router(t)
t = self.experts(hidden_states=t, router_logits=g)
return x + t
class TransformerBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
self.mlp = MLPBlock(config,
self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
attn_output = self.attn(hidden_states, positions)
output = self.mlp(attn_output)
return output
@support_torch_compile
class GptOssModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)
self.layers = torch.nn.ModuleList([
TransformerBlock(
self.config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
) for layer_idx in range(self.config.num_hidden_layers)
])
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
def forward(self, input_ids: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x, positions)
x = self.norm(x)
return x
class GptOssForCausalLM(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config.hf_config
self.model = GptOssModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
self.model_config.vocab_size,
self.model_config.hidden_size,
)
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
assert intermediate_tensors is None
assert inputs_embeds is None
return self.model(input_ids, positions)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def _load_weights_mxfp4(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
per_rank_intermediate_size = (per_rank_intermediate_size_block *
mxfp4_block)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if "gate_up_proj_blocks" in name:
# Handle MLP gate and up projection weights
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
-1).contiguous()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_blocks" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
intermediate_size // 2).contiguous()
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales",
"w13_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
weight_loader(param,
weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def _load_weights_other(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
if ".experts.gate_up_proj" in name and "bias" not in name:
# Handle MLP gate and up projection weights
new_name = name.replace(".experts.gate_up_proj",
".experts.w13_weight")
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif ".experts.down_proj" in name and "bias" not in name:
# Handle MLP down projection weights
new_name = name.replace(".experts.down_proj",
".experts.w2_weight")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[new_name]
param.copy_(weight)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
quant_method = (self.model_config.quantization_config['quant_method']
if hasattr(self.model_config, "quantization_config")
else None)
if quant_method == "mxfp4":
return self._load_weights_mxfp4(weights)
else:
return self._load_weights_other(weights)

View File

@@ -0,0 +1,480 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm_kunlun.ops.attention.layer import MultiHeadAttention
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
}
class InternVisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size)
self.num_patches = (self.image_size // self.patch_size)**2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size,
self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed,
size=(H, W),
mode='bicubic',
align_corners=False)
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
position_embedding = self.position_embedding
if self.num_patches == H * W:
return position_embedding
return torch.cat(
[
position_embedding[:, :1, :],
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
],
dim=1,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
target_dtype)) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = self._get_position_embedding(height, width)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
class InternVisionPatchModel(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
return hidden_states
class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear(
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
out = self.attn(q, k, v)
out, _ = self.proj(out)
return out
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.dummy_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
return x
class InternMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class InternVisionEncoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn")
self.mlp = InternMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
self.ls1 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
self.ls2 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
prefix: str = "",
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
prefix=prefix)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
):
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states)) * self.ls1
hidden_states = hidden_states + self.mlp(
self.norm2(hidden_states)) * self.ls2
return hidden_states
class InternVisionEncoder(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class InternVisionModel(nn.Module):
packed_modules_mapping = {
"qkv": ["qkv"],
}
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import partial
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
# from vllm.attention import Attention
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
# from vllm.model_executor.layers.activation import SiluAndMul
from vllm_kunlun.ops.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from vllm.model_executor.models.utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class InternLM2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.w2 = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.w2(x)
return x
class InternLM2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= self.tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % self.tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.wqkv = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wqkv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def split_qkv(self, qkv: torch.Tensor):
seq_len = qkv.shape[0]
if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
qkv = tensor_model_parallel_all_gather(qkv)
qkv = torch.split(qkv, qkv_map, dim=-1)
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
qkv = torch.cat(qkv, dim=-1)
qkv = qkv.view(seq_len, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(seq_len, self.q_size * self.tp_size)
k = k.reshape(seq_len, self.kv_size * self.tp_size)
v = v.reshape(seq_len, self.kv_size * self.tp_size)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
return q, k, v
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.wo(attn_output)
return output
class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
else:
hidden_states, residual = self.attention_norm(
hidden_states, residual)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.ffn_norm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"wqkv": ["wqkv"],
"gate_up_proj": ["w1", "w3"],
}
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: type[InternLM2Model] = InternLM2Model):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.lora_config = lora_config
self.model = model_type(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "output"))
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@default_pooling_type("ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: type[InternLM2Model] = InternLM2Model,
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=model_type)
for attr in ("output", "logits_processor"):
delattr(self, attr)
config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear(
config.hidden_size,
1,
bias=False,
input_is_parallel=False,
prefix=maybe_prefix(prefix, "v_head"),
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits

View File

@@ -0,0 +1,869 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/interns1.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
import regex as re
import torch
import torch.nn as nn
from transformers import BatchFeature, InternVLProcessor, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.models.got_ocr2.image_processing_got_ocr2_fast import (
GotOcr2ImageProcessorFast)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from .interns1_vit import InternS1VisionModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
class InternS1MultiModalProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size *
int(1 / config.downsample_ratio)**2)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size *
int(1 / config.downsample_ratio)**2,
config.text_config.hidden_size)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size)
def forward(self, image_features):
hidden_states = self.layer_norm(image_features)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class InternS1ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
class InternS1ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternS1ImageInputs = Union[InternS1ImagePixelInputs,
InternS1ImageEmbeddingInputs]
class InternS1VideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values: torch.Tensor
"""
Shape:
`(batch_size * num_video * num_frames, num_channels, height, width)`
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class InternS1VideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternS1VideoInputs = Union[InternS1VideoPixelInputs,
InternS1VideoEmbeddingInputs]
def resolve_interns1_min_max_num(
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: bool,
use_thumbnail: bool,
) -> tuple[int, int]:
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1:
max_dynamic_patch += 1
return min_dynamic_patch, max_dynamic_patch
def get_interns1_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
class InternS1ProcessingInfo(BaseProcessingInfo):
"""ProcessingInfo for InternS1-style models."""
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
return self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional['GotOcr2ImageProcessorFast'] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor().image_processor
if not isinstance(processor, GotOcr2ImageProcessorFast):
raise ValueError(f'GotOcr2ImageProcessorFast is expected but got '
f'{type(processor)}')
num_image_patches = processor.get_number_of_image_patches(
image_height, image_width, images_kwargs=dict())
num_image_tokens = self.get_hf_processor(
).image_seq_length * num_image_patches
return num_image_tokens
def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None):
image_processor = self.get_hf_processor().image_processor
min_dynamic_patch = image_processor.min_patches
max_dynamic_patch = image_processor.max_patches
# HF format's InternVL processor uses `crop_to_patches` which is
# equivalent to `use_thumbnail` in original format.
use_thumbnail = image_processor.crop_to_patches
dynamic_image_size = True
min_num, max_num = resolve_interns1_min_max_num(
min_dynamic_patch,
max_dynamic_patch,
dynamic_image_size,
use_thumbnail=use_thumbnail)
return get_interns1_target_ratios(min_num, max_num)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
hf_config = self.ctx.get_hf_config()
base_height, base_width = hf_config.vision_config.image_size
target_ratios = self.resolve_target_ratios()
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_width * wr, base_height * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor.image_processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
assert not (largest_feature_size == 0 or largest_feature_pinpoint
is None), ("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_max_image_tokens(self) -> int:
processor = self.get_hf_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=processor.image_processor,
)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
processor = self.get_hf_processor()
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len -
max_image_tokens) // processor.image_seq_length
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
):
"""DummyInputsBuilder for InternS1-style models."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_token = self.info.get_hf_processor().image_token
video_token = self.info.get_hf_processor().video_token
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
"""Generates dummy multimodal data on Kunlun3 platform for performance analysis and warmup.
Retrieves visual resolution based on configuration (defaulting to 224x224)
and generates resized dummy data for images and videos.
Args:
seq_len: Sequence length (unused).
mm_counts: A mapping of multimodal type counts, containing "image"
and "video" keys.
Returns:
MultiModalDataDict: A dictionary containing the generated dummy image
and video data, structured as:
{
"image": dummy_image_data,
"video": dummy_video_data
}
Author:
Dong Xinyu
"""
config = self.info.get_hf_config()
img_size = getattr(config.vision_config, "image_size", None)
if isinstance(img_size, (tuple, list)) and len(img_size) == 2:
cfg_h, cfg_w = int(img_size[0]), int(img_size[1])
else:
cfg_h, cfg_w = 224, 224
target_width = min(cfg_w, 224)
target_height = min(cfg_h, 224)
target_num_frames = 1
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
),
"video": self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
),
}
class InternS1MultiModalProcessor(
BaseMultiModalProcessor[InternS1ProcessingInfo]):
""" Basic image-only MultiModalProcessor for InternS1-style models."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
images = mm_data.pop("images", [])
assert isinstance(videos, list)
assert isinstance(images, list)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
tokenizer = hf_processor.tokenizer
video_token_id = tokenizer.encode(hf_processor.video_token,
add_special_tokens=False)
assert len(video_token_id) == 1
video_token_id = video_token_id[0]
prompt = re.sub(hf_processor.image_token, "<image_placeholder>",
prompt)
prompt = re.sub(hf_processor.video_token, "<video_placeholder>",
prompt)
image_outputs = {}
if images:
image_pixel_values = []
for image in images:
processed_outputs = super()._call_hf_processor(
prompt=hf_processor.image_token,
mm_data={"images": image},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
image_pixel_values.append(
processed_outputs.pop("pixel_values"))
input_ids = processed_outputs.pop("input_ids")
image_placeholder = tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace("<image_placeholder>",
image_placeholder, 1)
num_patches = [len(item) for item in image_pixel_values]
image_outputs: dict[str, NestedTensors] = {
"pixel_values": torch.concat(image_pixel_values),
"image_num_patches": torch.tensor(num_patches),
"image_token_id": torch.tensor(hf_processor.image_token_id),
}
video_outputs = {}
if videos:
video_pixel_values = []
for video in videos:
processed_outputs = super()._call_hf_processor(
prompt=hf_processor.video_token,
mm_data={"videos": video},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
video_pixel_values.append(
processed_outputs.pop("pixel_values"))
input_ids = processed_outputs.pop("input_ids")
input_ids[input_ids ==
hf_processor.image_token_id] = video_token_id
video_placeholder = tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace("<video_placeholder>",
video_placeholder, 1)
num_frames = [len(item) for item in video_pixel_values]
video_outputs: dict[str, NestedTensors] = {
"pixel_values_videos": torch.concat(video_pixel_values),
"video_num_patches": torch.tensor(num_frames),
"video_token_id": torch.tensor(video_token_id),
}
prompt = re.sub("<image_placeholder>", hf_processor.image_token,
prompt)
prompt = re.sub("<video_placeholder>", hf_processor.video_token,
prompt)
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
combined_outputs = dict(
**text_outputs,
**image_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_images = len(image_num_patches)
num_videos = len(video_num_patches)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
img_context_token = hf_processor.image_token
start_image_token = hf_processor.start_image_token
end_image_token = hf_processor.end_image_token
video_token = hf_processor.video_token
if "video_num_patches" in out_mm_kwargs:
video_num_patches = out_mm_kwargs["video_num_patches"]
assert isinstance(video_num_patches, torch.Tensor)
video_num_patches = video_num_patches.tolist()
else:
video_num_patches = []
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
else:
image_num_patches = []
def get_replacement_interns1_image(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
num_patches = image_num_patches[item_idx]
feature_size = num_patches * hf_processor.image_seq_length
repl_features = img_context_token * feature_size
repl_full = start_image_token + repl_features + end_image_token
return PromptUpdateDetails.select_text(repl_full,
img_context_token)
def get_replacement_interns1_video(item_idx: int):
num_patches = video_num_patches[item_idx]
repl_features = video_token * hf_processor.image_seq_length
repl_features_with_sep = (start_image_token + repl_features +
end_image_token)
# num_patches is equal to num_frames
repl_full = '\n'.join([
f'Frame{i+1}: {repl_features_with_sep}'
for i in range(num_patches)
])
return PromptUpdateDetails.select_text(repl_full, video_token)
return [
PromptReplacement(
modality="image",
target=img_context_token,
replacement=get_replacement_interns1_image,
),
PromptReplacement(
modality="video",
target=video_token,
replacement=get_replacement_interns1_video,
),
]
@MULTIMODAL_REGISTRY.register_processor(
InternS1MultiModalProcessor,
info=InternS1ProcessingInfo,
dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
})
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
# transformers InternVLProcessor uses <IMG_CONTEXT> as the seperator
# refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116
if modality.startswith("image"):
return '<IMG_CONTEXT>'
if modality.startswith("video"):
return "<video>"
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
image_size = config.vision_config.image_size[0]
patch_size = config.vision_config.patch_size[0]
self.patch_size = patch_size
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio
self.llm_arch_name = config.text_config.architectures[0]
self.vision_tower = self._init_vision_model(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = self._init_mlp1(config)
self.img_context_token_id = None
self.video_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
prefix: str,
):
num_hidden_layers = config.vision_config.num_hidden_layers
return InternS1VisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
return InternS1MultiModalProjector(config)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_tower(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds,
scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
vit_embeds.shape[-1])
vit_embeds = self.multi_modal_projector(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h, w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternS1ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternS1ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternS1ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=image_num_patches,
)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
if pixel_values_flat_video is None and video_embeds is None:
return None
if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternS1ImageEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
return InternS1VideoPixelInputs(
type="pixel_values_videos",
pixel_values=self._validate_pixel_values(
pixel_values_flat_video),
num_patches=video_num_patches,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_embeds = self.extract_feature(image_input["pixel_values"])
num_patches = image_input["num_patches"]
# Only one image in the current batch
if len(num_patches) == 1:
return (image_embeds.view(-1,
self.config.text_config.hidden_size), )
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
return image_embeds.split(image_feature_sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in (
"pixel_values_videos", ) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")

View File

@@ -0,0 +1,431 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/interns1_vit.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
}
class InternS1VisionPatchEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
patch_size[0])
patch_shape = (image_size[0] // patch_size[0],
image_size[1] // patch_size[1])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.patch_shape = patch_shape
self.projection = nn.Conv2d(num_channels,
hidden_size,
kernel_size=patch_size,
stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values "
"match with the one set in the configuration.")
embeddings = self.projection(
pixel_values.to(self.projection.weight.dtype))
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, (patch_height, patch_width)
class InternS1VisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
if config.use_mask_token:
self.mask_token = nn.Parameter(
torch.zeros(1, 1, config.hidden_size))
else:
self.mask_token = None
self.patch_embeddings = InternS1VisionPatchEmbeddings(config)
self.patch_size = config.patch_size
self.image_size = (config.image_size if isinstance(
config.image_size, Iterable) else
(config.image_size, config.image_size))
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(
torch.zeros(1, num_patches + 1, config.hidden_size))
else:
self.position_embeddings = None
@torch._dynamo.disable
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
""" # noqa: E501
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
# always interpolate when tracing to ensure the exported model
# works for dynamic input shapes
if not torch.jit.is_tracing(
) and num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size[0]
new_width = width // self.patch_size[1]
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions,
sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings, (patch_height,
patch_width) = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1 - w) + mask_tokens * w
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if self.position_embeddings is not None:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width)
return embeddings, (patch_height, patch_width)
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(self.embed_dim,
self.num_heads * self.head_dim,
bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim,
self.num_heads * self.head_dim,
bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim,
self.num_heads * self.head_dim,
bias=config.attention_bias)
self.qk_normalization = config.use_qk_norm
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.projection_layer(x)
return x
class InternS1VisionMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# self.activation_fn = GeluAndMul()
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class InternS1VisionLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.attention = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attention")
self.mlp = InternS1VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layernorm_before = NORM2FN[config.norm_type](
config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = NORM2FN[config.norm_type](
config.hidden_size, eps=config.layer_norm_eps)
init_values = config.layer_scale_init_value
self.lambda_1 = nn.Parameter(init_values *
torch.ones(config.hidden_size),
requires_grad=True)
self.lambda_2 = nn.Parameter(init_values *
torch.ones(config.hidden_size),
requires_grad=True)
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
prefix: str = "",
):
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
):
hidden_states = hidden_states + self.attention(
self.layernorm_before(hidden_states)) * self.lambda_1
hidden_states = hidden_states + self.mlp(
self.layernorm_after(hidden_states)) * self.lambda_2
return hidden_states
class InternS1VisionEncoder(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layer = nn.ModuleList([
InternS1VisionLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layer:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class InternS1VisionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embeddings = InternS1VisionEmbeddings(config)
self.encoder = InternS1VisionEncoder(
config=config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
)
self.layernorm = (nn.Identity() if config.use_mean_pooling else
nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps))
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states, _ = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
encoder_outputs = self.layernorm(encoder_outputs)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

File diff suppressed because it is too large Load Diff

643
vllm_kunlun/models/llama.py Normal file
View File

@@ -0,0 +1,643 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/llama.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm_kunlun.ops.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim
# Phi models introduced a partial_rotary_factor parameter in the config
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self._init_rotary_emb(config,
rope_scaling=rope_scaling,
quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} is not supported.")
else:
sliding_window = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
#TODO@hanhaowen:use kunlun ops to speed up
q, k = self.rotary_emb.forward_native(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(self, config: LlamaConfig,
rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias
# By default, Llama uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. parasail-ai/GritLM-7B-vllm)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
# @support_torch_compile
class LlamaModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int] = tuple()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self,
name: str,
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = (f"{item}.{next_item}"
if next_item is not None else None)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight

View File

@@ -0,0 +1,24 @@
class BitsAndBytesModelLoader():
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names = ["adapter_config.json"]
def __init__(self):
# Save the module names without sharding.
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
self.load_8bit: bool = False
self.is_pool_model: bool = False

498
vllm_kunlun/models/qwen2.py Normal file
View File

@@ -0,0 +1,498 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
import os
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm_kunlun.ops.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.adapters import as_seq_cls_model
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Qwen2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen2Model(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Qwen2DecoderLayer
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

530
vllm_kunlun/models/qwen3.py Normal file
View File

@@ -0,0 +1,530 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen3.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import xtorch_ops
import torch
import os
from torch import nn
from transformers import Qwen3Config
from vllm.attention import AttentionType, AttentionMetadata
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm import envs
from vllm.model_executor.models.adapters import as_seq_cls_model
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.platforms import current_platform
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
logger = init_logger(__name__)
class Qwen3Attention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position = max_position
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
self.max_position = int(self.max_position * scaling_factor)
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# TODO: Supports both original Rope and Kunlun Rope fusion operators
if os.getenv('FUSED_QK_ROPE_OP') == "1":
# Rope fusion operators
q, k, v = Split_Norm_Rope(qkv,
self.rotary_emb.cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.max_position,
self.num_heads,
self.num_kv_heads,
self.head_dim,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, Qwen3 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = Qwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = Qwen3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": Qwen3DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen3Model(nn.Module):
"""Qwen3Model"""
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen3DecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Qwen2DecoderLayer
decoder_layer_type = decoder_layer_type or Qwen3DecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""get_input_embeddings"""
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""
Args:
input_ids (torch.Tensor): Input sequence of shape `(batch, seq_len)`.
Indices are expected to be in the range `[0, config.vocab_size]`.
positions (torch.Tensor): Positional tensor of shape `(batch, seq_len)`.
intermediate_tensors (Optional[IntermediateTensors], optional):
Intermediate tensors from previous forward pass. Defaults to `None`.
inputs_embeds (Optional[torch.Tensor], optional):
Optionally, instead of positional embeddings, you can choose to
provide your own embedding lookup matrix of shape `(batch, seq_len, emb_dim)`.
If None, the model will create one on its own using the input ids.
Defaults to `None`.
Returns:
Union[torch.Tensor, IntermediateTensors]:
If `intermediate_tensors` is not None, returns a IntermediateTensors object.
Otherwise, returns a tensor of shape `(batch, seq_len, hidden_size)` representing
the output of the last transformer encoder layer.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i, layer in enumerate(self.layers[self.start_layer:self.end_layer], start=self.start_layer):
hidden_states, residual = layer(
positions,
hidden_states,
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
"""Load model weights.
Args:
weights (Iterable[tuple[str, torch.Tensor]]): An iterator containing weight names and their corresponding values.
Returns (set[str]):
A set of already loaded weight names.
Exceptions:
None.
"""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
kv_caches: list[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)

View File

@@ -0,0 +1,836 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import os
from collections.abc import Iterable
from typing import Any, Optional, Union, Tuple, Set
import torch
import os
from torch import nn
from transformers import PretrainedConfig
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm_kunlun.ops.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
logger = init_logger(__name__)
class Qwen3MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}."
)
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.quant_config = quant_config
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
if self.quant_config is None:
kunlun_linear_weights = self.gate.get_weights()
final_hidden_states = self.experts(
hidden_states=hidden_states, linear_weights=kunlun_linear_weights
)
else:
kunlun_linear_weights = self.gate.get_weights()
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
linear_weights=kunlun_linear_weights,
)
if self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
)
return final_hidden_states.view(orig_shape)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
self.max_position_embeddings = int(
self.max_position_embeddings * scaling_factor
)
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.getenv("FUSED_QK_ROPE_OP") == "1":
# Rope fusion operators
q, k, v = Split_Norm_Rope(
qkv,
self.rotary_emb.cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.max_position_embeddings,
self.num_heads,
self.num_kv_heads,
self.head_dim,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeSparseMoeBlock(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens"
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
weights_to_quantize = {}
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Map to the parameter name in the model
name_mapped = name.replace(weight_name, param_name)
# Layer/PP skip judgment
if is_pp_missing_parameter(name_mapped, self):
continue
if (
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
) and name_mapped not in params_dict:
continue
# Get the param and target module
param = params_dict.get(name_mapped, None)
if param is None:
continue
# === Only when the target MoE layer has int8 weights and scales, and the name matches, the "streaming quantization" is performed ===
if self._should_stream_quantize(name_mapped):
# Note: Pass the mapped name_mapped instead of the original name
self._stream_quantize_moe_weight(
name_mapped,
param,
loaded_weight,
expert_id=expert_id,
shard_id=shard_id,
)
loaded_params.add(name_mapped)
else:
# Fallback: Normal weight loading (non-quantized)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded."
)
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# loaded_params.add(name)
return loaded_params
def _is_moe_weight(self, name: str) -> bool:
"""Check if the weight is MoE weight"""
return name.endswith("w13_weight") or name.endswith("w2_weight")
def _is_expert_complete(self, cache_key):
cache = self._moe_weight_cache.get(cache_key)
if cache is None:
return False
w13_ok = (0 in cache["w13_shards"]) and (1 in cache["w13_shards"])
w2_ok = cache["w2_weight"] is not None
return w13_ok and w2_ok
@torch.no_grad()
def _stream_quantize_moe_weight(
self,
param_name: str,
param: nn.Parameter,
loaded_weight: torch.Tensor,
*,
expert_id,
shard_id,
):
rank = os.environ.get("RANK", "0")
# Ensure expert_id is an integer
try:
expert_id = int(expert_id)
except (ValueError, TypeError):
if isinstance(expert_id, str):
expert_id = int(expert_id)
# Process shard_id
if isinstance(shard_id, str):
if shard_id in ("gate", "w1"):
shard_id = 0
elif shard_id in ("up", "w3"):
shard_id = 1
elif shard_id == "w2":
shard_id = 0
else:
try:
shard_id = int(shard_id)
except ValueError:
shard_id = 0
else:
shard_id = int(shard_id)
# Initialize cache
if not hasattr(self, "_moe_weight_cache"):
self._moe_weight_cache = {}
self._expert_batch_count = 0 # Batch counter
module_path = ".".join(param_name.split(".")[:-1])
cache_key = (module_path, expert_id)
cache = self._moe_weight_cache.get(cache_key)
if cache is None:
cache = {
"w13_shards": {},
"w2_weight": None,
"target_module": self.get_submodule(module_path),
"done": False,
}
self._moe_weight_cache[cache_key] = cache
if cache.get("done", False):
return
# Cache weights (keep original precision)
if "w13_weight" in param_name:
cache["w13_shards"][shard_id] = loaded_weight.clone()
elif "w2_weight" in param_name:
cache["w2_weight"] = loaded_weight.clone()
# Check if complete
if self._is_expert_complete(cache_key):
# Quantize this expert
self._quantize_expert_weights(cache_key)
cache["done"] = True
self._moe_weight_cache.pop(cache_key, None)
# Force synchronization every 4 experts
self._expert_batch_count += 1
if self._expert_batch_count % 4 == 0:
torch.cuda.synchronize() # Force synchronization
# print(f"[Rank {rank}] Completed batch of {self._expert_batch_count} experts")
def _quantize_expert_weights(self, cache_key):
"""Quantize the complete weights of an expert (supports TP sharding)"""
module_path, expert_id = cache_key
cache = self._moe_weight_cache[cache_key]
target_module = cache["target_module"]
# Get TP config
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# Get actual shapes
E, twoN, H = target_module.w13_weight.shape
_, H2, N = target_module.w2_weight.shape
qmax = 127.0
# Process w13_weight: concatenate gate and up
gate_weight = cache["w13_shards"][0] # [768, 2048]
up_weight = cache["w13_shards"][1] # [768, 2048]
# TP sharding
if tp_size > 1:
# Calculate shard for each TP rank
gate_per_rank = gate_weight.shape[0] // tp_size
up_per_rank = up_weight.shape[0] // tp_size
gate_start = tp_rank * gate_per_rank
gate_end = (tp_rank + 1) * gate_per_rank
up_start = tp_rank * up_per_rank
up_end = (tp_rank + 1) * up_per_rank
gate_weight = gate_weight[gate_start:gate_end, :] # [192, 2048]
up_weight = up_weight[up_start:up_end, :] # [192, 2048]
w13_complete = torch.cat([gate_weight, up_weight], dim=0) # [384, 2048]
# Quantize w13_weight
w13_f = w13_complete.float()
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [384]
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [384]
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [384, 1]
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
# Write w13_weight
target_module.w13_weight.data[expert_id, :, :].copy_(
w13_q.to(target_module.w13_weight.device)
)
# Update w13_scale - pre-multiply 127
s = getattr(target_module, "w13_weight_scale")
s.data[expert_id, :].copy_((w13_scale_2d * 127.0).to(s.device))
# Process w2_weight
w2_weight = cache["w2_weight"] # [2048, 768]
# TP sharding for w2 weight
if tp_size > 1:
w2_per_rank = w2_weight.shape[1] // tp_size
w2_start = tp_rank * w2_per_rank
w2_end = (tp_rank + 1) * w2_per_rank
w2_weight = w2_weight[:, w2_start:w2_end] # [2048, 192]
w2_f = w2_weight.float() # [2048, 192]
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [2048]
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [2048]
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [2048, 1]
w2_q = torch.round(w2_f / w2_scale_3d).clamp_(-128, 127).to(torch.int8)
# Write w2_weight
w2_param = getattr(target_module, "w2_weight")
w2_param.data[expert_id, :, :].copy_(w2_q.to(w2_param.device))
# Update w2_scale - pre-multiply 127
w2_s = getattr(target_module, "w2_weight_scale")
w2_s.data[expert_id, :].copy_((w2_scale_2d * 127.0).to(w2_s.device))
# Clear cache
cache["w13_shards"].clear()
cache["w2_weight"] = None
def _is_int8_moe_target_module(self, module_path: str) -> bool:
"""Check if a module_path is a FusedMoE target using INT8(W8A8).
Determine by the actual existing parameters and dtype, not relying on quant_config names.
"""
try:
mod = self.get_submodule(module_path)
except Exception:
return False
# Need to have both int8 weights and float32 scales, and dimensions come from CompressedTensorsW8A8 path
if not (
hasattr(mod, "w13_weight")
and hasattr(mod, "w2_weight")
and hasattr(mod, "w13_weight_scale")
and hasattr(mod, "w2_weight_scale")
):
return False
try:
return (
mod.w13_weight.dtype == torch.int8
and mod.w2_weight.dtype == torch.int8
and mod.w13_weight_scale.dtype == torch.float32
and mod.w2_weight_scale.dtype == torch.float32
)
except Exception:
return False
def _should_stream_quantize(self, param_name: str) -> bool:
"""Only when (1) the parameter name corresponds to the MoE weights we defined; and
(2) the MoE layer is indeed the INT8 path (exists int8 weights + scales)
Stream quantization is enabled; otherwise, it falls back to the default loading.
"""
# First, determine if it is the MoE weight name we want to process (w13_weight / w2_weight)
if not self._is_moe_weight(param_name):
return False
# Then, check if the module containing this param is the INT8 path
module_path = ".".join(param_name.split(".")[:-1])
return self._is_int8_moe_target_module(module_path)
class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen3MoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
kv_caches: list[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@@ -0,0 +1,21 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
#
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq

View File

@@ -0,0 +1,597 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
import torch
import os
from typing import Optional, List, Dict
import vllm.envs as envs
import os
import ctypes
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import xtorch_ops
logger.info(f"Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
"""is per token smooth quant"""
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
context_lens_cpu,
is_context,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV1"""
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=block_tables,
context_lens_cpu=context_lens_cpu,
context_lens_xpu=context_lens,
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
)
@staticmethod
def paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
context_lens_cpu,
is_context,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV2"""
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=block_tables,
context_lens_cpu=context_lens_cpu,
context_lens_xpu=context_lens,
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
xtorch_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
xtorch_ops.quick_gelu(
x,
out=out,
)
# Layernorm
@staticmethod
def rms_norm(
out,
x,
weight,
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
@staticmethod
def fused_add_rms_norm(
x,
residual,
weight,
epsilon,
):
"""fused_add_rms_norm"""
output = torch.empty_like(x)
xtorch_ops.add_rmsnorm(
x, residual, weight.to(torch.float32), epsilon, out=output
)
fused_input = x + residual
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
if not is_neox_style:
if cos_sin_cache.dtype == torch.float16:
cos_sin_cache = cos_sin_cache.to(torch.float32)
positions = positions.to(torch.int)
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
xtorch_ops.rotary_embedding_gptj(
positions, query_x, key_x, head_size, cos_sin_cache
)
query.data = query_x
key.data = key_x
if query_x_dim != query_x.dim():
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
return query, key
# TODO: need opt
if cos_sin_cache.dim() == 4:
max_seq_len = cos_sin_cache.shape[2]
head_dim = cos_sin_cache.shape[3]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
0
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
# Reshape query and key
num_tokens = query_x.shape[0]
num_heads = query_x.shape[1] // head_size
num_kv_heads = key_x.shape[1] // head_size
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
# query_x = query_x.view(num_tokens, num_heads, head_size)
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
# # Ensure shapes are correct
# assert query_x.shape == (num_tokens, num_heads, head_size), \
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
torch.ops._C.rotary_embedding(
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
query_x = query_x.view(num_tokens, num_heads * head_size)
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
# query.data = query_x
# key.data = key_x
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
query.data = query_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
xtorch_ops.swap_blocks(src, dst, block_mapping)
@staticmethod
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
xtorch_ops.copy_blocks(
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
):
"""reshape_and_cache"""
# slot_mapping_cast = slot_mapping.to(torch.int32)
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def multi_query_kv_attention(
usual_seq_lod_xpu: torch.Tensor,
usual_seq_lod_cpu: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs,
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
"""
if query.dim() == 3:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
output = torch.empty_like(query)
alibi_slopes = kargs.get("alibi_slopes", None)
mask = kargs.get("mask", None)
is_causal = kargs.get("is_causal", True)
is_lvsl = kargs.get("is_lvsl", True)
B, T, Qh, Hd = query.shape
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
q=query,
k_cache=key,
v_cache=value,
out=output,
is_causal=True,
is_prefill=True,
context_seq_lod_cpu=usual_seq_lod_cpu,
context_seq_lod_xpu=usual_seq_lod_xpu,
)
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0])
return out
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_moe_ep(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
ep_rank: int,
top_k: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)
out = torch.zeros(batch * top_k, hidden_size, dtype=x.dtype, device=x.device)
repeat_x = x.repeat_interleave(top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(num_local_experts):
experts_id = ep_rank * num_local_experts + i
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
out[selected_token] = up_gate @ w2_weight[i].T
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output
@staticmethod
def fused_multi_head_latent_page_attention(
hidden_states: torch.Tensor,
q_lora_rank: int,
kv_lora_rank: int,
q_a_proj_w: torch.Tensor,
q_a_layernorm_w: torch.Tensor,
q_b_proj_w: torch.Tensor,
q_proj_w: torch.Tensor,
kv_a_proj_w: torch.Tensor,
kv_a_layernorm_w: torch.Tensor,
kv_b_proj_w: torch.Tensor,
o_proj_w: torch.Tensor,
head_num: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
max_context_len: int,
layernorm_eps: float,
scale: float,
is_causal: bool,
is_context: bool,
mp_size: int,
local_rank: int,
rotary_pos_embedding: torch.Tensor,
pa_block_tables: torch.Tensor,
position: torch.Tensor,
context_lens_cpu: torch.Tensor,
slot_mapping: torch.Tensor,
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
xtorch_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
kv_lora_rank,
q_a_proj_w,
q_a_layernorm_w,
q_b_proj_w,
q_proj_w,
kv_a_proj_w,
kv_a_layernorm_w,
kv_b_proj_w,
o_proj_w,
head_num,
qk_nope_head_dim,
qk_rope_head_dim,
v_head_dim,
max_context_len,
layernorm_eps,
scale,
is_causal,
is_context,
mp_size,
local_rank,
rotary_pos_embedding,
pa_block_tables,
position,
None,
context_lens_cpu,
slot_mapping,
None,
prompt_lods_cpu,
out=output,
k_cache=k_cache,
v_cache=v_cache,
)
return output

View File

@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
"""Custom activation functions."""
import torch
import torch.nn.functional as F
from vllm.model_executor.custom_op import CustomOp
@CustomOp.register("kunlun_silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
torch.ops._C.swiglu(x, out)
return out

View File

@@ -0,0 +1,3 @@
# from .backends import KunlunMetadata
# __all__ = ['KunlunMetadata']

View File

@@ -0,0 +1,3 @@
# from .kunlun_attn import KunlunMetadata
# __all__ = ['KunlunMetadata']

View File

@@ -0,0 +1,803 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
# Email: baoqian@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun attention wrapper for context and decode"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from itertools import accumulate
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from .utils import CommonAttentionState, CommonMetadataBuilder
from vllm.attention.backends.utils import (
is_block_tables_empty,
compute_slot_mapping_start_idx,
compute_slot_mapping,
)
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm_kunlun.ops._kunlun_ops import KunlunOps
from vllm.attention.backends.abstract import AttentionLayer
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d
logger = init_logger(__name__)
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
accept_output_buffer = False
@staticmethod
def get_name() -> str:
return "KUNLUN_ATTENTION"
@staticmethod
def get_impl_cls() -> Type["KunlunAttentionImpl"]:
"""get_impl_cls"""
return KunlunAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["KunlunMetadata"]:
"""get_metadata_cls"""
return KunlunMetadata
@staticmethod
def get_builder_cls() -> Type["KunlunMetadataBuilder"]:
"""get_builder_cls"""
return KunlunMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size
)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
"""KunlunMetadata"""
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# Max number of key/value length in the batch, especially for prefix cache
max_kv_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
query_start_loc_host: Optional[torch.Tensor] = None
# serve only for prefix cache
kv_prefix_start_loc_host: Optional[torch.Tensor] = None
kv_prefix_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["KunlunMetadata"] = None
_cached_decode_metadata: Optional["KunlunMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
@property
def is_all_encoder_attn_metadata_set(self):
"""
All attention metadata required for encoder attention is set.
"""
return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)
)
@property
def is_all_cross_attn_metadata_set(self):
"""
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
"""
return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)
)
@property
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
"""prefill_metadata"""
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
# Compute some attn_metadata fields which default to None
query_start_loc = (
None
if self.query_start_loc is None
else self.query_start_loc[: self.num_prefills + 1]
)
# flash attention needs both lod information on host and device
query_start_loc_host = (
None
if self.query_start_loc_host is None
else self.query_start_loc_host[: self.num_prefills + 1]
)
kv_prefix_start_loc_host = (
None
if self.kv_prefix_start_loc_host is None
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
+ query_start_loc_host
)
kv_prefix_start_loc = (
None
if kv_prefix_start_loc_host is None
else kv_prefix_start_loc_host.cuda()
)
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[: self.num_prefill_tokens]
)
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[: self.num_prefills]
)
context_lens_tensor = (
None
if self.context_lens_tensor is None
else self.context_lens_tensor[: self.num_prefills]
)
# for prefix cache, block table only contains blocks that hit
# if self.block_tables is None:
# block_tables = None
# elif self.block_tables.shape[1] == 0:
# block_tables = self.block_tables[:self.num_prefills]
# else:
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
block_tables = (
None
if self.block_tables is None
else self.block_tables[: self.num_prefills]
)
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_kv_len=self.max_kv_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
query_start_loc_host=query_start_loc_host,
kv_prefix_start_loc=kv_prefix_start_loc,
kv_prefix_start_loc_host=kv_prefix_start_loc_host,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
seq_start_loc=self.seq_start_loc,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["KunlunMetadata"]:
"""decode_metadata"""
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
# Compute some attn_metadata fields which default to None
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[self.num_prefill_tokens :]
)
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[self.num_prefills :]
)
seq_lens_tensor_cpu = (
None
if self.seq_lens_tensor_cpu is None
else self.seq_lens_tensor_cpu[self.num_prefills :]
)
block_tables = (
None
if self.block_tables is None
else self.block_tables[self.num_prefills :]
)
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = KunlunMetadata(
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens_tensor,
seq_lens_tensor_cpu=seq_lens_tensor_cpu,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
)
return self._cached_decode_metadata
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
"""KunlunMetadataBuilder"""
_metadata_cls = KunlunMetadata
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
super().__init__(input_builder)
self.prefix_cache_kv_lens: List[int] = []
def prepare(self):
"""prepare"""
super().prepare()
self.prefix_cache_kv_lens = list()
def _add_seq_group(
self,
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool,
):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (
seq_id,
token_len,
seq_len,
curr_seq_len,
query_len,
context_len,
curr_sliding_window_block,
) in zip(
inter_data.seq_ids,
[len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens,
inter_data.seq_lens,
inter_data.query_lens,
inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert (
query_len == 1
), "seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len
)
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
block_table = []
assert (
not chunked_prefill_enabled
), "chunk prefill not supported for kunlun attention"
if inter_data.prefix_cache_hit:
assert context_len != 0
assert context_len % self.block_size == 0
# block_table = block_tables[seq_id]
block_table = block_tables[seq_id][: context_len // self.block_size]
elif (not is_prompt) and block_tables is not None:
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
if is_prompt:
self.prefix_cache_kv_lens.append(context_len)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window
)
compute_slot_mapping(
is_profile_run,
self.slot_mapping,
seq_id,
seq_len,
context_len,
start_idx,
self.block_size,
inter_data.block_tables,
)
def build(
self,
seq_lens: List[int],
query_lens: List[int],
cuda_graph_pad_size: int,
batch_size: int,
):
"""build"""
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
query_start_loc = list(accumulate(query_lens, initial=0))
query_start_loc_host = torch.tensor(
query_start_loc, dtype=torch.int32, device="cpu"
)
attn_meta.query_start_loc_host = query_start_loc_host
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
# If kv cache is included and there is a hit
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
self.prefix_cache_kv_lens = list(
accumulate(self.prefix_cache_kv_lens, initial=0)
)
prefix_cache_kv_lens_tensor = torch.tensor(
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
)
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
return attn_meta
def _get_seq_len_block_table_args(
attn_metadata: KunlunMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
"""
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
"""
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables,
)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
None,
)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"""KunlunAttentionImpl"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError("kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}."
)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "KunlunAttnMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with KunlunAttn and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* KunlunAttnImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
Used for encoder branch of encoder-decoder models.
* ENCODER_ONLY: no kv_caching, uses the normal attention
attributes (seq_lens/seq_lens_tensor/max_seq_len).
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if attn_type == AttentionType.ENCODER and (
not attn_metadata.is_all_encoder_attn_metadata_set
):
raise AttributeError(
"Encoder attention requires setting " "encoder metadata attributes."
)
elif attn_type == AttentionType.ENCODER_DECODER and (
not attn_metadata.is_all_cross_attn_metadata_set
):
raise AttributeError(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
updated_slot_mapping = attn_metadata.slot_mapping
value = value.contiguous()
KunlunOps.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
)
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_encoder_tokens]
value = value[:num_encoder_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
out = KunlunOps.multi_query_kv_attention(
prefill_meta.query_start_loc,
prefill_meta.query_start_loc_host,
query,
key,
value,
alibi_slopes=self.alibi_slopes,
).view_as(query)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
assert (
attn_type != AttentionType.ENCODER_ONLY
), "Encoder-only models should not have decode metadata."
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
block_tables_arg,
seq_lens_arg,
decode_meta.seq_lens_tensor_cpu,
False,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)

View File

@@ -0,0 +1,604 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import numpy as np
import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
PAD_SLOT_ID = -1
# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
return (isinstance(block_tables, dict)
and all(value is None for value in block_tables.values()))
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
start_idx = max(0, query_len - sliding_window)
return start_idx
def _compute_slot_mapping_python(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
for i in range(range_start, range_end):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
def _compute_slot_mapping_numpy(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
block_table_array = np.array(block_table)
idx = np.arange(range_start, range_end)
block_offset = idx % block_size
idx //= block_size
seq_slot_mapping_array = block_table_array[idx]
seq_slot_mapping_array *= block_size
seq_slot_mapping_array += block_offset
slot_mapping.extend(seq_slot_mapping_array)
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
block_tables: Dict[int, List[int]]):
"""
Compute slot mapping.
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
padding_mask_len = max(0, start_idx - context_len)
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start
block_table = block_tables[seq_id]
# numpy implementation will be faster than python if we have
# many elements, otherwise it will be slower.
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
range_end, block_size)
else:
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
range_end, block_size)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
"""CommonMetadataBuilder"""
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
"""prepare"""
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class CommonAttentionState(AttentionState):
"""CommonAttentionState"""
def __init__(self, runner: "ModelRunnerBase"):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
"""graph_capture"""
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_seq_lens_cpu = self._graph_seq_lens.to('cpu')
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_seq_lens_cpu
del self._graph_block_tables
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
"""graph_clone"""
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
"""graph_capture_get_metadata_for_batch"""
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
seq_lens_tensor_cpu=self._graph_seq_lens_cpu[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
return attn_metadata
def get_graph_input_buffers(
self,
attn_metadata,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""get_graph_input_buffers"""
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"seq_lens_tensor_cpu": attn_metadata.decode_metadata.seq_lens_tensor_cpu,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additional_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
def prepare_graph_input_buffers(
self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False) -> None:
"""prepare_graph_input_buffers"""
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
def begin_forward(self, model_input) -> None:
"""begin_forward"""
return
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
attn_metadata):
"""
Updates the attention metadata parameters for CUDA graph capture in an
encoder-decoder model.
This method modifies attention-related tensors and metadata required
for CUDA graph capture in encoder-decoder models. Specifically, it
updates the cross-attention and encoder sequence tensors in the
AttentionMetadata object.
"""
# During decode phase the cross_slot_mapping will be empty. Hence set
# an empty tensor for CUDA Graph capture.
attn_metadata.cross_slot_mapping = torch.tensor(
[], dtype=torch.int).cuda()
attn_metadata.cross_block_tables = torch.full(
(batch_size, self.runner.get_max_block_per_batch()),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.num_encoder_tokens = 0
def _add_additional_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]):
"""
Saves additional input buffers specific to the encoder-decoder model
from the attention metadata.
This method extracts and stores encoder-decoder related input buffers
from the `attn_metadata` into the `input_buffers` dictionary. The
buffers include encoder sequence lengths, cross-slot mappings, and
cross-block tables, which are essential for the encoder-decoder model
during CUDA graph replay.
"""
input_buffers["encoder_seq_lens_tensor"] = (
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
input_buffers["seq_lens_tensor_cpu"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor_cpu, non_blocking=True)
input_buffers["cross_slot_mapping"] = (
attn_metadata.decode_metadata.cross_slot_mapping)
input_buffers["cross_block_tables"] = (
attn_metadata.decode_metadata.cross_block_tables)
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
input_buffers: Dict[str,
Any]):
"""
Populates input buffers with data from the encoder-decoder model's
attention metadata.
This method fills the input buffers with encoder-decoder specific
tensors. It copies data from the `attn_metadata` and keyword arguments
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
The copied data includes attention-related metadata as well as input
IDs and positional information for the encoder.
"""
input_buffers["encoder_seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
non_blocking=True)
input_buffers["cross_slot_mapping"].copy_(
attn_metadata.decode_metadata.cross_slot_mapping,
non_blocking=True)
input_buffers["cross_block_tables"].copy_(
attn_metadata.decode_metadata.cross_block_tables,
non_blocking=True)
def is_all_encoder_attn_metadata_set(attn_metadata):
'''
All attention metadata required for encoder attention is set.
'''
return ((attn_metadata.encoder_seq_lens is not None)
and (attn_metadata.encoder_seq_lens_tensor is not None)
and (attn_metadata.max_encoder_seq_len is not None))
def is_all_cross_attn_metadata_set(attn_metadata):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (attn_metadata.is_all_encoder_attn_metadata_set
and (attn_metadata.cross_slot_mapping is not None)
and (attn_metadata.cross_block_tables is not None))
def get_seq_len_block_table_args(
attn_metadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_num_prefill_decode_query_kv_tokens(
attn_metadata,
attn_type: str,
) -> Tuple[int, int, int]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (AttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
num_prefill_kv_tokens = 0
if attn_type == AttentionType.ENCODER:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = 0
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
else: # attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)

View File

@@ -0,0 +1,274 @@
"""layer.py"""
import torch
import torch.nn.functional as F
from typing import Optional, List, Dict, Any
from vllm.attention import AttentionType
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.attention import Attention as VllmAttention
from vllm.attention.layer import MultiHeadAttention as VllmMultiHeadAttention
from torch.library import custom_op, impl
from vllm.platforms import _Backend
class Attention(VllmAttention):
"""Attention"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
use_mla=use_mla,
prefix=prefix,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
**extra_impl_args,
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
"""forward"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
# processed differently.
if not self.use_mla:
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
)
else:
torch.ops.vllm.unified_attention_with_output_kunlun(
query, key, value, output, self.layer_name
)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
else:
return unified_attention(query, key, value, self.layer_name)
#
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
class MultiHeadAttention(VllmMultiHeadAttention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
):
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
)
# kunlun only supports flash_attn
self.attn_backend = _Backend.FLASH_ATTN
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA3
bsz, q_len, _ = query.size()
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
# kunlun only supports flash_attn
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_func
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
"""wait_for_kv_layer_from_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str, kv_cache_layer: List[torch.Tensor]
):
"""maybe_save_kv_layer_to_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
def unified_attention_with_output_kunlun(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def _fake_unified_attention_with_output_kunlun(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return None
unified_attention_with_output_kunlun.register_fake(
_fake_unified_attention_with_output_kunlun
)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
"""unified_attention"""
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output

View File

View File

@@ -0,0 +1,310 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Author: Dong Xinyu, Chen Zhennan, Bao Qian, Yuan Jizhong
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""layer.py"""
import torch
from typing import Callable, Optional
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.layer import (
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
"""FusedMoEMethodBase"""
moe: FusedMoEConfig
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
"""UnquantizedFusedMoEMethod"""
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
linear_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""apply"""
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
return self.forward_kunlun(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
linear_weights=linear_weights)
def forward_kunlun(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.moe.use_ep:
return ops.fused_moe_ep(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group
)
else:
return ops.fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
linear_weights,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group
)
class FusedMoE(VllmFusedMoE):
"""FusedMoE"""
def __init__(self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = 0,
topk_group: Optional[int] = 0,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
):
super().__init__(
num_experts=num_experts, # Global number of experts
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
)
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
model_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
# assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm_kunlun.ops.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor = None,
linear_weights: torch.Tensor = None):
"""forward"""
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[self.layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, linear_weights)
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
# self.layer_name)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor = None):
"""forward_impl"""
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
linear_weights=linear_weights
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states

View File

@@ -0,0 +1,60 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
#
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from typing import Optional, Union
import xtorch_ops
def vllm_kunlun_forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x,
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
)
return out
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda

Some files were not shown because too many files have changed in this diff Show More