Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

11
Dockerfile Normal file
View File

@@ -0,0 +1,11 @@
FROM registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8
# Keep the runtime stack from the known-good v8 image, but replace the
# installed Python package with the repository's patched 0.16.1rc0 sources.
WORKDIR /tmp
RUN rm -rf /usr/local/lib/python3.12/dist-packages/vllm \
/usr/local/lib/python3.12/dist-packages/vllm-*.dist-info
COPY vllm /usr/local/lib/python3.12/dist-packages/vllm
COPY vllm-0.16.1rc0+corex.4.4.0.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.16.1rc0+corex.4.4.0.dist-info

60
README.md Normal file
View File

@@ -0,0 +1,60 @@
# bi_150-vllm
基于 `registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8`
`vLLM 0.16.1rc0` 构建仓库,用于在 BI-V150 虚拟机环境中生成可直接运行的镜像。
## 改动说明
本仓库只保留构建镜像所需的最小内容:
- `vllm/`
当前运行代码
- `vllm-0.16.1rc0+corex.4.4.0.dist-info/`
对应的包元数据
- `Dockerfile`
构建最终镜像
与基础镜像相比,本仓库保留的关键代码改动如下:
-`vllm/platforms/__init__.py` 中修复 CUDA 平台识别逻辑
- 当 NVML 不可用且出现 `NVML Shared Library Not Found` 一类错误时
不再直接判定为非 CUDA 平台
- 改为回退到 `torch.cuda.is_available()`
`torch.cuda.device_count()` 继续判断 CUDA 是否可用
这个修复用于解决如下启动失败:
```text
RuntimeError: Failed to infer device type
```
## 构建镜像
在仓库根目录执行:
```bash
docker build -t bi_150_vllm:0.16.1 .
```
## 启动镜像
```bash
docker run -dit \
--name iluvatar_sut_submit_325811 \
-p 38047:8000 \
--privileged \
-v /lib/modules:/lib/modules \
-v /dev:/dev \
-v /usr/src:/usr/src \
-v /mnt/gpfs/leaderboard/modelHubXC/Amu/t1-1.5B:/model \
-e CUDA_VISIBLE_DEVICES=0 \
--entrypoint vllm \
bi_150_vllm:0.16.1 \
serve /model \
--port 8000 \
--served-model-name llm \
--max-model-len 2048 \
--enforce-eager \
--trust-remote-code \
-tp 1
```

View File

@@ -0,0 +1 @@
pip

View File

@@ -0,0 +1,211 @@
Metadata-Version: 2.4
Name: vllm
Version: 0.16.1rc0+corex.4.4.0
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
Author: vLLM Team
License-Expression: Apache-2.0
Project-URL: Homepage, https://github.com/vllm-project/vllm
Project-URL: Documentation, https://docs.vllm.ai/en/latest/
Project-URL: Slack, https://slack.vllm.ai/
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Information Technology
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Information Analysis
Requires-Python: <3.14,>=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: regex
Requires-Dist: cachetools
Requires-Dist: psutil
Requires-Dist: sentencepiece
Requires-Dist: numpy==1.26.4
Requires-Dist: requests>=2.26.0
Requires-Dist: tqdm
Requires-Dist: blake3
Requires-Dist: py-cpuinfo
Requires-Dist: transformers<5,>=4.56.0
Requires-Dist: tokenizers>=0.21.1
Requires-Dist: protobuf!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*,>=5.29.6
Requires-Dist: fastapi[standard]>=0.115.0
Requires-Dist: aiohttp>=3.13.3
Requires-Dist: openai>=1.99.1
Requires-Dist: pydantic>=2.12.0
Requires-Dist: prometheus_client>=0.18.0
Requires-Dist: pillow
Requires-Dist: prometheus-fastapi-instrumentator>=7.0.0
Requires-Dist: tiktoken>=0.6.0
Requires-Dist: lm-format-enforcer==0.11.3
Requires-Dist: llguidance<1.4.0,>=1.3.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" or platform_machine == "s390x" or platform_machine == "ppc64le"
Requires-Dist: outlines_core==0.2.11
Requires-Dist: diskcache==5.6.3
Requires-Dist: lark==1.2.2
Requires-Dist: xgrammar==0.1.29; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le"
Requires-Dist: typing_extensions>=4.10
Requires-Dist: filelock>=3.16.1
Requires-Dist: partial-json-parser
Requires-Dist: pyzmq>=25.0.0
Requires-Dist: msgspec
Requires-Dist: gguf>=0.17.0
Requires-Dist: mistral_common[image]>=1.9.1
Requires-Dist: pyyaml
Requires-Dist: six>=1.16.0; python_version > "3.11"
Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11"
Requires-Dist: einops
Requires-Dist: compressed-tensors==0.13.0
Requires-Dist: depyf==0.20.0
Requires-Dist: cloudpickle
Requires-Dist: watchfiles
Requires-Dist: python-json-logger
Requires-Dist: ninja
Requires-Dist: pybase64
Requires-Dist: cbor2
Requires-Dist: ijson
Requires-Dist: setproctitle
Requires-Dist: openai-harmony>=0.0.3
Requires-Dist: anthropic>=0.71.0
Requires-Dist: model-hosting-container-standards<1.0.0,>=0.1.13
Requires-Dist: mcp
Requires-Dist: grpcio
Requires-Dist: grpcio-reflection
Requires-Dist: opentelemetry-sdk>=1.27.0
Requires-Dist: opentelemetry-api>=1.27.0
Requires-Dist: opentelemetry-exporter-otlp>=1.27.0
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1
Requires-Dist: numba==0.61.2
Requires-Dist: ray[cgraph]>=2.48.0
Provides-Extra: bench
Requires-Dist: pandas; extra == "bench"
Requires-Dist: matplotlib; extra == "bench"
Requires-Dist: seaborn; extra == "bench"
Requires-Dist: datasets; extra == "bench"
Requires-Dist: scipy; extra == "bench"
Provides-Extra: tensorizer
Requires-Dist: tensorizer==2.10.1; extra == "tensorizer"
Provides-Extra: fastsafetensors
Requires-Dist: fastsafetensors>=0.2.2; extra == "fastsafetensors"
Provides-Extra: runai
Requires-Dist: runai-model-streamer[gcs,s3]>=0.15.3; extra == "runai"
Provides-Extra: audio
Requires-Dist: librosa; extra == "audio"
Requires-Dist: scipy; extra == "audio"
Requires-Dist: soundfile; extra == "audio"
Requires-Dist: mistral_common[audio]; extra == "audio"
Provides-Extra: video
Provides-Extra: flashinfer
Provides-Extra: otel
Requires-Dist: opentelemetry-sdk>=1.26.0; extra == "otel"
Requires-Dist: opentelemetry-api>=1.26.0; extra == "otel"
Requires-Dist: opentelemetry-exporter-otlp>=1.26.0; extra == "otel"
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1; extra == "otel"
Dynamic: license-file
Dynamic: provides-extra
Dynamic: requires-dist
<!-- markdownlint-disable MD001 MD041 -->
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png">
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-light.png" width=55%>
</picture>
</p>
<h3 align="center">
Easy, fast, and cheap LLM serving for everyone
</h3>
<p align="center">
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
</p>
🔥 We have built a vllm website to help you get started with vllm. Please visit [vllm.ai](https://vllm.ai) to learn more.
For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
---
## About
vLLM is a fast and easy-to-use library for LLM inference and serving.
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
vLLM is fast with:
- State-of-the-art serving throughput
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
- Speculative decoding
- Chunked prefill
vLLM is flexible and easy to use with:
- Seamless integration with popular Hugging Face models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor, pipeline, data and expert parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
- Prefix caching support
- Multi-LoRA support
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
- Embedding Models (e.g., E5-Mistral)
- Multi-modal LLMs (e.g., LLaVA)
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
## Getting Started
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
```bash
pip install vllm
```
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
## Contributing
We welcome and value any contributions and collaborations.
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
## Citation
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
```bibtex
@inproceedings{kwon2023efficient,
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
year={2023}
}
```
## Contact Us
<!-- --8<-- [start:contact-us] -->
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
- For collaborations and partnerships, please contact us at [collaboration@vllm.ai](mailto:collaboration@vllm.ai)
<!-- --8<-- [end:contact-us] -->
## Media Kit
- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (82.0.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1 @@
{"archive_info": {"hash": "sha256=f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1", "hashes": {"sha256": "f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1"}}, "url": "file:///workspace/vllm-0.16.1rc0%2Bcorex.4.4.0-py3-none-any.whl"}

View File

@@ -0,0 +1,6 @@
[console_scripts]
vllm = vllm.entrypoints.cli.main:main
[vllm.general_plugins]
lora_filesystem_resolver = vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
lora_hf_hub_resolver = vllm.plugins.lora_resolvers.hf_hub_resolver:register_hf_hub_resolver

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1 @@
vllm

244
vllm/.gitignore vendored Normal file
View File

@@ -0,0 +1,244 @@
# version file generated by setuptools-scm
/vllm/_version.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/flash_attn_interface.py
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*
# FlashMLA interface copied from source
vllm/third_party/flashmla/flash_mla_interface.py
# triton jit
.triton
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
/.deps/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# generated files
**/generated/**
# uv
uv.lock
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
docs/argparse
docs/examples/*
!docs/examples/README.md
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VSCode
.vscode/
# Claude
.claude/
# Codex
.codex/
# Cursor
.cursor/
# DS Store
.DS_Store
# Results
*.csv
# Python pickle files
*.pkl
# Sphinx documentation
_build/
# vim swap files
*.swo
*.swp
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h
# Benchmark dataset
benchmarks/**/*.json
# Linting
actionlint
shellcheck*/
# Ignore moe/marlin_moe gen code
csrc/moe/marlin_moe_wna16/kernel_*
# Ignore ep_kernels_workspace folder
ep_kernels_workspace/
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
!vllm/benchmarks/lib/
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
# Ignore generated cpu headers
csrc/cpu/cpu_attn_dispatch_generated.h

107
vllm/__init__.py Normal file
View File

@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
# The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip
import typing
# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"bc_linter_skip": "._bc_linter:bc_linter_skip",
"bc_linter_include": "._bc_linter:bc_linter_include",
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
"CompletionOutput": ".outputs:CompletionOutput",
"EmbeddingOutput": ".outputs:EmbeddingOutput",
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
"PoolingOutput": ".outputs:PoolingOutput",
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
"RequestOutput": ".outputs:RequestOutput",
"ScoringOutput": ".outputs:ScoringOutput",
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
}
if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (
ClassificationOutput,
ClassificationRequestOutput,
CompletionOutput,
EmbeddingOutput,
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
ScoringOutput,
ScoringRequestOutput,
)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.executor.ray_utils import initialize_ray_cluster
from ._bc_linter import bc_linter_include, bc_linter_skip
else:
def __getattr__(name: str) -> typing.Any:
from importlib import import_module
if name in MODULE_ATTRS:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
else:
raise AttributeError(f"module {__package__} has no attribute {name}")
__all__ = [
"__version__",
"bc_linter_skip",
"bc_linter_include",
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"PoolingOutput",
"PoolingRequestOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"ClassificationOutput",
"ClassificationRequestOutput",
"ScoringOutput",
"ScoringRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]

1810
vllm/_aiter_ops.py Normal file

File diff suppressed because it is too large Load Diff

54
vllm/_bc_linter.py Normal file
View File

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# vllm/_bc_linter.py
from collections.abc import Callable
from typing import Any, TypeVar, overload
T = TypeVar("T")
@overload
def bc_linter_skip(obj: T) -> T: ...
@overload
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
"""
No-op decorator to mark symbols/files for BC-linter suppression.
Usage:
@bc_linter_skip
def legacy_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
@overload
def bc_linter_include(obj: T) -> T: ...
@overload
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
"""
Usage:
@bc_linter_include
def public_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
__all__ = ["bc_linter_skip", "bc_linter_include"]

4238
vllm/_custom_ops.py Normal file

File diff suppressed because it is too large Load Diff

96
vllm/_oink_ops.py Normal file
View File

@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Small helper wrappers for external Oink Blackwell custom ops.
vLLM does not depend on the external Oink repository/package. When an external
plugin registers torch.library.custom_op entrypoints under the `oink::`
namespace (e.g. via vLLM's general_plugins mechanism) and
`VLLM_USE_OINK_OPS=1` is set, vLLM can route eligible calls to those ops.
This module provides:
- A single place to probe Oink op availability at module init time
(outside torch.compile tracing), and
- Thin wrappers around the torch.ops entrypoints for use in CUDA fast paths,
without introducing graph breaks.
Important:
Do not call the availability helpers in a compiled region. They may call
functions decorated with `torch._dynamo.disable` to safely check
conditions that should not be traced.
"""
from __future__ import annotations
from collections.abc import Callable
import torch
try:
from torch._dynamo import disable as _dynamo_disable # type: ignore[attr-defined]
except Exception: # pragma: no cover
def _dynamo_disable(fn: Callable): # type: ignore[misc]
return fn
def _has_oink_op(op_name: str) -> bool:
"""Check if a specific oink op is registered."""
return hasattr(torch.ops, "oink") and hasattr(torch.ops.oink, op_name)
@_dynamo_disable
def is_oink_available_for_device(device_index: int) -> bool:
"""Return True if Oink ops are registered and device is SM100+.
This function is intended to be called during module initialization
(e.g., in RMSNorm.__init__), not in the forward path.
External plugins are expected to gate registration on SM100+ and
VLLM_USE_OINK_OPS=1, so if the ops are present they should be usable.
"""
if not torch.cuda.is_available():
return False
try:
major, minor = torch.cuda.get_device_capability(device_index)
sm = 10 * major + minor
if sm < 100:
return False
except Exception:
return False
return _has_oink_op("rmsnorm")
def has_fused_add_rms_norm() -> bool:
"""Return True if the in-place fused op is registered."""
return _has_oink_op("fused_add_rms_norm")
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Call `torch.ops.oink.rmsnorm`.
This wrapper is safe to call in torch.compile regions.
"""
return torch.ops.oink.rmsnorm(x, weight, eps)
def fused_add_rms_norm_(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> None:
"""Call `torch.ops.oink.fused_add_rms_norm` (mutates x and residual)."""
torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps)
def fused_add_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convenience wrapper returning (x, residual) after in-place mutation."""
fused_add_rms_norm_(x, residual, weight, eps)
return x, residual

159
vllm/_xpu_ops.py Normal file
View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
@register_fake("_xpu_C::fp8_gemm_w8a16")
def _fp8_gemm_w8a16_fake(
input: torch.Tensor,
q_weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
input_2d = input.view(-1, input.shape[-1])
M = input_2d.size(0)
N = q_weight.size(1)
return torch.empty((M, N), dtype=input.dtype, device=input.device)
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
@register_fake("_xpu_C::int4_gemm_w4a16")
def _int4_gemm_w4a16_fake(
input: torch.Tensor,
q_weight: torch.Tensor,
bias: torch.Tensor | None,
weight_scale: torch.Tensor,
qzeros: torch.Tensor,
group_size: int,
group_idx: torch.Tensor | None = None,
) -> torch.Tensor:
input_2d = input.view(-1, input.shape[-1])
M = input_2d.size(0)
N = q_weight.size(1)
return torch.empty((M, N), dtype=input.dtype, device=input.device)
class xpu_ops:
@staticmethod
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float | None = None,
causal: bool = False,
out: torch.Tensor | None = None,
block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None,
softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
# The following parameters are not used in xpu kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
fa_version: int = 2,
q_descale=None,
k_descale=None,
v_descale=None,
num_splits=0,
return_softmax_lse: bool | None = False,
s_aux: torch.Tensor | None = None,
):
assert cu_seqlens_k is not None or seqused_k is not None, (
"cu_seqlens_k or seqused_k must be provided"
)
assert cu_seqlens_k is None or seqused_k is None, (
"cu_seqlens_k and seqused_k cannot be provided at the same time"
)
assert block_table is None or seqused_k is not None, (
"when enable block_table, seqused_k is needed"
)
assert block_table is not None or cu_seqlens_k is not None, (
"when block_table is disabled, cu_seqlens_k is needed"
)
if out is None:
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1]) # noqa: F841
# In encode attention, k and v maybe not contiguous and current
# kernel can't handle it
if block_table is None:
k = k.contiguous()
v = v.contiguous()
return flash_attn_varlen_func(
out=out,
q=q.contiguous(),
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
block_table=block_table,
s_aux=s_aux,
window_size=real_window_size,
# alibi_slopes = alibi_slopes,
# softcap=softcap,
return_softmax_lse=return_softmax_lse,
)
@staticmethod
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: torch.Tensor | None = None,
cu_seqlens_k_new: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = None,
page_size: int | None = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
) -> None:
logger.warning_once(
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
)
return None

0
vllm/assets/__init__.py Normal file
View File

43
vllm/assets/audio.py Normal file
View File

@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from urllib.parse import urljoin
import numpy.typing as npt
from vllm.utils.import_utils import PlaceholderModule
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
ASSET_DIR = "multimodal_asset"
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
@dataclass(frozen=True)
class AudioAsset:
name: AudioAssetName
@property
def filename(self) -> str:
return f"{self.name}.ogg"
@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
def get_local_path(self) -> Path:
return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
@property
def url(self) -> str:
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")

40
vllm/assets/base.py Normal file
View File

@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import lru_cache
from pathlib import Path
import vllm.envs as envs
from vllm.connections import global_http_connection
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
def get_cache_dir() -> Path:
"""Get the path to the cache for storing downloaded assets."""
path = Path(envs.VLLM_ASSETS_CACHE)
path.mkdir(parents=True, exist_ok=True)
return path
@lru_cache
def get_vllm_public_assets(filename: str, s3_prefix: str | None = None) -> Path:
"""
Download an asset file from `s3://vllm-public-assets`
and return the path to the downloaded file.
"""
asset_directory = get_cache_dir() / "vllm_public_assets"
asset_directory.mkdir(parents=True, exist_ok=True)
asset_path = asset_directory / filename
if not asset_path.exists():
if s3_prefix is not None:
filename = s3_prefix + "/" + filename
global_http_connection.download_file(
f"{VLLM_S3_BUCKET_URL}/{filename}",
asset_path,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
return asset_path

62
vllm/assets/image.py Normal file
View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
ImageAssetName = Literal[
"stop_sign",
"cherry_blossom",
"hato",
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
"Grayscale_8bits_palette_sample_image",
"1280px-Venn_diagram_rgb",
"RGBA_comp",
"237-400x300",
"231-200x300",
"27-500x500",
"17-150x600",
"handelsblatt-preview",
"paper-11",
]
@dataclass(frozen=True)
class ImageAsset:
name: ImageAssetName
def get_path(self, ext: str) -> Path:
"""
Return s3 path for given image.
"""
return get_vllm_public_assets(
filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR
)
@property
def pil_image(self) -> Image.Image:
return self.pil_image_ext(ext="jpg")
def pil_image_ext(self, ext: str) -> Image.Image:
image_path = self.get_path(ext=ext)
return Image.open(image_path)
@property
def image_embeds(self) -> torch.Tensor:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = self.get_path("pt")
return torch.load(image_path, map_location="cpu", weights_only=True)
def read_bytes(self, ext: str) -> bytes:
p = Path(self.get_path(ext))
return p.read_bytes()

149
vllm/assets/video.py Normal file
View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, ClassVar, Literal
import numpy as np
import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.utils.import_utils import PlaceholderModule
from .base import get_cache_dir
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
@lru_cache
def download_video_asset(filename: str) -> str:
"""
Download and open an image from huggingface
repo: raushan-testing-hf/videos-test
"""
video_directory = get_cache_dir() / "video-example-data"
video_directory.mkdir(parents=True, exist_ok=True)
video_path = video_directory / filename
video_path_str = str(video_path)
if not video_path.exists():
video_path_str = hf_hub_download(
repo_id="raushan-testing-hf/videos-test",
filename=filename,
repo_type="dataset",
cache_dir=video_directory,
)
return video_path_str
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
import cv2
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
num_frames = num_frames if num_frames > 0 else total_frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for idx in range(total_frames):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_indices: # only decompress needed
ret, frame = cap.retrieve()
if ret:
# OpenCV uses BGR format, we need to convert it to RGB
# for PIL and transformers compatibility
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames = np.stack(frames)
if len(frames) < num_frames:
raise ValueError(
f"Could not read enough frames from video file {path}"
f" (expected {num_frames} frames, got {len(frames)})"
)
return frames
def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]:
frames = video_to_ndarrays(path, num_frames)
return [Image.fromarray(frame) for frame in frames]
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
import cv2
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
if num_frames == -1 or num_frames > total_frames:
num_frames = total_frames
metadata = {
"total_num_frames": num_frames,
"fps": duration / num_frames,
"duration": duration,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": num_frames == total_frames,
}
return metadata
VideoAssetName = Literal["baby_reading"]
@dataclass(frozen=True)
class VideoAsset:
name: VideoAssetName
num_frames: int = -1
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
"baby_reading": "sample_demo_1.mp4",
}
@property
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def video_path(self) -> str:
return download_video_asset(self.filename)
@property
def pil_images(self) -> list[Image.Image]:
ret = video_to_pil_images_list(self.video_path, self.num_frames)
return ret
@property
def np_ndarrays(self) -> npt.NDArray:
ret = video_to_ndarrays(self.video_path, self.num_frames)
return ret
@property
def metadata(self) -> dict[str, Any]:
ret = video_get_metadata(self.video_path, self.num_frames)
return ret
def get_audio(self, sampling_rate: float | None = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
return librosa.load(self.video_path, sr=sampling_rate)[0]

109
vllm/beam_search.py Normal file
View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from vllm.inputs import TokenInputs, token_inputs
from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
orig_prompt: TokenInputs | MultiModalInputs
# The tokens include the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: LoRARequest | None = None
cum_logprob: float = 0.0
text: str | None = None
finish_reason: str | None = None
stop_reason: int | str | None = None
def get_prompt(self):
prompt = self.orig_prompt
prompt_text = prompt.get("prompt")
cache_salt = prompt.get("cache_salt")
if prompt["type"] == "token":
return token_inputs(
self.tokens,
prompt=prompt_text,
cache_salt=cache_salt,
)
return mm_inputs(
prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"],
mm_placeholders=prompt["mm_placeholders"],
prompt=prompt_text,
cache_salt=cache_salt,
)
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: list[BeamSearchSequence]
class BeamSearchInstance:
def __init__(
self,
prompt: TokenInputs | MultiModalInputs,
lora_request: LoRARequest | None = None,
logprobs: list[dict[int, Logprob]] | None = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt["prompt_token_ids"],
logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request,
**kwargs,
)
]
self.completed: list[BeamSearchSequence] = []
def get_beam_search_score(
tokens: list[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(
x.tokens, x.cum_logprob, eos_token_id, length_penalty
)
return sort_beams_key

View File

3453
vllm/benchmarks/datasets.py Normal file

File diff suppressed because it is too large Load Diff

172
vllm/benchmarks/latency.py Normal file
View File

@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from typing import Any
import numpy as np
from tqdm import tqdm
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
)
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument(
"--num-iters", type=int, default=30, help="Number of iterations to run."
)
parser.add_argument(
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the latency results in JSON format.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)
# V1 enables prefix caching by default which skews the latency
# numbers. We need to disable prefix caching by default.
parser.set_defaults(enable_prefix_caching=False)
def main(args: argparse.Namespace):
engine_args = EngineArgs.from_cli_args(args)
# Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len
), (
"Please ensure that max_model_len is greater than"
" the sum of input_len and output_len."
)
sampling_params = SamplingParams(
n=args.n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
dummy_prompt_token_ids = np.random.randint(
10000, size=(args.batch_size, args.input_len)
)
dummy_prompts: list[PromptType] = [
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
),
)
def run_to_completion(do_profile: bool = False):
if do_profile:
llm.start_profile()
llm_generate()
llm.stop_profile()
else:
start_time = time.perf_counter()
llm_generate()
end_time = time.perf_counter()
latency = end_time - start_time
return latency
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(do_profile=False)
if args.profile:
profiler_config = engine_args.profiler_config
if profiler_config.profiler == "torch":
print(
"Profiling with torch profiler (results will be saved to"
f" {profiler_config.torch_profiler_dir})..."
)
elif profiler_config.profiler == "cuda":
print("Profiling with cuda profiler ...")
run_to_completion(do_profile=True)
return
# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Bench iterations"):
latencies.append(run_to_completion(do_profile=False))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f"{percentage}% percentile latency: {percentile} seconds")
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

View File

@@ -0,0 +1,539 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
r"""Benchmark multimodal processor latency.
This benchmark measures the latency of the mm processor module
using multimodal prompts from datasets.
MM processor stats are automatically enabled.
Run:
vllm bench mm-processor \
--model <your_model> \
--dataset-name random-mm \
--num-prompts 10 \
"""
import argparse
import dataclasses
import json
import time
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
from vllm.benchmarks.datasets import (
MultiModalConversationDataset,
VisionArenaDataset,
)
from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
if TYPE_CHECKING: # Avoid having to mock during docs build
from vllm.v1.engine.llm_engine import LLMEngine
else:
LLMEngine = object
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
"""
Get all multimodal timing stats from the LLM engine.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args:
llm_engine: The LLM engine (has input_processor and workers).
Returns:
Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'get_mm_hashes_secs': 0.02,
'get_cache_missing_items_secs': 0.01,
'apply_hf_processor_secs': 0.45,
'merge_mm_kwargs_secs': 0.01,
'apply_prompt_updates_secs': 0.03,
'preprocessor_total_secs': 0.51,
'encoder_forward_secs': 0.23,
'num_encoder_calls': 1
}
}
"""
observability_config = llm_engine.vllm_config.observability_config
if not observability_config or not observability_config.enable_mm_processor_stats:
return {}
renderer = llm_engine.renderer
mm_processor_stats = renderer._mm_timing_registry.stat()
encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_secs", 0.0
)
new_time = stats_dict.get("encoder_forward_secs", 0.0)
encoder_stats[request_id]["encoder_forward_secs"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in mm_processor_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine(llm_engine)
stats_by_stage = defaultdict[str, list[float]](list)
for stats_dict in all_stats.values():
for stat_key, stat_val in stats_dict.items():
stats_by_stage[stat_key].append(stat_val)
return stats_by_stage
def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float],
*,
unit: Literal["us", "ms", "s"] = "ms",
) -> dict[str, dict[str, float]]:
"""
Calculate aggregate metrics from stats by stage.
"""
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
unit_mult = unit2mult[unit]
metrics = {}
for stage, times in stats_by_stage.items():
stage_name = stage.replace("_secs", "_" + unit)
if not times:
metrics[stage_name] = {
"mean": 0.0,
"median": 0.0,
"std": 0.0,
**{f"p{p}": 0.0 for p in selected_percentiles},
}
continue
is_count_metric = stage == "num_encoder_calls"
values = times if is_count_metric else [t * unit_mult for t in times]
metrics[stage_name] = {
"mean": float(np.mean(values)),
"median": float(np.median(values)),
"std": float(np.std(values)),
**{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
}
return metrics
def validate_args(args):
"""
Validate command-line arguments for mm_processor benchmark.
"""
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
if not hasattr(args, "dataset_path"):
args.dataset_path = None
if not hasattr(args, "lora_path"):
args.lora_path = None
if not hasattr(args, "max_loras"):
args.max_loras = None
if args.dataset_name == "hf" and not args.dataset_path:
raise ValueError(
"--dataset-path is required when using --dataset-name hf. "
"For multimodal benchmarking, specify a dataset like "
"'lmarena-ai/VisionArena-Chat'."
)
if args.dataset_name == "hf":
supported_mm_datasets = (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
)
if args.dataset_path not in supported_mm_datasets:
raise ValueError(
f"{args.dataset_path} is not a supported multimodal dataset. "
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
)
def benchmark_multimodal_processor(
args: argparse.Namespace,
) -> dict[str, Any]:
"""
Run the multimodal processor benchmark.
"""
from vllm import LLM, SamplingParams
validate_args(args)
if args.seed is None:
args.seed = 0
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
tokenizer = llm.get_tokenizer()
requests = get_requests(args, tokenizer)
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = [request.prompt for request in requests]
expected_output_lens = [request.expected_output_len for request in requests]
sampling_params = [
SamplingParams(
n=1,
temperature=0.0,
max_tokens=output_len,
detokenize=True,
)
for output_len in expected_output_lens
]
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
freeze_gc_heap()
num_warmups = getattr(args, "num_warmups", 0)
if num_warmups > 0:
print(f"Processing {num_warmups} warmup requests...")
# Create a temporary args object for warmup requests
warmup_args = argparse.Namespace(**vars(args))
warmup_args.num_prompts = num_warmups
warmup_args.seed += 1
warmup_requests = get_requests(warmup_args, tokenizer)
warmup_prompts = [req.prompt for req in warmup_requests]
warmup_output_lens = [req.expected_output_len for req in warmup_requests]
warmup_sampling_params = [
SamplingParams(max_tokens=output_len) for output_len in warmup_output_lens
]
llm.chat(
warmup_prompts,
warmup_sampling_params,
use_tqdm=not getattr(args, "disable_tqdm", False),
)
# Clear stats from warmup requests
collect_mm_processor_stats(llm.llm_engine)
print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter()
outputs = llm.chat(
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
)
end_time = time.perf_counter()
total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
if not any(mm_stats_by_stage.values()):
print(
"\n⚠️ Warning: No MM processor stats found in registry.\n"
" This may indicate that:\n"
" - No multimodal requests were processed\n"
" - Stats were already retrieved (registry is cleared after retrieval)\n"
)
mm_processor_metrics = calculate_mm_processor_metrics(
mm_stats_by_stage, selected_percentiles
)
completed = len([o for o in outputs if o.finished])
failed = len(outputs) - completed
e2el_times = []
for output in outputs:
if not output.finished or output.metrics is None:
continue
metrics = output.metrics
# Calculate E2E latency as: TTFT + (last_token_ts - first_token_ts)
if (
getattr(metrics, "first_token_latency", None) is not None
and getattr(metrics, "last_token_ts", None) is not None
and getattr(metrics, "first_token_ts", None) is not None
):
ttft = metrics.first_token_latency
# Decode time is the duration between the first and last token generation
decode_time = max(0.0, metrics.last_token_ts - metrics.first_token_ts)
e2el_times.append((ttft + decode_time) * 1000)
if not e2el_times and completed > 0:
print(
"\n⚠️ Warning: Detailed end-to-end latency metrics not available.\n"
" Falling back to average request latency "
"(total_time / num_completed_requests).\n"
)
avg_time_per_request = total_time / completed
e2el_times = [avg_time_per_request * 1000] * completed
if e2el_times:
mean_e2el_ms = float(np.mean(e2el_times))
median_e2el_ms = float(np.median(e2el_times))
std_e2el_ms = float(np.std(e2el_times))
percentiles_e2el_ms = [
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
]
else:
mean_e2el_ms = 0.0
median_e2el_ms = 0.0
std_e2el_ms = 0.0
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
encoder_summary = {}
if (
"num_encoder_calls" in mm_stats_by_stage
and mm_stats_by_stage["num_encoder_calls"]
):
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
encoder_summary = {
"total_encoder_calls": int(sum(encoder_calls)),
"num_requests_with_encoder_calls": len(encoder_calls),
}
benchmark_result = {
"completed": completed,
"failed": failed,
"mean_e2el_ms": mean_e2el_ms,
"median_e2el_ms": median_e2el_ms,
"std_e2el_ms": std_e2el_ms,
"percentiles_e2el_ms": percentiles_e2el_ms,
"mm_processor_stats": mm_processor_metrics,
"encoder_summary": encoder_summary,
}
return benchmark_result
def add_cli_args(parser: argparse.ArgumentParser) -> None:
"""Add CLI arguments for the multimodal processor benchmark."""
from vllm.engine.arg_utils import EngineArgs
EngineArgs.add_cli_args(parser)
parser.set_defaults(enable_mm_processor_stats=True)
parser.add_argument(
"--dataset-name",
type=str,
default="random-mm",
choices=["random-mm", "hf"],
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=10,
help="Number of prompts to process.",
)
parser.add_argument(
"--num-warmups",
type=int,
default=1,
help="Number of warmup prompts to process.",
)
from vllm.benchmarks.datasets import (
add_random_dataset_base_args,
add_random_multimodal_dataset_args,
)
add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser)
# HuggingFace dataset arguments
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the dataset file or HuggingFace dataset name "
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
)
parser.add_argument(
"--hf-subset",
type=str,
default=None,
help="Subset of the HuggingFace dataset (optional).",
)
parser.add_argument(
"--hf-split",
type=str,
default=None,
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. "
"Overrides the default output lengths from the dataset.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the benchmark results in JSON format.",
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Disable tqdm progress bar.",
)
def main(args: argparse.Namespace) -> None:
"""Main entry point for the multimodal processor benchmark."""
print("Starting multimodal processor benchmark...")
result = benchmark_multimodal_processor(args)
print("\n" + "=" * 80)
print("Multimodal Processor Benchmark Results")
print("=" * 80)
if "mm_processor_stats" in result:
print("\nMM Processor Metrics:")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
mm_data = []
for stage, metrics in result["mm_processor_stats"].items():
row = {
"Stage": stage,
"Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}",
}
for p in selected_percentiles:
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
mm_data.append(row)
mm_df = pd.DataFrame(mm_data)
print(mm_df.to_string(index=False))
if "encoder_summary" in result and result["encoder_summary"]:
total_calls = result["encoder_summary"]["total_encoder_calls"]
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
print(
f"\nSummary: {total_calls} total encoder calls "
f"across {num_requests} requests."
)
if "mean_e2el_ms" in result:
print("\nEnd-to-End Latency (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
e2el_data = [
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
]
for p in selected_percentiles:
percentile_value = next(
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
0.0,
)
e2el_data.append(
{
"Metric": f"P{p}",
"Value (ms)": f"{percentile_value:.2f}",
}
)
e2el_df = pd.DataFrame(e2el_data)
print(e2el_df.to_string(index=False))
if args.output_json:
result["config"] = {
"model": args.model,
"num_prompts": args.num_prompts,
"input_len": getattr(args, "random_input_len", None),
"output_len": getattr(args, "random_output_len", None),
}
result["timestamp"] = datetime.now().isoformat()
with open(args.output_json, "w") as f:
json.dump(result, f, indent=2)
print(f"\nResults saved to {args.output_json}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
add_cli_args(parser)
args = parser.parse_args()
main(args)

1816
vllm/benchmarks/serve.py Normal file

File diff suppressed because it is too large Load Diff

321
vllm/benchmarks/startup.py Normal file
View File

@@ -0,0 +1,321 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark the cold and warm startup time of vLLM models.
This script measures total startup time (including model loading, compilation,
and cache operations) for both cold and warm scenarios:
- Cold startup: Fresh start with no caches (temporary cache directories)
- Warm startup: Using cached compilation and model info
"""
import argparse
import dataclasses
import json
import multiprocessing
import os
import shutil
import tempfile
import time
from contextlib import contextmanager
from typing import Any
import numpy as np
from tqdm import tqdm
from vllm.benchmarks.lib.utils import (
convert_to_pytorch_benchmark_format,
write_to_json,
)
from vllm.engine.arg_utils import EngineArgs
@contextmanager
def cold_startup():
"""
Context manager to measure cold startup time:
1. Uses a temporary directory for vLLM cache to avoid any pollution
between cold startup iterations.
2. Uses inductor's fresh_cache to clear torch.compile caches.
"""
from torch._inductor.utils import fresh_cache
# Use temporary directory for caching to avoid any pollution between cold startups
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
try:
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
with fresh_cache():
yield
finally:
# Clean up temporary cache directory
shutil.rmtree(temp_cache_dir, ignore_errors=True)
if original_cache_root:
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
else:
os.environ.pop("VLLM_CACHE_ROOT", None)
def run_startup_in_subprocess(engine_args, result_queue):
"""
Run LLM startup in a subprocess and return timing metrics via a queue.
This ensures complete isolation between iterations.
"""
try:
# Import inside the subprocess to avoid issues with forking
from vllm import LLM
# Measure total startup time
start_time = time.perf_counter()
llm = LLM(**dataclasses.asdict(engine_args))
total_startup_time = time.perf_counter() - start_time
# Extract compilation time if available
compilation_time = 0.0
if hasattr(llm.llm_engine, "vllm_config"):
vllm_config = llm.llm_engine.vllm_config
if (
hasattr(vllm_config, "compilation_config")
and vllm_config.compilation_config is not None
):
compilation_time = vllm_config.compilation_config.compilation_time
result_queue.put(
{
"total_startup_time": total_startup_time,
"compilation_time": compilation_time,
}
)
except Exception as e:
result_queue.put(None)
result_queue.put(str(e))
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
base_name = os.path.splitext(args.output_json)[0]
cold_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_startup_time": [results["avg_cold_startup_time"]],
},
extra_info={
"cold_startup_times": results["cold_startup_times"],
"cold_startup_percentiles": results["cold_startup_percentiles"],
},
)
if cold_startup_records:
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
cold_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_compilation_time": [results["avg_cold_compilation_time"]],
},
extra_info={
"cold_compilation_times": results["cold_compilation_times"],
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
},
)
if cold_compilation_records:
write_to_json(
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
)
warm_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_startup_time": [results["avg_warm_startup_time"]],
},
extra_info={
"warm_startup_times": results["warm_startup_times"],
"warm_startup_percentiles": results["warm_startup_percentiles"],
},
)
if warm_startup_records:
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
warm_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_compilation_time": [results["avg_warm_compilation_time"]],
},
extra_info={
"warm_compilation_times": results["warm_compilation_times"],
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
},
)
if warm_compilation_records:
write_to_json(
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-iters-cold",
type=int,
default=3,
help="Number of cold startup iterations.",
)
parser.add_argument(
"--num-iters-warmup",
type=int,
default=1,
help="Number of warmup iterations before benchmarking warm startups.",
)
parser.add_argument(
"--num-iters-warm",
type=int,
default=3,
help="Number of warm startup iterations.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the startup time results in JSON format.",
)
parser = EngineArgs.add_cli_args(parser)
return parser
def main(args: argparse.Namespace):
# Set multiprocessing start method to 'spawn' for clean process isolation
# This ensures each subprocess starts fresh without inheriting state
multiprocessing.set_start_method("spawn", force=True)
engine_args = EngineArgs.from_cli_args(args)
def create_llm_and_measure_startup():
"""
Create LLM instance in a subprocess and measure startup time.
Returns timing metrics, using subprocess for complete isolation.
"""
# Create a queue for inter-process communication
result_queue = multiprocessing.Queue()
process = multiprocessing.Process(
target=run_startup_in_subprocess,
args=(
engine_args,
result_queue,
),
)
process.start()
process.join()
if not result_queue.empty():
result = result_queue.get()
if result is None:
if not result_queue.empty():
error_msg = result_queue.get()
raise RuntimeError(f"Subprocess failed: {error_msg}")
else:
raise RuntimeError("Subprocess failed with unknown error")
return result
else:
raise RuntimeError("Subprocess did not return a result")
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
print("Measuring cold startup time...\n")
cold_startup_times = []
cold_compilation_times = []
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
with cold_startup():
metrics = create_llm_and_measure_startup()
cold_startup_times.append(metrics["total_startup_time"])
cold_compilation_times.append(metrics["compilation_time"])
# Warmup for warm startup
print("\nWarming up for warm startup measurement...\n")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
create_llm_and_measure_startup()
print("\nMeasuring warm startup time...\n")
warm_startup_times = []
warm_compilation_times = []
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
metrics = create_llm_and_measure_startup()
warm_startup_times.append(metrics["total_startup_time"])
warm_compilation_times.append(metrics["compilation_time"])
# Calculate statistics
cold_startup_array = np.array(cold_startup_times)
cold_compilation_array = np.array(cold_compilation_times)
warm_startup_array = np.array(warm_startup_times)
warm_compilation_array = np.array(warm_compilation_times)
avg_cold_startup = np.mean(cold_startup_array)
avg_cold_compilation = np.mean(cold_compilation_array)
avg_warm_startup = np.mean(warm_startup_array)
avg_warm_compilation = np.mean(warm_compilation_array)
percentages = [10, 25, 50, 75, 90, 99]
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
print("\n" + "=" * 60)
print("STARTUP TIME BENCHMARK RESULTS")
print("=" * 60)
# Cold startup statistics
print("\nCOLD STARTUP:")
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, cold_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
# Warm startup statistics
print("\nWARM STARTUP:")
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, warm_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("=" * 60)
# Output JSON results if specified
if args.output_json:
results = {
"avg_cold_startup_time": float(avg_cold_startup),
"avg_cold_compilation_time": float(avg_cold_compilation),
"cold_startup_times": cold_startup_times,
"cold_compilation_times": cold_compilation_times,
"cold_startup_percentiles": dict(
zip(percentages, cold_startup_percentiles.tolist())
),
"cold_compilation_percentiles": dict(
zip(percentages, cold_compilation_percentiles.tolist())
),
"avg_warm_startup_time": float(avg_warm_startup),
"avg_warm_compilation_time": float(avg_warm_compilation),
"warm_startup_times": warm_startup_times,
"warm_compilation_times": warm_compilation_times,
"warm_startup_percentiles": dict(
zip(percentages, warm_startup_percentiles.tolist())
),
"warm_compilation_percentiles": dict(
zip(percentages, warm_compilation_percentiles.tolist())
),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

View File

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs
from .plot import main as plot_main
from .plot_pareto import SweepPlotParetoArgs
from .plot_pareto import main as plot_pareto_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
from .startup import SweepStartupArgs
from .startup import main as startup_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepStartupArgs, startup_main),
(SweepPlotArgs, plot_main),
(SweepPlotParetoArgs, plot_pareto_main),
)
def add_cli_args(parser: argparse.ArgumentParser):
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
for cmd, entrypoint in SUBCOMMANDS:
cmd_subparser = subparsers.add_parser(
cmd.parser_name,
description=cmd.parser_help,
usage=f"vllm bench sweep {cmd.parser_name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=entrypoint)
cmd.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"sweep {cmd.parser_name}"
)
def main(args: argparse.Namespace):
args.dispatch_function(args)

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from typing import Any
class ParameterSweep(list["ParameterSweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
data = json.load(f)
# Support both list and dict formats
if isinstance(data, dict):
return cls.read_from_dict(data)
return cls.from_records(data)
@classmethod
def read_from_dict(cls, data: dict[str, dict[str, object]]):
"""
Read parameter sweep from a dict format where keys are names.
Example:
{
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9}
}
"""
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
return cls.from_records(records)
@classmethod
def from_records(cls, records: list[dict[str, object]]):
if not isinstance(records, list):
raise TypeError(
f"The parameter sweep should be a list of dictionaries, "
f"but found type: {type(records)}"
)
# Validate that all _benchmark_name values are unique if provided
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
if names and len(names) != len(set(names)):
duplicates = [name for name in names if names.count(name) > 1]
raise ValueError(
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
f"All _benchmark_name values must be unique."
)
return cls(ParameterSweepItem.from_record(record) for record in records)
class ParameterSweepItem(dict[str, object]):
@classmethod
def from_record(cls, record: dict[str, object]):
if not isinstance(record, dict):
raise TypeError(
f"Each item in the parameter sweep should be a dictionary, "
f"but found type: {type(record)}"
)
return cls(record)
def __or__(self, other: dict[str, Any]):
return type(self)(super().__or__(other))
@property
def name(self) -> str:
"""
Get the name for this parameter sweep item.
Returns the '_benchmark_name' field if present, otherwise returns a text
representation of all parameters.
"""
if "_benchmark_name" in self:
return str(self["_benchmark_name"])
return self.as_text(sep="-")
# In JSON, we prefer "_"
def _iter_param_key_candidates(self, param_key: str):
# Inner config arguments are not converted by the CLI
if "." in param_key:
prefix, rest = param_key.split(".", 1)
for prefix_candidate in self._iter_param_key_candidates(prefix):
yield prefix_candidate + "." + rest
return
yield param_key
yield param_key.replace("-", "_")
yield param_key.replace("_", "-")
# In CLI, we prefer "-"
def _iter_cmd_key_candidates(self, param_key: str):
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
yield "--" + k
def _normalize_cmd_key(self, param_key: str):
return next(self._iter_cmd_key_candidates(param_key))
def has_param(self, param_key: str) -> bool:
return any(k in self for k in self._iter_param_key_candidates(param_key))
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
"""
Normalize a key-value pair into command-line arguments.
Returns a list containing either:
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
- Two elements for key-value pairs (e.g., ['--key', 'value'])
"""
if isinstance(v, bool):
# For nested params (containing "."), use =true/false syntax
if "." in k:
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
else:
return [self._normalize_cmd_key(k if v else "no-" + k)]
else:
return [self._normalize_cmd_key(k), str(v)]
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
cmd = list(cmd)
for k, v in self.items():
# Skip the '_benchmark_name' field, not a parameter
if k == "_benchmark_name":
continue
# Serialize dict values as JSON
if isinstance(v, dict):
v = json.dumps(v)
for k_candidate in self._iter_cmd_key_candidates(k):
try:
k_idx = cmd.index(k_candidate)
# Replace existing parameter
normalized = self._normalize_cmd_kv_pair(k, v)
if len(normalized) == 1:
# Boolean flag
cmd[k_idx] = normalized[0]
else:
# Key-value pair
cmd[k_idx] = normalized[0]
cmd[k_idx + 1] = normalized[1]
break
except ValueError:
continue
else:
# Add new parameter
cmd.extend(self._normalize_cmd_kv_pair(k, v))
return cmd
def as_text(self, sep: str = ", ") -> str:
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")

View File

@@ -0,0 +1,683 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import TracebackType
from typing import ClassVar
from typing_extensions import Self, override
from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
try:
import seaborn as sns
except ImportError:
seaborn = PlaceholderModule("seaborn")
@dataclass
class PlotFilterBase(ABC):
var: str
target: str
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_FILTERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_FILTERS[op_key](
key,
value.removeprefix(op_key).strip("'").strip('"'),
)
else:
raise ValueError(
f"Invalid operator for plot filter '{s}'. "
f"Valid operators are: {sorted(PLOT_FILTERS)}",
)
@abstractmethod
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this filter to a DataFrame."""
raise NotImplementedError
@dataclass
class PlotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] == target]
@dataclass
class PlotNotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] != target]
@dataclass
class PlotLessThan(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] < float(self.target)]
@dataclass
class PlotLessThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] <= float(self.target)]
@dataclass
class PlotGreaterThan(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] > float(self.target)]
@dataclass
class PlotGreaterThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] >= float(self.target)]
# NOTE: The ordering is important! Match longer op_keys first
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
"==": PlotEqualTo,
"!=": PlotNotEqualTo,
"<=": PlotLessThanOrEqualTo,
">=": PlotGreaterThanOrEqualTo,
"<": PlotLessThan,
">": PlotGreaterThan,
}
class PlotFilters(list[PlotFilterBase]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
return df
@dataclass
class PlotBinner:
var: str
bin_size: float
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_BINNERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
else:
raise ValueError(
f"Invalid operator for plot binner '{s}'. "
f"Valid operators are: {sorted(PLOT_BINNERS)}",
)
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this binner to a DataFrame."""
df = df.copy()
df[self.var] = df[self.var] // self.bin_size * self.bin_size
return df
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
"%": PlotBinner,
}
class PlotBinners(list[PlotBinner]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotBinner.parse_str(e) for e in s.split(","))
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
return df
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
with path.open("rb") as f:
return json.load(f)
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
"""
Convert string values "inf", "-inf", and "nan" to their float equivalents.
This handles the case where JSON serialization represents inf/nan as strings.
"""
converted_data = []
for record in data:
converted_record = {}
for key, value in record.items():
if isinstance(value, str):
if value in ["inf", "-inf", "nan"]:
converted_record[key] = float(value)
else:
converted_record[key] = value
else:
converted_record[key] = value
converted_data.append(converted_record)
return converted_data
def _get_metric(run_data: dict[str, object], metric_key: str):
try:
return run_data[metric_key]
except KeyError as exc:
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
def _get_group(run_data: dict[str, object], group_keys: list[str]):
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
parts = list[str]()
# Start with figure name (always provided, defaults to "FIGURE")
parts.append(fig_name)
# Always append group data if present
if group:
parts.extend(f"{k}={v}" for k, v in group)
return fig_dir / sanitize_filename("-".join(parts) + ".png")
class DummyExecutor:
map = map
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
return None
def _plot_fig(
fig_dir: Path,
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str,
error_bars: bool,
fig_height: float,
fig_dpi: int,
):
fig_group, fig_data = fig_group_data
row_groups = full_groupby(
fig_data,
key=lambda item: _get_group(item, row_by),
)
num_rows = len(row_groups)
num_cols = max(
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
for _, row_data in row_groups
)
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
print(f"Grid: {num_rows} rows x {num_cols} cols")
print(f"Output file: {fig_path}")
if dry_run:
print("[END FIGURE]")
return
# Convert string "inf", "-inf", and "nan" to their float equivalents
fig_data = _convert_inf_nan_strings(fig_data)
df = pd.DataFrame.from_records(fig_data)
if var_x not in df.columns:
raise ValueError(
f"Cannot find {var_x=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
if var_y not in df.columns:
raise ValueError(
f"Cannot find {var_y=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in row_by:
if k not in df.columns:
raise ValueError(
f"Cannot find row_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in col_by:
if k not in df.columns:
raise ValueError(
f"Cannot find col_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in curve_by:
if k not in df.columns:
raise ValueError(
f"Cannot find curve_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
df = filter_by.apply(df)
df = bin_by.apply(df)
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
df["row_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in row_by],
axis=1,
).agg("\n".join, axis=1)
if row_by
else "(All)"
)
df["col_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in col_by],
axis=1,
).agg("\n".join, axis=1)
if col_by
else "(All)"
)
if len(curve_by) <= 3:
hue, style, size, *_ = (*curve_by, None, None, None)
g = sns.relplot(
df,
x=var_x,
y=var_y,
hue=hue,
style=style,
size=size,
markers=True,
errorbar="sd" if error_bars else None,
kind="line",
row="row_group",
col="col_group",
height=fig_height,
)
else:
df["curve_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in curve_by],
axis=1,
).agg("\n".join, axis=1)
if curve_by
else "(All)"
)
g = sns.relplot(
df,
x=var_x,
y=var_y,
hue="curve_group",
markers=True,
errorbar="sd" if error_bars else None,
kind="line",
row="row_group",
col="col_group",
height=fig_height,
)
if row_by and col_by:
g.set_titles("{row_name}\n{col_name}")
elif row_by:
g.set_titles("{row_name}")
elif col_by:
g.set_titles("{col_name}")
else:
g.set_titles("")
if scale_x:
g.set(xscale=scale_x)
if scale_y:
g.set(yscale=scale_y)
g.savefig(fig_path, dpi=fig_dpi)
plt.close(g.figure)
print("[END FIGURE]")
def plot(
output_dir: Path,
fig_dir: Path,
fig_by: list[str],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str = "FIGURE",
error_bars: bool = True,
fig_height: float = 6.4,
fig_dpi: int = 300,
):
all_data = [
run_data
for path in output_dir.rglob("**/summary.json")
for run_data in _json_load_bytes(path)
]
if not all_data:
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
fig_dir.mkdir(parents=True, exist_ok=True)
fig_groups = full_groupby(
all_data,
key=lambda item: _get_group(item, fig_by),
)
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
# Resolve the iterable to ensure that the workers are run
all(
executor.map(
partial(
_plot_fig,
fig_dir,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=var_x,
var_y=var_y,
filter_by=filter_by,
bin_by=bin_by,
scale_x=scale_x,
scale_y=scale_y,
dry_run=dry_run,
fig_name=fig_name,
error_bars=error_bars,
fig_height=fig_height,
fig_dpi=fig_dpi,
),
fig_groups,
)
)
@dataclass
class SweepPlotArgs:
output_dir: Path
fig_dir: Path
fig_by: list[str]
row_by: list[str]
col_by: list[str]
curve_by: list[str]
var_x: str
var_y: str
filter_by: PlotFilters
bin_by: PlotBinners
scale_x: str | None
scale_y: str | None
dry_run: bool
fig_name: str = "FIGURE"
error_bars: bool = True
fig_height: float = 6.4
fig_dpi: int = 300
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
return cls(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=not args.no_error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
)
parser.add_argument(
"--fig-dir",
type=str,
default="",
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
"By default, the same directory is used.",
)
parser.add_argument(
"--fig-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate figure "
"is created for each combination of these variables.",
)
parser.add_argument(
"--row-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate row "
"is created for each combination of these variables.",
)
parser.add_argument(
"--col-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate column "
"is created for each combination of these variables.",
)
parser.add_argument(
"--curve-by",
type=str,
default=None,
help="A comma-separated list of variables, such that a separate curve "
"is created for each combination of these variables.",
)
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_ttft_ms",
help="The variable for the y-axis",
)
parser.add_argument(
"--filter-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to filter by. "
"This is useful to remove outliers. "
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
"plot only the points where `max_concurrency` is less than 1000 and "
"`max_num_batched_tokens` is no greater than 4096.",
)
parser.add_argument(
"--bin-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%%1` means "
"use a bin size of 1 for the `request_throughput` variable.",
)
parser.add_argument(
"--scale-x",
type=str,
default=None,
help="The scale to use for the x-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--scale-y",
type=str,
default=None,
help="The scale to use for the y-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--fig-name",
type=str,
default="FIGURE",
help="Name prefix for the output figure file. "
"Group data is always appended when present. "
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
)
parser.add_argument(
"--no-error-bars",
action="store_true",
help="If set, disables error bars on the plot. "
"By default, error bars are shown.",
)
parser.add_argument(
"--fig-height",
type=float,
default=6.4,
help="Height of each subplot in inches. Default: 6.4",
)
parser.add_argument(
"--fig-dpi",
type=int,
default=300,
help="Resolution of the output figure in dots per inch. Default: 300",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the information about each figure to plot, "
"then exits without drawing them.",
)
return parser
def run_main(args: SweepPlotArgs):
return plot(
output_dir=args.output_dir,
fig_dir=args.fig_dir,
fig_by=args.fig_by,
row_by=args.row_by,
col_by=args.col_by,
curve_by=args.curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=args.filter_by,
bin_by=args.bin_by,
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=args.error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
def main(args: argparse.Namespace):
run_main(SweepPlotArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
SweepPlotArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,399 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import ClassVar
from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .plot import DummyExecutor, _json_load_bytes
from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
try:
import seaborn as sns
except ImportError:
seaborn = PlaceholderModule("seaborn")
def _first_present(run_data: dict[str, object], keys: list[str]):
for key in keys:
for candidate in {key, key.replace("_", "-"), key.replace("-", "_")}:
if candidate in run_data:
return run_data[candidate]
return None
def _get_numeric(
run_data: dict[str, object],
keys: list[str],
*,
allow_zero: bool = True,
) -> float | None:
value = _first_present(run_data, keys)
if value is None:
return None
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Expected numeric value for one of {keys}, "
f"but found {value!r} in {run_data=}"
) from exc
if not allow_zero and numeric == 0:
return None
return numeric
def _infer_user_count(
run_data: dict[str, object],
user_count_var: str | None,
) -> float | None:
candidates = [user_count_var] if user_count_var else []
candidates.extend(["request_rate"])
user_count = _get_numeric(run_data, candidates, allow_zero=False)
if user_count is not None:
return user_count
# Fallback to the observed peak if configured value is missing.
return _get_numeric(run_data, ["max_concurrent_requests"], allow_zero=False)
def _infer_gpu_count(
run_data: dict[str, object],
gpu_count_var: str | None,
) -> float:
direct_candidates = [gpu_count_var] if gpu_count_var else []
direct_gpu_count = _get_numeric(run_data, direct_candidates, allow_zero=False)
if direct_gpu_count:
return direct_gpu_count
tp_size = _get_numeric(run_data, ["tensor_parallel_size", "tp"])
pp_size = _get_numeric(run_data, ["pipeline_parallel_size", "pp"])
dp_size = _get_numeric(run_data, ["data_parallel_size", "dp"])
world_size = 1.0
if tp_size:
world_size *= tp_size
if pp_size:
world_size *= pp_size
if dp_size:
world_size *= dp_size
return world_size
def _get_throughput(
run_data: dict[str, object],
throughput_var: str,
) -> float:
throughput = _get_numeric(run_data, [throughput_var])
if throughput is None:
raise ValueError(
f"Cannot find throughput metric {throughput_var!r} in run data. "
f"Available keys: {sorted(run_data)}"
)
return throughput
def _prepare_records(
all_data: list[dict[str, object]],
*,
user_count_var: str | None,
gpu_count_var: str | None,
) -> tuple[list[dict[str, object]], int]:
prepared = []
skipped_missing_users = 0
for record in all_data:
throughput = _get_throughput(record, "output_throughput")
user_count = _infer_user_count(record, user_count_var)
if user_count is None:
skipped_missing_users += 1
continue
gpu_count = _infer_gpu_count(record, gpu_count_var)
tokens_per_user = throughput / user_count
tokens_per_gpu = throughput / gpu_count
prepared.append(
{
**record,
"tokens_per_user": tokens_per_user,
"tokens_per_gpu": tokens_per_gpu,
"user_count_estimate": user_count,
"gpu_count": gpu_count,
}
)
return prepared, skipped_missing_users
def _pareto_frontier(
df: "pd.DataFrame",
x_col: str,
y_col: str,
*,
epsilon: float = 1e-9,
) -> "pd.DataFrame":
sorted_df = df.sort_values([x_col, y_col], ascending=[False, False])
frontier_indices = []
best_y = -math.inf
for idx, row in sorted_df.iterrows():
y_val = row[y_col]
if y_val >= best_y - epsilon:
frontier_indices.append(idx)
best_y = max(best_y, y_val)
return df.loc[frontier_indices]
def _get_fig_path(
fig_dir: Path,
fig_group: tuple[tuple[str, str], ...],
) -> Path:
parts = ["PARETO"]
if fig_group:
parts.extend(f"{k}={v}" for k, v in fig_group)
filename = sanitize_filename("-".join(parts) + ".png")
return fig_dir / filename
def _plot_fig(
fig_dir: Path,
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
label_by: list[str],
*,
dry_run: bool,
):
fig_group, fig_data = fig_group_data
fig_path = _get_fig_path(fig_dir, fig_group)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
print(f"Output file: {fig_path}")
if dry_run:
print("[END FIGURE]")
return
df = pd.DataFrame.from_records(fig_data)
df = df.dropna(subset=["tokens_per_user", "tokens_per_gpu"])
if df.empty:
print("No data points available after filtering; skipping.")
print("[END FIGURE]")
return
frontier = _pareto_frontier(df, "tokens_per_user", "tokens_per_gpu")
frontier = frontier.sort_values("tokens_per_user")
fig, ax = plt.subplots()
sns.scatterplot(
data=df,
x="tokens_per_user",
y="tokens_per_gpu",
color="0.5",
alpha=0.6,
ax=ax,
label="All runs",
)
sns.lineplot(
data=frontier,
x="tokens_per_user",
y="tokens_per_gpu",
marker="o",
ax=ax,
label="Pareto frontier",
)
if label_by:
for _, row in frontier.iterrows():
label_parts = []
for key in label_by:
if key in row:
label_parts.append(f"{key}={row[key]}")
if label_parts:
ax.text(
row["tokens_per_user"],
row["tokens_per_gpu"],
"\n".join(label_parts),
fontsize=8,
)
ax.set_xlabel("Tokens/s/user")
ax.set_ylabel("Tokens/s/GPU")
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
fig.tight_layout()
fig.savefig(fig_path)
plt.close(fig)
print(
f"Plotted {len(df)} points; Pareto frontier size: {len(frontier)}.",
)
print("[END FIGURE]")
def plot_pareto(
output_dir: Path,
user_count_var: str | None,
gpu_count_var: str | None,
label_by: list[str],
*,
dry_run: bool,
):
fig_dir = output_dir / "pareto"
raw_data = [
run_data
for path in output_dir.rglob("**/summary.json")
for run_data in _json_load_bytes(path)
]
if not raw_data:
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
fig_dir.mkdir(parents=True, exist_ok=True)
prepared_data, skipped_missing_users = _prepare_records(
raw_data,
user_count_var=user_count_var,
gpu_count_var=gpu_count_var,
)
if skipped_missing_users:
print(
f"Skipped {skipped_missing_users} runs without a user count "
"(`max_concurrency` or `max_concurrent_requests`).",
)
if not prepared_data:
raise ValueError(
"No data points with both throughput and user count available "
"to plot Pareto frontier.",
)
fig_groups = full_groupby(
prepared_data,
key=lambda item: tuple(),
)
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
all(
executor.map(
partial(
_plot_fig,
fig_dir,
label_by=label_by,
dry_run=dry_run,
),
fig_groups,
)
)
@dataclass
class SweepPlotParetoArgs:
output_dir: Path
user_count_var: str | None
gpu_count_var: str | None
label_by: list[str]
dry_run: bool
parser_name: ClassVar[str] = "plot_pareto"
parser_help: ClassVar[str] = (
"Plot Pareto frontier between tokens/s/user and tokens/s/GPU "
"from parameter sweep results."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
label_by = [] if not args.label_by else args.label_by.split(",")
return cls(
output_dir=output_dir,
user_count_var=args.user_count_var,
gpu_count_var=args.gpu_count_var,
label_by=label_by,
dry_run=args.dry_run,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(
"--user-count-var",
type=str,
default="max_concurrency",
help="Result key that stores concurrent user count. "
"Falls back to max_concurrent_requests if missing.",
)
parser.add_argument(
"--gpu-count-var",
type=str,
default=None,
help="Result key that stores GPU count. "
"If not provided, falls back to num_gpus/gpu_count "
"or tensor_parallel_size * pipeline_parallel_size.",
)
parser.add_argument(
"--label-by",
type=str,
default="max_concurrency,gpu_count",
help="Comma-separated list of fields to annotate on Pareto frontier "
"points.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the figures to plot without drawing them.",
)
return parser
def run_main(args: SweepPlotParetoArgs):
return plot_pareto(
output_dir=args.output_dir,
user_count_var=args.user_count_var,
gpu_count_var=args.gpu_count_var,
label_by=args.label_by,
dry_run=args.dry_run,
)
def main(args: argparse.Namespace):
run_main(SweepPlotParetoArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepPlotParetoArgs.parser_help)
SweepPlotParetoArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,498 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import json
import shlex
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .server import ServerProcess
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@contextlib.contextmanager
def run_server(
serve_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_overrides: ParameterSweepItem,
dry_run: bool,
server_ready_timeout: int = 300,
):
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
print("[BEGIN SERVER]")
print(f"Server overrides: {serve_overrides}")
print(f"Server command: {server_cmd}")
if dry_run:
yield None
print("[END SERVER]")
return
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
server.wait_until_ready(timeout=server_ready_timeout)
yield server
print("[END SERVER]")
def _update_run_data(
run_data: dict[str, object],
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
):
run_data["run_number"] = run_number
run_data.update(serve_overrides)
run_data.update(bench_overrides)
return run_data
def run_benchmark(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
output_path: Path,
dry_run: bool,
):
benchmark_cmd = [
*bench_overrides.apply_to_cmd(bench_cmd),
"--percentile-metrics",
"ttft,tpot,itl,e2el",
"--save-result",
"--result-dir",
str(output_path.parent),
"--result-filename",
output_path.name,
]
print("[BEGIN BENCHMARK]")
print(f"Benchmark overrides: {bench_overrides}")
print(f"Run Number: {run_number}")
print(f"Benchmark command: {benchmark_cmd}")
print(f"Output file: {output_path}")
run_data: dict[str, object]
if output_path.exists():
print("Found existing results.")
print("[SKIPPED BENCHMARK]")
with output_path.open("rb") as f:
run_data = json.load(f)
return _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
if server is None:
if not dry_run:
raise ValueError(f"Cannot find results at {output_path}")
print("[END BENCHMARK]")
return None
output_path.parent.mkdir(parents=True, exist_ok=True)
server.run_subcommand(benchmark_cmd)
server.after_bench()
with output_path.open("rb") as f:
run_data = json.load(f)
run_data = _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
with output_path.open("w") as f:
json.dump(run_data, f, indent=4)
print("[END BENCHMARK]")
return run_data
def _get_comb_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.name))
return output_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
if run_number is None:
return base_path / "summary.json"
return base_path / f"run={run_number}.json"
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
output_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
return False
def server_ctx(
serve_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_comb: ParameterSweepItem,
bench_params: ParameterSweep,
output_dir: Path,
dry_run: bool,
server_ready_timeout: int = 300,
):
if not _comb_needs_server(serve_comb, bench_params, output_dir):
return contextlib.nullcontext()
return run_server(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_overrides=serve_comb,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
)
def _comb_is_valid(
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
) -> bool:
return all(
serve_key in serve_comb
and bench_key in bench_comb
and serve_comb[serve_key] == bench_comb[bench_key]
for serve_key, bench_key in link_vars
)
def run_comb(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
return None
comb_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
server,
bench_cmd,
serve_overrides=serve_comb,
bench_overrides=bench_comb,
run_number=run_number,
output_path=_get_comb_run_path(base_path, run_number),
dry_run=dry_run,
)
if run_data is not None:
comb_data.append(run_data)
if dry_run:
return None
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
json.dump(comb_data, f, indent=4)
return comb_data
def run_combs(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
) as server:
for bench_comb in bench_params:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeArgs:
serve_cmd: list[str]
bench_cmd: list[str]
after_bench_cmd: list[str]
show_stdout: bool
server_ready_timeout: int
serve_params: ParameterSweep
bench_params: ParameterSweep
output_dir: Path
num_runs: int
dry_run: bool
resume: str | None
link_vars: list[tuple[str, str]]
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
serve_cmd = shlex.split(args.serve_cmd)
bench_cmd = shlex.split(args.bench_cmd)
after_bench_cmd = (
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
)
if args.serve_params:
serve_params = ParameterSweep.read_json(args.serve_params)
else:
# i.e.: run serve_cmd without any modification
serve_params = ParameterSweep.from_records([{}])
if args.bench_params:
bench_params = ParameterSweep.read_json(args.bench_params)
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
link_vars = cls.parse_link_vars(args.link_vars)
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
return cls(
serve_cmd=serve_cmd,
bench_cmd=bench_cmd,
after_bench_cmd=after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
output_dir=Path(args.output_dir),
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
link_vars=link_vars,
server_ready_timeout=args.server_ready_timeout,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--serve-cmd",
type=str,
required=True,
help="The command used to run the server: `vllm serve ...`",
)
parser.add_argument(
"--bench-cmd",
type=str,
required=True,
help="The command used to run the benchmark: `vllm bench serve ...`",
)
parser.add_argument(
"--after-bench-cmd",
type=str,
default=None,
help="After a benchmark run is complete, invoke this command instead of "
"the default `ServerWrapper.clear_cache()`.",
)
parser.add_argument(
"--show-stdout",
action="store_true",
help="If set, logs the standard output of subcommands. "
"Useful for debugging but can be quite spammy.",
)
parser.add_argument(
"--server-ready-timeout",
type=int,
default=300,
help="Timeout in seconds to wait for the server to become ready.",
)
parser.add_argument(
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm serve` command. Can be either a list of dicts or a dict "
"where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"--bench-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm bench serve` command. Can be either a list of dicts or "
"a dict where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
)
parser.add_argument(
"--num-runs",
type=int,
default=3,
help="Number of runs per parameter combination.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the commands to run, "
"then exits without executing them.",
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
)
return parser
@staticmethod
def parse_link_vars(s: str) -> list[tuple[str, str]]:
if not s:
return []
pairs = []
for item in s.split(","):
a, b = item.split("=")
pairs.append((a.strip(), b.strip()))
return pairs
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
SweepServeArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
def run_comb_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
sla_variable: SLAVariable,
sla_value: int,
) -> list[dict[str, object]] | None:
bench_comb_sla = bench_comb | {sla_variable: sla_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_sla,
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
def explore_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
print("[SLA START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of SLA iterations: {sla_iters}")
if sla_iters < 2:
raise ValueError("`sla_iters` should be at least 2")
serial_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=1,
)
batch_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
)
if serial_comb_data is None or batch_comb_data is None:
if dry_run:
print("Omitting intermediate SLA iterations.")
print("[SLA END]")
return
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
print(f"Serial inference: {sla_variable}={serial_sla_value}")
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
print(f"Batch inference: {sla_variable}={batch_sla_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_sla_value` and `batch_sla_value` is small
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
inter_sla_values = sorted(set(map(round, inter_sla_values)))
inter_combs_data: list[dict[str, object]] = []
for inter_sla_value in inter_sla_values:
print(f"Exploring: {sla_variable}={inter_sla_value}")
inter_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=inter_sla_value,
)
if inter_comb_data is not None:
inter_combs_data.extend(inter_comb_data)
print("[SLA END]")
return serial_comb_data + inter_combs_data + batch_comb_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_variable=sla_variable,
sla_iters=sla_iters,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_variable: SLAVariable
sla_iters: int
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = (
"Explore the latency-throughput space for determining SLAs."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
sla_group.add_argument(
"--sla-iters",
type=int,
default=10,
help="Number of iterations used to explore the latency-throughput space. "
"This includes the first two iterations used to interpolate the value of "
"`sla_variable` for remaining iterations.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import signal
import subprocess
import time
from types import TracebackType
import requests
from typing_extensions import Self
class ServerProcess:
VLLM_RESET_CACHE_ENDPOINTS = [
"/reset_prefix_cache",
"/reset_mm_cache",
"/reset_encoder_cache",
]
def __init__(
self,
server_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
) -> None:
super().__init__()
self.server_cmd = server_cmd
self.after_bench_cmd = after_bench_cmd
self.show_stdout = show_stdout
def __enter__(self) -> Self:
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.stop()
def start(self):
# Create new process for clean termination
self._server_process = subprocess.Popen(
self.server_cmd,
start_new_session=True,
stdout=None if self.show_stdout else subprocess.DEVNULL,
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
)
def stop(self):
server_process = self._server_process
if server_process.poll() is None:
# In case only some processes have been terminated
with contextlib.suppress(ProcessLookupError):
# We need to kill both API Server and Engine processes
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
def run_subcommand(self, cmd: list[str]):
return subprocess.run(
cmd,
stdout=None if self.show_stdout else subprocess.DEVNULL,
check=True,
)
def after_bench(self) -> None:
if not self.after_bench_cmd:
self.reset_caches()
return
self.run_subcommand(self.after_bench_cmd)
def _get_vllm_server_address(self) -> str:
server_cmd = self.server_cmd
for host_key in ("--host",):
if host_key in server_cmd:
host = server_cmd[server_cmd.index(host_key) + 1]
break
else:
host = "localhost"
for port_key in ("-p", "--port"):
if port_key in server_cmd:
port = int(server_cmd[server_cmd.index(port_key) + 1])
break
else:
port = 8000 # The default value in vllm serve
return f"http://{host}:{port}"
def is_server_ready(self) -> bool:
server_address = self._get_vllm_server_address()
try:
response = requests.get(f"{server_address}/health")
return response.status_code == 200
except requests.RequestException:
return False
def wait_until_ready(self, timeout: int) -> None:
start_time = time.monotonic()
while not self.is_server_ready():
# Check if server process has crashed
if self._server_process.poll() is not None:
returncode = self._server_process.returncode
raise RuntimeError(
f"Server process crashed with return code {returncode}"
)
if time.monotonic() - start_time > timeout:
raise TimeoutError(
f"Server failed to become ready within {timeout} seconds."
)
time.sleep(1)
def reset_caches(self) -> None:
server_cmd = self.server_cmd
# Use `.endswith()` to match `/bin/...`
if server_cmd[0].endswith("vllm"):
server_address = self._get_vllm_server_address()
print(f"Resetting caches at {server_address}")
for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
res = requests.post(server_address + endpoint)
res.raise_for_status()
elif server_cmd[0].endswith("infinity_emb"):
if "--vector-disk-cache" in server_cmd:
raise NotImplementedError(
"Infinity server uses caching but does not expose a method "
"to reset the cache"
)
else:
raise NotImplementedError(
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
"Please specify a custom command via `--after-bench-cmd`."
)

View File

@@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import shlex
import subprocess
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import ClassVar
from vllm.benchmarks.startup import add_cli_args as add_startup_cli_args
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@lru_cache(maxsize=1)
def _get_supported_startup_keys() -> set[str]:
parser = FlexibleArgumentParser(add_help=False)
add_startup_cli_args(parser)
supported: set[str] = {"config"}
for action in parser._actions:
if action.dest and action.dest is not argparse.SUPPRESS:
supported.add(action.dest)
for option in action.option_strings:
if option.startswith("--"):
supported.add(option.lstrip("-").replace("-", "_"))
return supported
def _is_supported_param(param_key: str, supported: set[str]) -> bool:
if param_key == "_benchmark_name":
return True
prefix = param_key.split(".", 1)[0]
normalized = prefix.replace("-", "_")
return normalized in supported
def _filter_params(
params: ParameterSweep, *, supported: set[str], strict: bool
) -> ParameterSweep:
filtered = []
for item in params:
kept: dict[str, object] = {}
dropped: list[str] = []
for key, value in item.items():
if _is_supported_param(key, supported):
kept[key] = value
else:
dropped.append(key)
if dropped:
label = item.get("_benchmark_name") or item.as_text()
message = (
"Ignoring unsupported startup params"
f"{' for ' + str(label) if label else ''}: "
f"{', '.join(sorted(dropped))}"
)
if strict:
raise ValueError(message)
print(message)
filtered.append(ParameterSweepItem.from_record(kept))
return ParameterSweep(filtered)
def _update_run_data(
run_data: dict[str, object],
serve_overrides: ParameterSweepItem,
startup_overrides: ParameterSweepItem,
run_number: int,
) -> dict[str, object]:
run_data["run_number"] = run_number
run_data.update(serve_overrides)
run_data.update(startup_overrides)
return run_data
def _strip_arg(cmd: list[str], keys: tuple[str, ...]) -> list[str]:
stripped: list[str] = []
skip_next = False
for arg in cmd:
if skip_next:
skip_next = False
continue
if arg in keys:
skip_next = True
continue
if any(arg.startswith(f"{key}=") for key in keys):
continue
stripped.append(arg)
return stripped
def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
keys = ("--output-json", "--output_json")
cmd = _strip_arg(cmd, keys)
return [*cmd, keys[0], str(output_path)]
def _get_comb_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
) -> Path:
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if startup_comb:
parts.extend(("STARTUP-", startup_comb.name))
return output_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
if run_number is None:
return base_path / "summary.json"
return base_path / f"run={run_number}.json"
def run_benchmark(
startup_cmd: list[str],
*,
serve_overrides: ParameterSweepItem,
startup_overrides: ParameterSweepItem,
run_number: int,
output_path: Path,
show_stdout: bool,
dry_run: bool,
) -> dict[str, object] | None:
cmd = serve_overrides.apply_to_cmd(startup_cmd)
cmd = startup_overrides.apply_to_cmd(cmd)
cmd = _apply_output_json(cmd, output_path)
print("[BEGIN BENCHMARK]")
print(f"Serve overrides: {serve_overrides}")
print(f"Startup overrides: {startup_overrides}")
print(f"Run Number: {run_number}")
print(f"Benchmark command: {cmd}")
print(f"Output file: {output_path}")
if output_path.exists():
print("Found existing results.")
print("[SKIPPED BENCHMARK]")
with output_path.open("r", encoding="utf-8") as f:
run_data = json.load(f)
return _update_run_data(
run_data, serve_overrides, startup_overrides, run_number
)
if dry_run:
print("[END BENCHMARK]")
return None
output_path.parent.mkdir(parents=True, exist_ok=True)
subprocess.run(
cmd,
stdout=None if show_stdout else subprocess.DEVNULL,
check=True,
)
with output_path.open("r", encoding="utf-8") as f:
run_data = json.load(f)
run_data = _update_run_data(
run_data, serve_overrides, startup_overrides, run_number
)
with output_path.open("w", encoding="utf-8") as f:
json.dump(run_data, f, indent=4)
print("[END BENCHMARK]")
return run_data
def run_comb(
startup_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
base_path: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
) -> list[dict[str, object]] | None:
comb_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
startup_cmd,
serve_overrides=serve_comb,
startup_overrides=startup_comb,
run_number=run_number,
output_path=_get_comb_run_path(base_path, run_number),
show_stdout=show_stdout,
dry_run=dry_run,
)
if run_data is not None:
comb_data.append(run_data)
if dry_run:
return None
with _get_comb_run_path(base_path, run_number=None).open(
"w", encoding="utf-8"
) as f:
json.dump(comb_data, f, indent=4)
return comb_data
def run_combs(
startup_cmd: list[str],
*,
serve_params: ParameterSweep,
startup_params: ParameterSweep,
output_dir: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
) -> "pd.DataFrame | None":
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
for startup_comb in startup_params:
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
comb_data = run_comb(
startup_cmd,
serve_comb=serve_comb,
startup_comb=startup_comb,
base_path=base_path,
num_runs=num_runs,
show_stdout=show_stdout,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepStartupArgs:
startup_cmd: list[str]
serve_params: ParameterSweep
startup_params: ParameterSweep
output_dir: Path
num_runs: int
show_stdout: bool
dry_run: bool
resume: str | None
strict_params: bool
parser_name: ClassVar[str] = "startup"
parser_help: ClassVar[str] = (
"Benchmark vLLM startup time over parameter combinations."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
startup_cmd = shlex.split(args.startup_cmd)
if args.serve_params:
serve_params = ParameterSweep.read_json(args.serve_params)
else:
serve_params = ParameterSweep.from_records([{}])
if args.startup_params:
startup_params = ParameterSweep.read_json(args.startup_params)
else:
startup_params = ParameterSweep.from_records([{}])
supported = _get_supported_startup_keys()
serve_params = _filter_params(
serve_params, supported=supported, strict=args.strict_params
)
startup_params = _filter_params(
startup_params, supported=supported, strict=args.strict_params
)
if args.num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
return cls(
startup_cmd=startup_cmd,
serve_params=serve_params,
startup_params=startup_params,
output_dir=Path(args.output_dir),
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
resume=args.resume,
strict_params=args.strict_params,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--startup-cmd",
type=str,
default="vllm bench startup",
help="The command used to run the startup benchmark.",
)
parser.add_argument(
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm serve` command. Only parameters supported by "
"`vllm bench startup` will be applied.",
)
parser.add_argument(
"--startup-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm bench startup` command.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
)
parser.add_argument(
"--num-runs",
type=int,
default=1,
help="Number of runs per parameter combination.",
)
parser.add_argument(
"--show-stdout",
action="store_true",
help="If set, logs the standard output of subcommands.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the commands to run, "
"then exits without executing them.",
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--strict-params",
action="store_true",
help="If set, unknown parameters in sweep files raise an error "
"instead of being ignored.",
)
return parser
def run_main(args: SweepStartupArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_combs(
startup_cmd=args.startup_cmd,
serve_params=args.serve_params,
startup_params=args.startup_params,
output_dir=output_dir,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepStartupArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepStartupArgs.parser_help)
SweepStartupArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def sanitize_filename(filename: str) -> str:
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')

View File

@@ -0,0 +1,946 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
from typing import Any
import torch
import uvloop
from tqdm import tqdm
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (
AIMODataset,
BurstGPTDataset,
ConversationDataset,
InstructCoderDataset,
MultiModalConversationDataset,
PrefixRepetitionRandomDataset,
RandomDataset,
RandomDatasetForReranking,
RandomMultiModalDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
add_random_dataset_base_args,
add_random_multimodal_dataset_args,
)
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.async_utils import merge_async_iterators
def run_vllm(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput] | None]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine.
prompts: list[TextPrompt | TokensPrompt] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompt = (
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
if "prompt_token_ids" in request.prompt
else TextPrompt(prompt=request.prompt)
)
if request.multi_modal_data:
assert isinstance(request.multi_modal_data, dict)
prompt["multi_modal_data"] = request.multi_modal_data
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
)
)
lora_requests: list[LoRARequest] | None = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
use_beam_search = False
outputs = None
if not use_beam_search:
start = time.perf_counter()
if do_profile:
llm.start_profile()
outputs = llm.generate(
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
)
if do_profile:
llm.stop_profile()
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0].expected_output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
if do_profile:
llm.start_profile()
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
),
)
if do_profile:
llm.stop_profile()
end = time.perf_counter()
return end - start, outputs
def run_vllm_chat(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
"""
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead.
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
)
)
start = time.perf_counter()
if do_profile:
llm.start_profile()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
if do_profile:
llm.stop_profile()
end = time.perf_counter()
return end - start, outputs
async def run_vllm_async(
requests: list[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
do_profile: bool,
disable_frontend_multiprocessing: bool = False,
disable_detokenize: bool = False,
) -> float:
from vllm import SamplingParams
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
async with build_async_engine_client_from_engine_args(
engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm:
model_config = llm.model_config
assert all(
model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine.
prompts: list[TextPrompt | TokensPrompt] = []
sampling_params: list[SamplingParams] = []
lora_requests: list[LoRARequest | None] = []
for request in requests:
prompt = (
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
if "prompt_token_ids" in request.prompt
else TextPrompt(prompt=request.prompt)
)
if request.multi_modal_data:
assert isinstance(request.multi_modal_data, dict)
prompt["multi_modal_data"] = request.multi_modal_data
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
)
)
prompts.append(prompt)
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
if do_profile:
await llm.start_profile()
for i, (prompt, sp, lr) in enumerate(
zip(prompts, sampling_params, lora_requests)
):
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
if do_profile:
await llm.stop_profile()
end = time.perf_counter()
return end - start
def run_hf(
requests: list[SampleRequest],
model: str,
tokenizer: TokenizerLike,
n: int,
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
) -> float:
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
"the hf backend only supports HF tokenizers"
)
llm = AutoModelForCausalLM.from_pretrained(
model, dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: list[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt = requests[i].prompt
prompt_len = requests[i].prompt_len
output_len = requests[i].expected_output_len
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len
if (
max(max_prompt_len, next_prompt_len)
+ max(max_output_len, next_output_len)
) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
if not disable_detokenize:
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def get_requests(args, tokenizer):
# Common parameters for all dataset types.
common_kwargs = {
"dataset_path": args.dataset_path,
"random_seed": args.seed,
}
sample_kwargs = {
"tokenizer": tokenizer,
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": args.num_prompts,
}
if args.dataset_name == "random" or (
args.dataset_path is None
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
):
sample_kwargs["range_ratio"] = args.random_range_ratio
# prefer random_* arguments, fall back to regular arguments
random_prefix_len = getattr(args, "random_prefix_len", None)
sample_kwargs["prefix_len"] = (
random_prefix_len if random_prefix_len is not None else args.prefix_len
)
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len if random_input_len is not None else args.input_len
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len if random_output_len is not None else args.output_len
)
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset."
)
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
if args.input_len is not None:
sample_kwargs["input_len"] = args.input_len
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs["dataset_subset"] = None
common_kwargs["dataset_split"] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs["dataset_split"] = "train"
elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = MultiModalConversationDataset
common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
common_kwargs["dataset_subset"] = None
common_kwargs["dataset_split"] = "train"
elif args.dataset_name == "prefix_repetition":
dataset_cls = PrefixRepetitionRandomDataset
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
sample_kwargs["output_len"] = args.prefix_repetition_output_len
elif args.dataset_name == "random-mm":
dataset_cls = RandomMultiModalDataset
# prefer random_* arguments, fall back to regular arguments
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len
if random_input_len is not None
else getattr(args, "input_len", None)
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len
if random_output_len is not None
else getattr(args, "output_len", None)
)
sample_kwargs["base_items_per_request"] = getattr(
args, "random_mm_base_items_per_request", None
)
sample_kwargs["num_mm_items_range_ratio"] = getattr(
args, "random_mm_num_mm_items_range_ratio", None
)
sample_kwargs["limit_mm_per_prompt"] = getattr(
args, "random_mm_limit_mm_per_prompt", None
)
sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
sample_kwargs["enable_multimodal_chat"] = True
random_prefix_len = getattr(args, "random_prefix_len", None)
prefix_len = getattr(args, "prefix_len", None)
sample_kwargs["prefix_len"] = (
random_prefix_len if random_prefix_len is not None else prefix_len
)
sample_kwargs["range_ratio"] = args.random_range_ratio
elif args.dataset_name == "random-rerank":
dataset_cls = RandomDatasetForReranking
# prefer random_* arguments, fall back to regular arguments
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len
if random_input_len is not None
else getattr(args, "input_len", None)
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len
if random_output_len is not None
else getattr(args, "output_len", None)
)
sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
sample_kwargs["range_ratio"] = args.random_range_ratio
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
requests = filter_requests_for_dp(requests, args.data_parallel_size)
return requests
def filter_requests_for_dp(requests, data_parallel_size):
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
# works for external launcher mode. Should be cleaned up and deprecated
# in the future with a better vLLM distributed process design.
if data_parallel_size == 1:
return requests
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
data_parallel_rank = global_rank // (world_size // data_parallel_size)
return [
r
for i, r in enumerate(requests)
if i % data_parallel_size == data_parallel_rank
]
def validate_args(args):
"""
Validate command-line arguments.
"""
# === Deprecation and Defaulting ===
if args.dataset is not None:
warnings.warn(
"The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2,
)
args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
# === Backend Validation ===
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
if args.backend not in valid_backends:
raise ValueError(f"Unsupported backend: {args.backend}")
# === Dataset Configuration ===
if (
not args.dataset
and not args.dataset_path
and args.dataset_name not in {"prefix_repetition"}
):
print("When dataset path is not set, it will default to random dataset")
args.dataset_name = "random"
random_input_len = getattr(args, "random_input_len", None)
if args.input_len is None and random_input_len is None:
raise ValueError(
"Either --input-len or --random-input-len must be provided "
"for a random dataset"
)
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None
):
warnings.warn(
"--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2,
)
elif args.dataset_name == "hf":
if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
| ConversationDataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm-chat", (
f"{args.dataset_path} needs to use vllm-chat as the backend."
)
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm", (
f"{args.dataset_path} needs to use vllm as the backend."
)
else:
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random',
# 'random-mm', or 'random-rerank'
if (
args.dataset_name not in {"random", "random-mm", "random-rerank"}
and args.random_range_ratio is not None
):
warnings.warn(
"--random-range-ratio will be ignored since \
--dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
stacklevel=2,
)
# --random-batch-size: only used when dataset_name is 'random-rerank'
if (
args.dataset_name != "random-rerank"
and getattr(args, "random_batch_size", None) is not None
) and args.random_batch_size != 1:
warnings.warn(
"--random-batch-size will be ignored since \
--dataset-name is not 'random-rerank'.",
stacklevel=2,
)
# --no-reranker: only used when dataset_name is 'random-rerank'
if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
warnings.warn(
"--no-reranker will be ignored since \
--dataset-name is not 'random-rerank'.",
stacklevel=2,
)
# --prefix-len: only used when dataset_name is 'random', 'random-mm',
# 'sonnet', or not set.
if (
args.dataset_name not in {"random", "random-mm", "sonnet", None}
and args.prefix_len is not None
):
warnings.warn(
"--prefix-len will be ignored since --dataset-name\
is not 'random', 'random-mm', 'sonnet', or not set.",
stacklevel=2,
)
# === Random Dataset Argument Conflict Detection ===
# Check for conflicts between regular and random arguments when using
# random datasets
if args.dataset_name in {"random", "random-mm", "random-rerank"}:
random_input_len = getattr(args, "random_input_len", None)
random_output_len = getattr(args, "random_output_len", None)
random_prefix_len = getattr(args, "random_prefix_len", None)
if args.input_len is not None and random_input_len is not None:
warnings.warn(
"Both --input-len and --random-input-len are specified. "
"The random version (--random-input-len) will be preferred "
"in this run.",
stacklevel=2,
)
if args.output_len is not None and random_output_len is not None:
warnings.warn(
"Both --output-len and --random-output-len are specified. "
"The random version (--random-output-len) will be preferred "
"in this run.",
stacklevel=2,
)
if args.prefix_len is not None and random_prefix_len is not None:
warnings.warn(
"Both --prefix-len and --random-prefix-len are specified. "
"The random version (--random-prefix-len) will be preferred "
"in this run.",
stacklevel=2,
)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True")
# === Backend-specific Validations ===
if args.backend == "hf" and args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend")
if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if (
args.backend in {"hf", "mii"}
and getattr(args, "quantization", None) is not None
):
raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII backend.")
if args.data_parallel_size > 1 and (
args.distributed_executor_backend != "external_launcher" or args.async_engine
):
# --data-parallel is not supported fully.
# Old issue: https://github.com/vllm-project/vllm/issues/16222
# Currently we only support data parallel with external launcher
# mode (i.e., launch with toruchrun).
raise ValueError(
"Data parallel is only supported with external launcher mode "
"with synchronous engine in offline benchmark, "
"please use benchmark serving instead"
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--backend",
type=str,
choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm",
)
parser.add_argument(
"--dataset-name",
type=str,
choices=[
"sharegpt",
"random",
"sonnet",
"burstgpt",
"hf",
"prefix_repetition",
"random-mm",
"random-rerank",
],
help="Name of the dataset to benchmark on.",
default="sharegpt",
)
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]",
)
parser.add_argument(
"--dataset-path", type=str, default=None, help="Path to the dataset"
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
parser.add_argument(
"--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the throughput results in JSON format.",
)
parser.add_argument(
"--async-engine",
action="store_true",
default=False,
help="Use vLLM async engine rather than LLM class.",
)
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
default=False,
help="Disable decoupled async engine frontend.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"
),
)
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.",
)
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before the random "
"context in a request (default: 0).",
)
# hf dtaset
parser.add_argument(
"--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.",
)
parser.add_argument(
"--hf-split",
type=str,
default=None,
help="Split of the HF dataset.",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
# prefix repetition dataset
parser.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=None,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
parser.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=None,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
parser.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=None,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
parser.add_argument(
"--prefix-repetition-output-len",
type=int,
default=None,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
# (random, random-mm, random-rerank)
add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
validate_args(args)
if args.seed is None:
args.seed = 0
random.seed(args.seed)
# Sample the requests.
if (
args.backend == "hf" or args.backend == "mii"
) and args.tokenizer_mode == "auto":
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
# for hf and mii backends, we use hf tokenizer
args.tokenizer_mode = "hf"
tokenizer = get_tokenizer(
args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code,
)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
request_outputs: list[RequestOutput] | None = None
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)
)
else:
elapsed_time, request_outputs = run_vllm(
requests,
args.n,
EngineArgs.from_cli_args(args),
disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
if args.profile:
raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
elapsed_time = run_hf(
requests,
args.model,
tokenizer,
args.n,
args.hf_max_batch_size,
args.trust_remote_code,
args.disable_detokenize,
)
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(
requests,
args.n,
EngineArgs.from_cli_args(args),
disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)
else:
raise ValueError(f"Unknown backend: {args.backend}")
if request_outputs:
# Note: with the vllm and vllm-chat backends,
# we have request_outputs, which we use to count tokens.
total_prompt_tokens = 0
total_output_tokens = 0
for ro in request_outputs:
if not isinstance(ro, RequestOutput):
continue
total_prompt_tokens += (
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
)
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens
else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat":
print(
"\033[91mWARNING\033[0m: Multi-modal request with "
f"{args.backend} backend detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details."
)
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
)
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

851
vllm/collect_env.py Normal file
View File

@@ -0,0 +1,851 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
import datetime
import locale
import os
import subprocess
import sys
# Unlike the rest of the PyTorch this file must be python2 compliant.
# This script outputs relevant system environment info
# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
from collections import namedtuple
import regex as re
from vllm.envs import environment_variables
try:
import torch
TORCH_AVAILABLE = True
except (ImportError, NameError, AttributeError, OSError):
TORCH_AVAILABLE = False
# System Environment Information
SystemEnv = namedtuple(
"SystemEnv",
[
"torch_version",
"is_debug_build",
"cuda_compiled_version",
"gcc_version",
"clang_version",
"cmake_version",
"os",
"libc_version",
"python_version",
"python_platform",
"is_cuda_available",
"cuda_runtime_version",
"cuda_module_loading",
"nvidia_driver_version",
"nvidia_gpu_models",
"cudnn_version",
"pip_version", # 'pip' or 'pip3'
"pip_packages",
"conda_packages",
"hip_compiled_version",
"hip_runtime_version",
"miopen_runtime_version",
"caching_allocator_config",
"is_xnnpack_available",
"cpu_info",
"rocm_version", # vllm specific field
"vllm_version", # vllm specific field
"vllm_build_flags", # vllm specific field
"gpu_topo", # vllm specific field
"env_vars",
],
)
DEFAULT_CONDA_PATTERNS = {
"torch",
"numpy",
"cudatoolkit",
"soumith",
"mkl",
"magma",
"triton",
"optree",
"nccl",
"transformers",
"zmq",
"nvidia",
"pynvml",
"flashinfer-python",
"helion",
}
DEFAULT_PIP_PATTERNS = {
"torch",
"numpy",
"mypy",
"flake8",
"triton",
"optree",
"onnx",
"nccl",
"transformers",
"zmq",
"nvidia",
"pynvml",
"flashinfer-python",
"helion",
}
def run(command):
"""Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False
try:
p = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
)
raw_output, raw_err = p.communicate()
rc = p.returncode
if get_platform() == "win32":
enc = "oem"
else:
enc = locale.getpreferredencoding()
output = raw_output.decode(enc)
if command == "nvidia-smi topo -m":
# don't remove the leading whitespace of `nvidia-smi topo -m`
# because they are meaningful
output = output.rstrip()
else:
output = output.strip()
err = raw_err.decode(enc)
return rc, output, err.strip()
except FileNotFoundError:
cmd_str = command if isinstance(command, str) else command[0]
return 127, "", f"Command not found: {cmd_str}"
def run_and_read_all(run_lambda, command):
"""Run command using run_lambda; reads and returns entire output if rc is 0."""
rc, out, _ = run_lambda(command)
if rc != 0:
return None
return out
def run_and_parse_first_match(run_lambda, command, regex):
"""Run command using run_lambda, returns the first regex match if it exists."""
rc, out, _ = run_lambda(command)
if rc != 0:
return None
match = re.search(regex, out)
if match is None:
return None
return match.group(1)
def get_conda_packages(run_lambda, patterns=None):
if patterns is None:
patterns = DEFAULT_CONDA_PATTERNS
conda = os.environ.get("CONDA_EXE", "conda")
out = run_and_read_all(run_lambda, [conda, "list"])
if out is None:
return out
return "\n".join(
line
for line in out.splitlines()
if not line.startswith("#") and any(name in line for name in patterns)
)
def get_gcc_version(run_lambda):
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
def get_clang_version(run_lambda):
return run_and_parse_first_match(
run_lambda, "clang --version", r"clang version (.*)"
)
def get_cmake_version(run_lambda):
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
def get_nvidia_driver_version(run_lambda):
if get_platform() == "darwin":
cmd = "kextstat | grep -i cuda"
return run_and_parse_first_match(
run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]"
)
smi = get_nvidia_smi()
return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
def get_gpu_info(run_lambda):
if get_platform() == "darwin" or (
TORCH_AVAILABLE
and hasattr(torch.version, "hip")
and torch.version.hip is not None
):
if TORCH_AVAILABLE and torch.cuda.is_available():
if torch.version.hip is not None:
prop = torch.cuda.get_device_properties(0)
if hasattr(prop, "gcnArchName"):
gcnArch = " ({})".format(prop.gcnArchName)
else:
gcnArch = "NoGCNArchNameOnOldPyTorch"
else:
gcnArch = ""
return torch.cuda.get_device_name(None) + gcnArch
return None
smi = get_nvidia_smi()
uuid_regex = re.compile(r" \(UUID: .+?\)")
rc, out, _ = run_lambda(smi + " -L")
if rc != 0:
return None
# Anonymize GPUs by removing their UUID
return re.sub(uuid_regex, "", out)
def get_running_cuda_version(run_lambda):
return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
def get_cudnn_version(run_lambda):
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
if get_platform() == "win32":
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
where_cmd = os.path.join(system_root, "System32", "where")
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
elif get_platform() == "darwin":
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
# https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
# https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
else:
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
rc, out, _ = run_lambda(cudnn_cmd)
# find will return 1 if there are permission errors or if not found
if len(out) == 0 or (rc != 1 and rc != 0):
l = os.environ.get("CUDNN_LIBRARY")
if l is not None and os.path.isfile(l):
return os.path.realpath(l)
return None
files_set = set()
for fn in out.split("\n"):
fn = os.path.realpath(fn) # eliminate symbolic links
if os.path.isfile(fn):
files_set.add(fn)
if not files_set:
return None
# Alphabetize the result because the order is non-deterministic otherwise
files = sorted(files_set)
if len(files) == 1:
return files[0]
result = "\n".join(files)
return "Probably one of the following:\n{}".format(result)
def get_nvidia_smi():
# Note: nvidia-smi is currently available only on Windows and Linux
smi = "nvidia-smi"
if get_platform() == "win32":
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
legacy_path = os.path.join(
program_files_root, "NVIDIA Corporation", "NVSMI", smi
)
new_path = os.path.join(system_root, "System32", smi)
smis = [new_path, legacy_path]
for candidate_smi in smis:
if os.path.exists(candidate_smi):
smi = '"{}"'.format(candidate_smi)
break
return smi
def get_rocm_version(run_lambda):
"""Returns the ROCm version if available, otherwise 'N/A'."""
return run_and_parse_first_match(
run_lambda, "hipcc --version", r"HIP version: (\S+)"
)
def get_vllm_version():
from vllm import __version__, __version_tuple__
if __version__ == "dev":
return "N/A (dev)"
version_str = __version_tuple__[-1]
if isinstance(version_str, str) and version_str.startswith("g"):
# it's a dev build
if "." in version_str:
# it's a dev build containing local changes
git_sha = version_str.split(".")[0][1:]
date = version_str.split(".")[-1][1:]
return f"{__version__} (git sha: {git_sha}, date: {date})"
else:
# it's a dev build without local changes
git_sha = version_str[1:] # type: ignore
return f"{__version__} (git sha: {git_sha})"
return __version__
def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return "CUDA Archs: {}; ROCm: {}".format(
os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"),
"Enabled" if os.environ.get("ROCM_HOME") else "Disabled",
)
def get_gpu_topo(run_lambda):
output = None
if get_platform() == "linux":
output = run_and_read_all(run_lambda, "nvidia-smi topo -m")
if output is None:
output = run_and_read_all(run_lambda, "rocm-smi --showtopo")
return output
# example outputs of CPU infos
# * linux
# Architecture: x86_64
# CPU op-mode(s): 32-bit, 64-bit
# Address sizes: 46 bits physical, 48 bits virtual
# Byte Order: Little Endian
# CPU(s): 128
# On-line CPU(s) list: 0-127
# Vendor ID: GenuineIntel
# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
# CPU family: 6
# Model: 106
# Thread(s) per core: 2
# Core(s) per socket: 32
# Socket(s): 2
# Stepping: 6
# BogoMIPS: 5799.78
# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
# Virtualization features:
# Hypervisor vendor: KVM
# Virtualization type: full
# Caches (sum of all):
# L1d: 3 MiB (64 instances)
# L1i: 2 MiB (64 instances)
# L2: 80 MiB (64 instances)
# L3: 108 MiB (2 instances)
# NUMA:
# NUMA node(s): 2
# NUMA node0 CPU(s): 0-31,64-95
# NUMA node1 CPU(s): 32-63,96-127
# Vulnerabilities:
# Itlb multihit: Not affected
# L1tf: Not affected
# Mds: Not affected
# Meltdown: Not affected
# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
# Retbleed: Not affected
# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
# Srbds: Not affected
# Tsx async abort: Not affected
# * win32
# Architecture=9
# CurrentClockSpeed=2900
# DeviceID=CPU0
# Family=179
# L2CacheSize=40960
# L2CacheSpeed=
# Manufacturer=GenuineIntel
# MaxClockSpeed=2900
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
# ProcessorType=3
# Revision=27142
#
# Architecture=9
# CurrentClockSpeed=2900
# DeviceID=CPU1
# Family=179
# L2CacheSize=40960
# L2CacheSpeed=
# Manufacturer=GenuineIntel
# MaxClockSpeed=2900
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
# ProcessorType=3
# Revision=27142
def get_cpu_info(run_lambda):
rc, out, err = 0, "", ""
if get_platform() == "linux":
rc, out, err = run_lambda("lscpu")
elif get_platform() == "win32":
rc, out, err = run_lambda(
"wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE"
)
elif get_platform() == "darwin":
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
cpu_info = "None"
if rc == 0:
cpu_info = out
else:
cpu_info = err
return cpu_info
def get_platform():
if sys.platform.startswith("linux"):
return "linux"
elif sys.platform.startswith("win32"):
return "win32"
elif sys.platform.startswith("cygwin"):
return "cygwin"
elif sys.platform.startswith("darwin"):
return "darwin"
else:
return sys.platform
def get_mac_version(run_lambda):
return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
def get_windows_version(run_lambda):
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic")
findstr_cmd = os.path.join(system_root, "System32", "findstr")
return run_and_read_all(
run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd)
)
def get_lsb_version(run_lambda):
return run_and_parse_first_match(
run_lambda, "lsb_release -a", r"Description:\t(.*)"
)
def check_release_file(run_lambda):
return run_and_parse_first_match(
run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"'
)
def get_os(run_lambda):
from platform import machine
platform = get_platform()
if platform == "win32" or platform == "cygwin":
return get_windows_version(run_lambda)
if platform == "darwin":
version = get_mac_version(run_lambda)
if version is None:
return None
return "macOS {} ({})".format(version, machine())
if platform == "linux":
# Ubuntu/Debian based
desc = get_lsb_version(run_lambda)
if desc is not None:
return "{} ({})".format(desc, machine())
# Try reading /etc/*-release
desc = check_release_file(run_lambda)
if desc is not None:
return "{} ({})".format(desc, machine())
return "{} ({})".format(platform, machine())
# Unknown platform
return platform
def get_python_platform():
import platform
return platform.platform()
def get_libc_version():
import platform
if get_platform() != "linux":
return "N/A"
return "-".join(platform.libc_ver())
def is_uv_venv():
if os.environ.get("UV"):
return True
pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg")
if os.path.exists(pyvenv_cfg_path):
with open(pyvenv_cfg_path, "r") as f:
return any(line.startswith("uv = ") for line in f)
return False
def get_pip_packages(run_lambda, patterns=None):
"""Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
if patterns is None:
patterns = DEFAULT_PIP_PATTERNS
def run_with_pip():
try:
import importlib.util
pip_spec = importlib.util.find_spec("pip")
pip_available = pip_spec is not None
except ImportError:
pip_available = False
if pip_available:
cmd = [sys.executable, "-mpip", "list", "--format=freeze"]
elif is_uv_venv():
print("uv is set")
cmd = ["uv", "pip", "list", "--format=freeze"]
else:
raise RuntimeError(
"Could not collect pip list output (pip or uv module not available)"
)
out = run_and_read_all(run_lambda, cmd)
return "\n".join(
line for line in out.splitlines() if any(name in line for name in patterns)
)
pip_version = "pip3" if sys.version[0] == "3" else "pip"
out = run_with_pip()
return pip_version, out
def get_cachingallocator_config():
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
return ca_config
def get_cuda_module_loading_config():
if TORCH_AVAILABLE and torch.cuda.is_available():
torch.cuda.init()
config = os.environ.get("CUDA_MODULE_LOADING", "")
return config
else:
return "N/A"
def is_xnnpack_available():
if TORCH_AVAILABLE:
import torch.backends.xnnpack
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
else:
return "N/A"
def get_env_vars():
env_vars = ""
secret_terms = ("secret", "token", "api", "access", "password")
report_prefix = (
"TORCH",
"NCCL",
"PYTORCH",
"CUDA",
"CUBLAS",
"CUDNN",
"OMP_",
"MKL_",
"NVIDIA",
)
for k, v in os.environ.items():
if any(term in k.lower() for term in secret_terms):
continue
if k in environment_variables:
env_vars = env_vars + "{}={}".format(k, v) + "\n"
if k.startswith(report_prefix):
env_vars = env_vars + "{}={}".format(k, v) + "\n"
return env_vars
def get_env_info():
run_lambda = run
pip_version, pip_list_output = get_pip_packages(run_lambda)
if TORCH_AVAILABLE:
version_str = torch.__version__
debug_mode_str = str(torch.version.debug)
cuda_available_str = str(torch.cuda.is_available())
cuda_version_str = torch.version.cuda
if (
not hasattr(torch.version, "hip") or torch.version.hip is None
): # cuda version
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
else: # HIP version
def get_version_or_na(cfg, prefix):
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
return _lst[0] if _lst else "N/A"
cfg = torch._C._show_config().split("\n")
hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
cuda_version_str = "N/A"
hip_compiled_version = torch.version.hip
else:
version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A"
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
sys_version = sys.version.replace("\n", " ")
conda_packages = get_conda_packages(run_lambda)
rocm_version = get_rocm_version(run_lambda)
vllm_version = get_vllm_version()
vllm_build_flags = summarize_vllm_build_flags()
gpu_topo = get_gpu_topo(run_lambda)
return SystemEnv(
torch_version=version_str,
is_debug_build=debug_mode_str,
python_version="{} ({}-bit runtime)".format(
sys_version, sys.maxsize.bit_length() + 1
),
python_platform=get_python_platform(),
is_cuda_available=cuda_available_str,
cuda_compiled_version=cuda_version_str,
cuda_runtime_version=get_running_cuda_version(run_lambda),
cuda_module_loading=get_cuda_module_loading_config(),
nvidia_gpu_models=get_gpu_info(run_lambda),
nvidia_driver_version=get_nvidia_driver_version(run_lambda),
cudnn_version=get_cudnn_version(run_lambda),
hip_compiled_version=hip_compiled_version,
hip_runtime_version=hip_runtime_version,
miopen_runtime_version=miopen_runtime_version,
pip_version=pip_version,
pip_packages=pip_list_output,
conda_packages=conda_packages,
os=get_os(run_lambda),
libc_version=get_libc_version(),
gcc_version=get_gcc_version(run_lambda),
clang_version=get_clang_version(run_lambda),
cmake_version=get_cmake_version(run_lambda),
caching_allocator_config=get_cachingallocator_config(),
is_xnnpack_available=is_xnnpack_available(),
cpu_info=get_cpu_info(run_lambda),
rocm_version=rocm_version,
vllm_version=vllm_version,
vllm_build_flags=vllm_build_flags,
gpu_topo=gpu_topo,
env_vars=get_env_vars(),
)
env_info_fmt = """
==============================
System Info
==============================
OS : {os}
GCC version : {gcc_version}
Clang version : {clang_version}
CMake version : {cmake_version}
Libc version : {libc_version}
==============================
PyTorch Info
==============================
PyTorch version : {torch_version}
Is debug build : {is_debug_build}
CUDA used to build PyTorch : {cuda_compiled_version}
ROCM used to build PyTorch : {hip_compiled_version}
==============================
Python Environment
==============================
Python version : {python_version}
Python platform : {python_platform}
==============================
CUDA / GPU Info
==============================
Is CUDA available : {is_cuda_available}
CUDA runtime version : {cuda_runtime_version}
CUDA_MODULE_LOADING set to : {cuda_module_loading}
GPU models and configuration : {nvidia_gpu_models}
Nvidia driver version : {nvidia_driver_version}
cuDNN version : {cudnn_version}
HIP runtime version : {hip_runtime_version}
MIOpen runtime version : {miopen_runtime_version}
Is XNNPACK available : {is_xnnpack_available}
==============================
CPU Info
==============================
{cpu_info}
==============================
Versions of relevant libraries
==============================
{pip_packages}
{conda_packages}
""".strip()
# both the above code and the following code use `strip()` to
# remove leading/trailing whitespaces, so we need to add a newline
# in between to separate the two sections
env_info_fmt += "\n\n"
env_info_fmt += """
==============================
vLLM Info
==============================
ROCM Version : {rocm_version}
vLLM Version : {vllm_version}
vLLM Build Flags:
{vllm_build_flags}
GPU Topology:
{gpu_topo}
==============================
Environment Variables
==============================
{env_vars}
""".strip()
def pretty_str(envinfo):
def replace_nones(dct, replacement="Could not collect"):
for key in dct.keys():
if dct[key] is not None:
continue
dct[key] = replacement
return dct
def replace_bools(dct, true="Yes", false="No"):
for key in dct.keys():
if dct[key] is True:
dct[key] = true
elif dct[key] is False:
dct[key] = false
return dct
def prepend(text, tag="[prepend]"):
lines = text.split("\n")
updated_lines = [tag + line for line in lines]
return "\n".join(updated_lines)
def replace_if_empty(text, replacement="No relevant packages"):
if text is not None and len(text) == 0:
return replacement
return text
def maybe_start_on_next_line(string):
# If `string` is multiline, prepend a \n to it.
if string is not None and len(string.split("\n")) > 1:
return "\n{}\n".format(string)
return string
mutable_dict = envinfo._asdict()
# If nvidia_gpu_models is multiline, start on the next line
mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(
envinfo.nvidia_gpu_models
)
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
dynamic_cuda_fields = [
"cuda_runtime_version",
"nvidia_gpu_models",
"nvidia_driver_version",
]
all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
all_dynamic_cuda_fields_missing = all(
mutable_dict[field] is None for field in dynamic_cuda_fields
)
if (
TORCH_AVAILABLE
and not torch.cuda.is_available()
and all_dynamic_cuda_fields_missing
):
for field in all_cuda_fields:
mutable_dict[field] = "No CUDA"
if envinfo.cuda_compiled_version is None:
mutable_dict["cuda_compiled_version"] = "None"
# Replace True with Yes, False with No
mutable_dict = replace_bools(mutable_dict)
# Replace all None objects with 'Could not collect'
mutable_dict = replace_nones(mutable_dict)
# If either of these are '', replace with 'No relevant packages'
mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
# Tag conda and pip packages with a prefix
# If they were previously None, they'll show up as ie '[conda] Could not collect'
if mutable_dict["pip_packages"]:
mutable_dict["pip_packages"] = prepend(
mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
)
if mutable_dict["conda_packages"]:
mutable_dict["conda_packages"] = prepend(
mutable_dict["conda_packages"], "[conda] "
)
mutable_dict["cpu_info"] = envinfo.cpu_info
return env_info_fmt.format(**mutable_dict)
def get_pretty_env_info():
return pretty_str(get_env_info())
def main():
print("Collecting environment information...")
output = get_pretty_env_info()
print(output)
if (
TORCH_AVAILABLE
and hasattr(torch, "utils")
and hasattr(torch.utils, "_crash_handler")
):
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
if sys.platform == "linux" and os.path.exists(minidump_dir):
dumps = [
os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)
]
latest = max(dumps, key=os.path.getctime)
ctime = os.path.getctime(latest)
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
"%Y-%m-%d %H:%M:%S"
)
msg = (
"\n*** Detected a minidump at {} created on {}, ".format(
latest, creation_time
)
+ "if this is related to your bug please include it when you file a report ***"
)
print(msg, file=sys.stderr)
if __name__ == "__main__":
main()

View File

1131
vllm/compilation/backends.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Protocol
from vllm.config import CUDAGraphMode, VllmConfig
class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""
def __init__(
self,
runnable: Callable[..., Any],
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
**kwargs: Any,
) -> None:
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
runtime_mode (CUDAGraphMode): The style of the static
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Executes the wrapped callable.
If the current runtime mode in the ForwardContext matches the runtime
mode of this instance, it replays the CUDAGraph or captures it using
the callable if it hasn't been captured yet. Otherwise, it calls the
original callable directly.
Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.
Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError

516
vllm/compilation/caching.py Normal file
View File

@@ -0,0 +1,516 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect
import os
import pickle
from collections.abc import Callable, Sequence
from typing import Any, Literal
from unittest.mock import patch
import torch
from torch.utils import _pytree as pytree
import vllm.envs as envs
from vllm.compilation.compiler_interface import get_inductor_factors
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
try:
from torch._dynamo.aot_compile import SerializableCallable
except ImportError:
SerializableCallable = object
assert isinstance(SerializableCallable, type)
logger = init_logger(__name__)
class StandaloneCompiledArtifacts:
"""Storage for standalone compiled artifacts with content-based deduplication.
Deduplication works via a two-level indirection:
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
When inserting, we compute the SHA256 hash of the bytes. If the hash
already exists in `submodule_bytes_store`, we reuse the existing entry
rather than storing duplicate bytes. This is common because submodules
often compile to identical artifacts (e.g., identical transformer layers
split on attn)
"""
def __init__(self) -> None:
# dict from submodule name to byte hash
self.submodule_bytes: dict[str, str] = {}
# dict from byte hash to bytes
self.submodule_bytes_store: dict[str, bytes] = {}
# dict from byte hash to loaded module
self.loaded_submodule_store: dict[str, Any] = {}
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
hasher = hashlib.sha256()
hasher.update(entry)
hex_digest = hasher.hexdigest()
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
if hex_digest not in self.submodule_bytes_store:
self.submodule_bytes_store[hex_digest] = entry
logger.debug(
"inserting new artifact for submod %s with shape %s "
"(%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
else:
logger.debug(
"reusing existing cache artifact for submod %s "
"with shape %s (%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
def get(self, submod_name: str, shape: str) -> bytes:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.submodule_bytes_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def get_loaded(self, submod_name: str, shape: str) -> Any:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.loaded_submodule_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def size_bytes(self) -> int:
return sum(len(entry) for entry in self.submodule_bytes_store.values())
def num_artifacts(self) -> int:
return len(self.submodule_bytes_store)
def num_entries(self) -> int:
return len(self.submodule_bytes)
def submodule_names(self) -> list[str]:
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
return list(dict.fromkeys(names))
def load_all(self) -> None:
import concurrent.futures
# check already loaded
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
return
from torch._inductor.standalone_compile import AOTCompiledArtifact
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry)
with concurrent.futures.ThreadPoolExecutor() as executor:
entries = list(self.submodule_bytes_store.values())
loaded_entries = list(executor.map(_load_entry, entries))
for i, k in enumerate(self.submodule_bytes_store.keys()):
self.loaded_submodule_store[k] = loaded_entries[i]
logger.debug("loaded all %s submodules", self.num_artifacts())
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
return {
"submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store,
}
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {}
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
It also implements a serialization interface to support PyTorch's precompile
with custom backend, so that we can save and load the compiled function on
disk. There's no need to wrap around the compiled function if we don't want
to serialize them in particular cases.
Right now serialization for the custom backend is done via
serializing the Dynamo fx graph plus example inputs.
"""
def __init__(
self,
graph_module: torch.fx.GraphModule,
example_inputs: Sequence[Any],
prefix: str,
optimized_call: Callable[..., Any],
is_encoder: bool = False,
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
self.prefix = prefix
self.optimized_call = optimized_call
self.is_encoder = is_encoder
self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices
sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
)
if sym_input is not None:
self.shape_env = sym_input.node.shape_env
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.optimized_call(*args, **kwargs)
@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override
def _graph_reducer_override(
self: GraphPickler, obj: Any
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
if (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
and hasattr(obj, "_torch_unpickler")
):
return obj._torch_unpickler, (obj._torch_handler_name,)
if isinstance(obj, FakeTensorMode):
return type(None), ()
return graph_reducer_override(self, obj)
if state.get("sym_tensor_indices"):
# put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
else:
# mask off all tensor inputs since they are large and not needed.
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
if compiled_fn.vllm_backend:
(
standalone_compile_artifacts,
sym_shape_indices_map,
returns_tuple_map,
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
state["standalone_compile_artifacts"] = standalone_compile_artifacts
state["sym_shape_indices_map"] = sym_shape_indices_map
state["returns_tuple_map"] = returns_tuple_map
return pickle.dumps(state)
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
from torch._guards import TracingContext, tracing
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {})
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names()
num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
logger.info(
"reconstructed serializable fn from standalone compile artifacts"
)
return fn
# Fall back to standard VllmBackend
from vllm.compilation.backends import VllmBackend
is_encoder = state.get("is_encoder", False)
vllm_backend: VllmBackend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
def optimized_call(*example_inputs: Any) -> Any:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs = [
inp if inp is not None else example_inputs[i]
for i, inp in enumerate(fn.example_inputs)
]
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call
return fn.optimized_call(*example_inputs)
fn = cls(**state, optimized_call=optimized_call)
return fn
@property
def co_name(self) -> Literal["VllmSerializableFunction"]:
"""
Used for depyf debugging.
"""
return "VllmSerializableFunction"
def reconstruct_serializable_fn_from_mega_artifact(
state: dict[str, Any],
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
vllm_config: VllmConfig,
sym_shape_indices_map: dict[str, list[int]],
returns_tuple_map: dict[str, bool],
) -> "VllmSerializableFunction":
"""Construct a VllmSerializableFunction from cached inductor artifacts.
This function reconstructs a callable model from pre-compiled inductor
artifacts without re-running the compilation. It:
1. Loads all cached artifacts
2. Builds compiled callables for each submodule/shape
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
4. Wraps with cudagraph if needed
5. Returns the final VllmSerializableFunction
Note: This function shares similar logic with PiecewiseCompileInterpreter
in backends.py. Both create PiecewiseBackend instances and wrap them with
cudagraph. The key difference is:
- this function: PiecewiseBackend receives pre-compiled runnables
(compiled_runnables is set, graph is None)
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
to compile (graph is set, compiled_runnables is None)
If modifying the backend creation/wrapping logic, consider updating both.
Args:
state: Deserialized state dict containing graph_module, example_inputs,
prefix, sym_tensor_indices, is_encoder, etc.
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
pre-compiled artifacts for each submodule/shape combination.
vllm_config: The vLLM configuration.
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
returns_tuple_map: Mapping from submod_name to returns_tuple.
Returns:
A VllmSerializableFunction that can be called directly.
"""
from vllm.compilation.backends import (
VllmBackend,
make_copy_and_call,
wrap_with_cudagraph_if_needed,
)
from vllm.compilation.piecewise_backend import PiecewiseBackend
prefix = state["prefix"]
is_encoder = state.get("is_encoder", False)
split_gm = state["graph_module"]
compilation_config = vllm_config.compilation_config
standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1)
compiled_callables.setdefault(submod_name, {})[shape_str] = (
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
)
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
os.makedirs(dummy_cache_dir, exist_ok=True)
vllm_backend.compiler_manager.initialize_cache(
cache_dir=dummy_cache_dir,
disable_cache=True,
prefix=prefix,
)
# spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()}
missing = set(submod_names) - graph_children
assert not missing, (
f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}"
)
for i, submod_name in enumerate(submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
sym_shape_indices = sym_shape_indices_map[submod_name]
returns_tuple = returns_tuple_map[submod_name]
runnables = compiled_callables[submod_name]
piecewise_backend = PiecewiseBackend(
graph=None, # not needed for cached artifacts
vllm_config=vllm_config,
piecewise_compile_index=i,
total_piecewise_compiles=len(submod_names),
sym_shape_indices=sym_shape_indices,
vllm_backend=vllm_backend,
returns_tuple=returns_tuple,
compiled_runnables=runnables,
)
is_first = i == 0
is_last = i == len(submod_names) - 1
wrapped_backend = wrap_with_cudagraph_if_needed(
piecewise_backend,
vllm_config,
compilation_config,
is_first,
is_last,
)
split_gm.__dict__[submod_name] = wrapped_backend
logger.debug(
"Replaced submodule %s with piecewise backend from cache",
submod_name,
)
if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"]
input_buffers = [
torch.empty_like(
state["example_inputs"][idx], device=vllm_config.device_config.device
)
for idx in sym_tensor_indices
]
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
else:
optimized_call = split_gm
fn = VllmSerializableFunction(
**state,
optimized_call=optimized_call,
vllm_backend=None,
)
return fn
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = hash_factors(envs.compile_factors())
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
# 2. inductor factors if applicable
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
factors.extend(get_inductor_factors())
return factors
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
hash_content = []
for filepath, content in items:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
result: str = safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
return result
def _compute_code_hash(files: set[str]) -> str:
logger.debug(
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
)
file_contents = {}
for filepath in files:
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
if not os.path.isfile(filepath):
file_contents[filepath] = ""
else:
with open(filepath) as f:
file_contents[filepath] = f.read()
return _compute_code_hash_with_content(file_contents)

View File

@@ -0,0 +1,660 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import os
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any, Literal
from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name: str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
) -> None:
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass
def compute_hash(self, vllm_config: VllmConfig) -> str:
"""
Gather all the relevant information from the vLLM config,
to compute a hash so that we can cache the compiled model.
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return ""
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
or a range [5, 8].
Right now we only support one variable in ranges for all inputs,
which is the batchsize (number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
return None, None
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise NotImplementedError("caching is not supported")
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: list[Any] = []
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
return True
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
return []
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
return ""
def get_inductor_factors() -> list[Any]:
factors: list[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
return factors
def is_compile_cache_enabled(
vllm_additional_inductor_config: dict[str, Any],
) -> bool:
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
"force_disable_caches", False
)
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
# with torch.compiler.config.force_disable_caches when minimum PyTorch
# version reaches 2.10
return (
not envs.VLLM_DISABLE_COMPILE_CACHE
and not torch._inductor.config.force_disable_caches
and not vllm_inductor_config_disable_cache
)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
Requires PyTorch 2.8+.
This is not on by default yet, but we plan to turn it on by default for
PyTorch 2.8.
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
"""
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
) -> None:
self.cache_dir = cache_dir
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, compile_range)
set_functorch_config()
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_graph"
from torch._inductor import standalone_compile
supports_aot = is_torch_equal_or_newer("2.10.0")
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
logger.error(
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
"is enabled but PyTorch version does not support 'aot' "
"parameter in standalone_compile. This requires PyTorch "
"2.10.0+. Falling back to non-AOT mode."
)
compile_kwargs = {
"dynamic_shapes": dynamic_shapes,
"options": {
"config_patches": current_config,
},
}
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
# only add 'aot' parameter if both supported and enabled...
# this will set bundled_autograd_cache
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment]
# Inductor's pre-grad passes don't do anything for vLLM.
# The pre-grad passes get run even on cache-hit and negatively impact
# vllm cold compile times by O(1s)
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
if envs.VLLM_ENABLE_PREGRAD_PASSES:
ctx: Any = contextlib.nullcontext()
else:
ctx = patch(
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx:
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
from torch._inductor.standalone_compile import AOTCompiledArtifact
assert isinstance(compiled_graph, AOTCompiledArtifact)
assert hasattr(compiled_graph, "serialize")
# just return the compiled graph and a key
# since we can serialize the bytes using to_bytes
# and reload it using the key when reading
return compiled_graph, None
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
def is_saveable_2_10(compiled_artifact):
# can just use compiled_artifact.is_saveable in 2.11
if compiled_artifact._artifacts is None:
return False
_, cache_info = compiled_artifact._artifacts
return len(cache_info.aot_autograd_artifacts) == 1
if is_compile_cache_enabled(compiler_config):
if not is_saveable_2_10(compiled_graph):
raise RuntimeError(
"The compiled artifact is not serializable. This usually means "
"that the model code has something that is not serializable "
"by torch.compile in it. You can fix this by either "
"figuring out what is not serializable and rewriting it, "
"filing a bug report, "
"or suppressing this error by "
"disabling vLLM's compilation cache via "
"VLLM_DISABLE_COMPILE_CACHE=1 "
"(this will greatly increase vLLM server warm start times)."
)
compiled_graph.save(path=path, format=self.save_format)
compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path)
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format=self.save_format
)
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
) -> None:
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a subdirectory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, compile_range)
set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str, file_path = None, None
from torch._inductor.codecache import compiled_fx_graph_hash
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if (
not file_path.startswith(self.base_cache_dir)
and compiled_fn.__closure__ is not None
):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename
break
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
# for hijacking the hash of the compiled graph
stack.enter_context(
patch(
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash,
)
)
# for providing a dummy shape environment
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env,
)
)
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env,
)
)
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache,
)
)
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
# Disable remote caching. When these are on, on remote cache-hit,
# the monkey-patched functions never actually get called.
# vLLM today assumes and requires the monkey-patched functions to
# get hit.
# TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False)
)
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config,
)
# Turn off the checks if we disable the compilation cache.
if is_compile_cache_enabled(compiler_config):
if hash_str is None:
raise RuntimeError(
"vLLM failed to compile the model. The most "
"likely reason for this is that a previous compilation "
"failed, leading to a corrupted compilation artifact. "
"We recommend trying to "
"remove ~/.cache/vllm/torch_compile_cache and try again "
"to see the real issue. "
)
assert file_path is not None, (
"failed to get the file path of the compiled graph"
)
return compiled_graph, (hash_str, file_path)
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove "
f"the cache directory and try again." # noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different mode of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
else:
return contextlib.nullcontext()
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
)
def set_functorch_config() -> None:
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
torch._functorch.config.bundled_autograd_cache = False
class EagerAdaptor(CompilerInterface):
name = "eager"
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_eager_compiles += 1
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
return graph, None

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
@dataclasses.dataclass
class CompilationCounter:
num_models_seen: int = 0
num_graphs_seen: int = 0
# including the splitting ops
num_piecewise_graphs_seen: int = 0
# not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0
num_backend_compilations: int = 0
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
num_gpu_runner_capture_triggers: int = 0
# Number of CUDAGraphs captured
num_cudagraph_captured: int = 0
# InductorAdapter.compile calls
num_inductor_compiles: int = 0
# EagerAdapter.compile calls
num_eager_compiles: int = 0
# The number of time vLLM's compiler cache entry was updated
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)
@contextmanager
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
old = self.clone()
yield
for k, v in kwargs.items():
assert getattr(self, k) - getattr(old, k) == v, (
f"{k} not as expected, before it is {getattr(old, k)}"
f", after it is {getattr(self, k)}, "
f"expected diff is {v}"
)
compilation_counter = CompilationCounter()

View File

@@ -0,0 +1,332 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from unittest.mock import patch
import torch
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class CUDAGraphStat:
num_unpadded_tokens: int
num_padded_tokens: int
num_paddings: int
runtime_mode: str
class CUDAGraphLogging:
"""Aggregate and log cudagraph metrics"""
COLUMN_HEADERS = [
"Unpadded Tokens",
"Padded Tokens",
"Num Paddings",
"Runtime Mode",
"Count",
]
def __init__(
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
) -> None:
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
self.settings_header = (
"**CUDAGraph Config Settings:**\n\n"
f"- Mode: {self.cg_mode}\n"
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
"**CUDAGraph Stats:**\n\n"
)
def reset(self) -> None:
self.stats: list[CUDAGraphStat] = []
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str:
stats_counts = Counter(self.stats)
# Convert stats to rows of strings, in descending order of observed frequencies
rows = []
for stat, count in sorted(
stats_counts.items(), key=lambda item: item[1], reverse=True
):
rows.append(
[
str(stat.num_unpadded_tokens),
str(stat.num_padded_tokens),
str(stat.num_paddings),
stat.runtime_mode,
str(count),
]
)
# Calculate column widths (max of header and data)
col_widths = []
for i, header_text in enumerate(self.COLUMN_HEADERS):
max_width = len(header_text)
for row in rows:
max_width = max(max_width, len(row[i]))
col_widths.append(max_width)
table_header_list = [
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
]
table_header = "| " + " | ".join(table_header_list) + " |\n"
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
# Create data rows with proper alignment
data_rows = []
for row in rows:
formatted_row = [
str(val).ljust(width) for val, width in zip(row, col_widths)
]
data_rows.append("| " + " | ".join(formatted_row) + " |")
return (
self.settings_header
+ table_header
+ table_separator
+ "\n".join(data_rows)
+ "\n"
)
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
if not self.stats:
return
log_fn(self.generate_metric_table())
self.reset()
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: torch.cuda.CUDAGraph | None = None
output: Any | None = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: list[int] | None = None
@dataclasses.dataclass
class CUDAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
class CUDAGraphWrapper:
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the cudagraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for cudagraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform cudagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: CUDAGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(
self,
runnable: Callable[..., Any],
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
) -> None:
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
)
def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable
def weak_ref_tensors_with_intermediate(self, output):
if isinstance(output, IntermediateTensors):
intermediate_states = IntermediateTensors(
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
return intermediate_states
return weak_ref_tensors(output)
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode != self.runtime_mode
):
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
assert batch_descriptor is not None
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
batch_descriptor=batch_descriptor
)
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug(
"Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name,
entry.batch_descriptor,
)
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
# Join offloader's copy stream after forward to avoid
# unjoined stream error. The last layer's start_prefetch
# forks copy_stream, but wait_prefetch only happens in
# the next forward pass.
get_offloader().join_after_forward()
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
# output = weak_ref_tensors(output)
output = self.weak_ref_tensors_with_intermediate(output)
# here we always use weak ref for the output
# to save memory
# entry.output = weak_ref_tensors(output)
entry.output = self.weak_ref_tensors_with_intermediate(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}"
)
# Sync offloader before replay - ensures any external dependencies
# from pre-capture prefetches are satisfied.
get_offloader().sync_prev_onload()
entry.cudagraph.replay()
return entry.output

View File

@@ -0,0 +1,657 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import hashlib
import inspect
import os
import sys
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, TypeVar, overload
from unittest.mock import patch
import torch
import torch.nn as nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
from vllm.config import (
CompilationMode,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.config.compilation import DynamicShapesType
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .monitor import start_monitoring_torch_compile
if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap
try:
from torch._dynamo.package import SourceInfo
except ImportError:
# Fallback for old versions not supporting
SourceInfo = Any
logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=nn.Module)
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
a support_torch_compile decorator, but we don't want to
compile the class `cls` that inherits the parent class.
This only ignores compiling the forward of the class the
decorator is applied to.
If the parent has ignore_torch_compile but the child has
support_torch_compile, the child will still be compiled.
If the class has one or more submodules
that have support_torch_compile decorator applied, compile will
not be ignored for those submodules.
"""
setattr(cls, IGNORE_COMPILE_KEY, True)
return cls
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
return getattr(cls, IGNORE_COMPILE_KEY, False)
@overload
def support_torch_compile(
*,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
def support_torch_compile(
cls: type[_T] | None = None,
*,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[type[_T]], type[_T]] | type[_T]:
"""
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
- otherwise, it will raise an error.
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
`shape_invariants` is a function that gets compiled right before forward.
The function should have the torch._check calls that are needed to set
the relationships between different input sizes. For example:
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
This enforces constraints on the symbolic shapes without hardcoding
specific values. It is needed for some models to avoid data dependent
errors.
"""
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
# to avoid too much indentation for `_support_torch_compile`
if not hasattr(cls, "forward"):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
inferred_dynamic_arg_dims = dynamic_arg_dims
if inferred_dynamic_arg_dims is None:
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [
torch.Tensor,
torch.Tensor | None,
IntermediateTensors,
IntermediateTensors | None,
]:
inferred_dynamic_arg_dims[k] = 0
logger.debug(
("Inferred dynamic dimensions for forward method of %s: %s"),
cls,
list(inferred_dynamic_arg_dims.keys()),
)
if len(inferred_dynamic_arg_dims) == 0:
raise ValueError(
"No dynamic dimensions found in the forward method of "
f"{cls}. Please provide dynamic_arg_dims explicitly."
)
for k in inferred_dynamic_arg_dims:
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
cls,
inferred_dynamic_arg_dims,
mark_unbacked_dims,
enable_if,
shape_invariants,
)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
assert isinstance(cls, type)
return cls_decorator_helper(cls)
return cls_decorator_helper
def _model_hash_key(fn: Callable[..., Any]) -> str:
import vllm
sha256_hash = hashlib.sha256()
sha256_hash.update(vllm.__version__.encode())
sha256_hash.update(fn.__qualname__.encode())
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
return sha256_hash.hexdigest()
def _verify_source_unchanged(
source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
from .caching import _compute_code_hash, _compute_code_hash_with_content
file_contents = {}
for source in source_info.inlined_sources:
module = sys.modules[source.module]
file = inspect.getfile(module)
vllm_config.compilation_config.traced_files.add(file)
file_contents[file] = source.content
expected_checksum = _compute_code_hash_with_content(file_contents)
actual_checksum = _compute_code_hash(set(file_contents.keys()))
if expected_checksum != actual_checksum:
raise RuntimeError(
"Source code has changed since the last compilation. Recompiling the model."
)
def _support_torch_compile(
cls: type[_T],
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> type[_T]:
"""
A decorator to add support for compiling the forward method of a class.
"""
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
# support decorating multiple times
return cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWithNoGuardsWrapper
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
old_init = cls.__init__
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(
self: _T,
*,
vllm_config: VllmConfig | None = None,
prefix: str = "",
**kwargs: Any,
) -> None:
if vllm_config is None:
vllm_config = get_current_vllm_config()
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch
# it
sig = inspect.signature(old_init)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
if "prefix" in sig.parameters:
kwargs["prefix"] = prefix
old_init(self, **kwargs)
self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or _should_ignore_torch_compile(self.__class__)
or not enable_compile
)
if self.do_not_compile:
return
self._check_shape_invariants = shape_invariants
self.was_aot_compile_fn_loaded_from_disk = False
compilation_counter.num_models_seen += 1
self.compiled = False
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
def _mark_dynamic_inputs(
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED:
if is_torch_equal_or_newer("2.10.0"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else:
torch._dynamo.mark_dynamic(arg, dims)
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}."
)
if mark_unbacked_dims:
for k, dims in mark_unbacked_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
if is_torch_equal_or_newer("2.10.0"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
# enc-dec models where tensor shapes/types vary across invocations, preventing
# the capture of a single computational graph.
if is_forward_context_available() and get_forward_context().skip_compiled:
return self.forward(*args, **kwargs)
# if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning.
if getattr(self, "aot_compiled_fn", None) is not None:
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
"""
When using torch.compile in AOT mode, we store the cache artifacts
under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
The {hash} contains all of the factors except for the source files
being traced through, because we don't actually know which source
files to check at this point (before dynamo runs).
On loading we will actually look at the source files being traced
through. If any source file have changed (compared with the
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
from .caching import aot_compile_hash_factors
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
"torch_aot_compile",
hash_key,
)
rank = self.vllm_config.parallel_config.rank
dp_rank = self.vllm_config.parallel_config.data_parallel_index
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model")
try:
with (
set_current_vllm_config(self.vllm_config),
open(aot_compilation_path, "rb") as f,
):
start_monitoring_torch_compile(self.vllm_config)
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e:
if os.path.exists(aot_compilation_path):
if isinstance(e, EOFError):
message = "Compile cache file corrupted."
else:
message = str(e)
logger.warning(
"Compiling model again due to a load failure from %s, "
"reason: %s",
aot_compilation_path,
message,
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
if getattr(self, "aot_compiled_fn", None) is not None:
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
# Apply partition wrapper context for proper CUDA graph capture
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
if self.compiled:
assert (
not envs.VLLM_USE_AOT_COMPILE
or self.vllm_config.compilation_config.backend == "eager"
)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs(
self,
ds_type,
*args,
**kwargs,
)
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
original_code_object = self.original_code_object()
logger.debug("Start compiling function %s", original_code_object)
# we do not want tp delete the original code object entries since
# we depend on them now to look up cached compiled functions.
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.
# 1. the file containing the top-level forward function
self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call_
def patched_inline_call(self_: Any) -> Any:
code = self_.f_code
self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
# vllm skip guards anyways, setting this flag to False can improve
# compile time.
dynamo_config_patches = {}
try:
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
except AttributeError:
# Note: this config is not available in torch 2.6, we can skip
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
# Prepare backed_size_oblivious config patch if needed
fx_config_patches = {}
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0+
inductor_config_patches = {}
if is_torch_equal_or_newer("2.10.0"):
inductor_config_patches["assume_32bit_indexing"] = (
self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
)
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
torch._inductor.config.patch(**inductor_config_patches),
):
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
if self.vllm_config.compilation_config.backend == "eager":
logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False
if use_aot_compile:
from vllm.compilation.backends import set_on_compilation_complete
# store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
self.compiled = True
return output
# triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self: type[_T]) -> None:
if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save")
return
assert (
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
)
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
try:
os.makedirs(self._aot_cache_dir, exist_ok=True)
# File saving should be atomic, so we will save to a temporary location
# first. Should be upstreamed to PyTorch 2.12 as well.
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
self.aot_compiled_fn.save_compiled_function(tmp_file)
os.replace(tmp_file, self._aot_compilation_path)
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
except Exception as e:
logger.warning(
"unable to save AOT compiled function to %s: %s",
self._aot_compilation_path,
e,
)
cls.__call__ = __call__
cls.save_aot_compiled_function = save_aot_compiled_function
return cls
@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(
vllm_config: VllmConfig,
) -> Generator[None, None, None]:
"""
Context manager to set/unset customized cudagraph partition wrappers.
If we're using Inductor-based graph partitioning, we currently have the
whole `fx.Graph` before Inductor lowering and the piecewise
splitting happens after all graph passes and fusions. Here, we add
a custom hook for Inductor to wrap each partition with our static
graph wrapper class to maintain more control over static graph
capture and replay.
"""
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition
):
from torch._inductor.utils import CUDAGraphWrapperMetadata
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.platforms import current_platform
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls()
)
def customized_cudagraph_wrapper(
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
) -> Any:
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
runnable=f,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
),
)
torch._inductor.utils.set_customized_partition_wrappers(
customized_cudagraph_wrapper
)
yield
if (
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and compilation_config.use_inductor_graph_partition
):
torch._inductor.utils.set_customized_partition_wrappers(None)

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
context_manager = None
torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
global torch_compile_start_time
torch_compile_start_time = time.perf_counter()
compilation_config: CompilationConfig = vllm_config.compilation_config
path = vllm_config.compile_debug_dump_path()
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
import depyf
path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path)
global context_manager
context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__()
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
compilation_config: CompilationConfig = vllm_config.compilation_config
total_compile_time: float = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once(
"torch.compile takes %.2f s in total",
total_compile_time,
scope="local",
)
global context_manager
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None
cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled() -> None:
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
if not cudagraph_capturing_enabled:
raise RuntimeError(
"CUDA graph capturing detected at an inappropriate "
"time. This operation is currently disabled."
)
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Generator
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
"""
Check if a node should be split for dynamo graph partition.
It operates on dynamo graph, so the node.target can be anything.
We need to check and split only on OpOverload and OpOverloadPacket.
"""
if node.op != "call_function":
return False
target = node.target
if isinstance(target, torch._ops.OpOverloadPacket):
# Example: "aten::add"
return target._qualified_op_name in splitting_ops
if isinstance(target, torch._ops.OpOverload):
# Example: "aten::add"
packet_name = target.name()
# Example: "aten::add.default"
op_overload_name = f"{packet_name}.{target._overloadname}"
return op_overload_name in splitting_ops or packet_name in splitting_ops
return False
@contextlib.contextmanager
def inductor_partition_rule_context(
splitting_ops: list[str] | None,
) -> Generator[None, None, None]:
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Args:
splitting_ops: List of operator names to partition on.
"""
if not splitting_ops:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
# Save current state before registering
saved_splitting_ops: list[str] = list(
torch._inductor.config.custom_should_partition_ops
)
torch._inductor.config.custom_should_partition_ops = splitting_ops
logger.debug(
"Registered inductor partition rules for %d operators", len(splitting_ops)
)
try:
yield
finally:
# Clear and restore previous state
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
logger.debug("Restored previous partition rules state.")

View File

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
)
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
FUSED_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
}
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
quant_key: QuantKey,
) -> None:
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
assert self.quant_key in FUSED_OPS, (
f"unsupported fusion scheme {self.quant_key}"
)
self.FUSED_OP = FUSED_OPS[self.quant_key]
self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self) -> None:
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.inputs()[1]
return [
*self.silu_and_mul_matcher.inputs(), # input
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
def replacement(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP, result=result, input=input, scale=scale
)
return at[1]
inps = self.get_inputs()
pattern(*inps)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self) -> None:
super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
output_scale = empty_i32(128, 4)
input_ = empty_bf16(5, 64)
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
output=result,
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
is_sf_swizzled_layout=True,
)
return at[1], at[2]
def replacement(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = auto_functionalized(
self.FUSED_OP,
result=result,
result_block_scale=output_scale,
input=input,
input_global_scale=scale,
)
return at[1], at[2]
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass"
)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
)

View File

@@ -0,0 +1,862 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from importlib.util import find_spec
from types import ModuleType
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"):
try:
import flashinfer.comm as _flashinfer_comm
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
_flashinfer_comm, "create_allreduce_fusion_workspace"
):
flashinfer_comm = _flashinfer_comm
except ImportError:
pass
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
90: {
2: 64, # 64MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 64, # 64MB
4: 32, # 32MB
8: 1, # 1MB
},
}
# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
90: {
2: 32, # 32MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 32, # 32MB
4: 4, # 4MB
8: 1, # 1MB
},
}
if flashinfer_comm is not None:
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
destroy_fi_ar_workspace,
get_fi_ar_quant_workspace,
get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
)
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
MiB = 1024 * 1024
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_size: int,
launch_with_pdl: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_tensor_size = max_token_num * hidden_size * element_size
assert current_tensor_size <= max_tensor_size, (
f"Current tensor size {current_tensor_size} is larger than "
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
curr_device = current_platform.get_device_capability()
device_capability = curr_device.to_int() if curr_device is not None else None
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, # type: ignore[arg-type, unused-ignore]
{},
).get(world_size, None)
# Use one shot if no max size is specified
use_oneshot = (
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
# Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
):
workspace = get_fi_ar_quant_workspace()
else:
workspace = get_fi_ar_workspace()
assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer"
)
assert flashinfer_comm is not None
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
layout_code = None
# layout_code only supported by trtllm backend
if workspace.backend == "trtllm":
# in vllm we only support swizzled layout
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
flashinfer_comm.allreduce_fusion(
input=allreduce_in,
workspace=workspace,
pattern=pattern_code,
launch_with_pdl=launch_with_pdl,
output=None,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
residual_in=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor,
layout_code=layout_code,
use_oneshot=use_oneshot,
fp32_acc=fp32_acc,
)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_size: int,
launch_with_pdl: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
) -> None:
pass
direct_register_custom_op(
op_name="flashinfer_trtllm_fused_allreduce_norm",
op_func=call_trtllm_fused_allreduce_norm,
mutates_args=[
"allreduce_in",
"residual",
"norm_out",
"quant_out",
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
world_size: int,
max_token_num: int = 1024,
) -> None:
self.world_size = world_size
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
return {
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
}
# TODO(luka): unify
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class AllReduceRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
with fused flashinfer implementation.
Applies to allreduce + rmsnorm before attn in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight)
return rms, allreduce_output
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# rms_result, allreduce_in
return allreduce[3], allreduce[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# allreduce_in, residual
return allreduce[1], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
),
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output
return allreduce[4], allreduce[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
return quant, res
def replacement(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
),
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual
return allreduce[4], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [input, quant_result, weight, input_global_scale, output_scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
def replacement(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
),
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output, output_scale
return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], residual, quant_out_tuple[2]
def replacement(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
# We don't use norm_out afterwards
pattern_code=(
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
),
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual, output_scale
return allreduce[4], allreduce[2], allreduce[5]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
if config.model_config is None:
logger.warning_once(
"AllReduce fusion pass is disabled for missing model_config."
)
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
rank = get_tensor_model_parallel_rank()
if flashinfer_comm is None:
logger.warning(
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass"
)
return
max_size = config.compilation_config.pass_config.flashinfer_max_size(
self.tp_size
)
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s"
" or max size is not provided",
self.tp_size,
)
return
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
self.max_token_num = max_size // (self.hidden_dim * element_size)
# take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways
self.max_token_num = min(
self.max_token_num, config.scheduler_config.max_num_batched_tokens
)
logger.debug_once(
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
"Maximal number of tokens used by "
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
scope="global",
)
for workspace_init_fn in [
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
)
return
self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size,
max_token_num=self.max_token_num,
)
self.register_patterns()
self.dump_patterns(config, self.patterns)
@enable_fake_mode
def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]:
if supports_quantization:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def __del__(self) -> None:
if getattr(self, "disabled", True):
return
with contextlib.suppress(Exception):
destroy_fi_ar_workspace()

View File

@@ -0,0 +1,374 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from ..fx_utils import is_func
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionQuantPattern(ABC):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
) -> None:
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
) -> None:
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale,
is_sf_swizzled_layout=True,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
return output, at2[2]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
empty_i32(
128, round_up(self.num_heads * self.head_size // 16, 4)
), # output_scale
empty_fp32(1, 1), # input_scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttnFusionPass(VllmPatternMatcherPass):
"""
This pass fuses post-attention quantization onto attention if supported.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
Currently, only static fp8 quant is supported, but patterns could easily be
added for other quant schemes and dtypes. The bigger hurdle for wider
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype
)
pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(
layer, config.model_config.dtype
)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern,
)

View File

@@ -0,0 +1,423 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
gather_dim=0,
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
scaled_mm = torch.ops.aten._scaled_mm.default(
input,
mat2=mat2,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
scaled_mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
return torch.ops.aten._scaled_mm.default(
all_gather,
mat2=weight,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=cutlass_mm_output,
a=input,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
cutlass_scaled_mm[1],
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
s2 = weight.shape[1]
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=output,
a=all_gather,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
return cutlass_scaled_mm[1]
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass"
)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

View File

@@ -0,0 +1,472 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRotaryEmbedding(MatcherCustomOp):
def __init__(
self,
is_neox: bool,
head_size: int,
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
match_rocm_aiter: bool | None = None,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
if match_rocm_aiter is None:
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
super().__init__(enabled)
self.is_neox = is_neox
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_size
self.kv_size = self.num_kv_heads * self.head_size
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
elif match_rocm_aiter:
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
else:
self.rotary_op = ROTARY_OP
def inputs(self) -> list[torch.Tensor]:
positions = self.empty_int64(5)
query = self.empty(5, self.q_size)
key = self.empty(5, self.kv_size)
cos_sin_cache = self.empty(4096, self.rotary_dim)
return [positions, query, key, cos_sin_cache]
def forward_custom(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result = auto_functionalized(
self.rotary_op,
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
)
query_out = result[1]
key_out = result[2] if len(result) > 2 else None
return query_out, key_out
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result: tuple[torch.Tensor, torch.Tensor | None] = (
RotaryEmbedding.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox,
)
)
return result
class MatcherRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self.match_rocm_aiter = match_rocm_aiter
self._rmsnorm_op = RMS_ADD_OP
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op( # type: ignore[no-any-return]
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight, residual)
_, result, residual = auto_functionalized(
self._rmsnorm_op,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
return result
class MatcherQuantFP8(MatcherCustomOp):
def __init__(
self,
quant_key: QuantKey,
enabled: bool | None = None,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
match_rocm_aiter: bool = False,
is_tma_aligned: bool = False,
) -> None:
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
self.match_rocm_aiter = match_rocm_aiter
self.is_tma_aligned = is_tma_aligned
if match_rocm_aiter:
assert not quant_key.scale.group_shape.is_per_tensor(), (
"ROCm aiter fusion pass does not support per tensor quantization"
)
if quant_key.scale.group_shape.is_per_token():
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
else:
assert quant_key.scale.group_shape.col == 128, (
"ROCm aiter fusion pass currently supports "
"quantization operation with group_size 128"
)
if current_platform.is_fp8_fnuz():
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
else:
self.QUANT_OP = (
torch.ops.vllm.triton_per_token_group_quant_fp8.default
)
else:
assert quant_key in QUANT_OPS, (
f"unsupported quantization scheme {quant_key}"
)
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
tma_aligned_scales=self.is_tma_aligned,
compile_native=False,
)
def forward_rocm_aiter(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP( # type: ignore[no-any-return]
x=input,
quant_dtype=self.quant_key.dtype,
scale=scale,
)
else:
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, scale)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.group_shape.is_per_group():
# for tma_aligned, the scale must be passed to forward_custom
# tma_aligned fusion then matches by custom op arguments
if not self.is_tma_aligned:
assert scale is None
scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.QUANT_OP,
input=input,
output_q=result,
output_s=scale,
group_size=self.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
if transposed:
scale_shape = tuple(reversed(scale_shape))
return torch.empty(
scale_shape, device=input.device, dtype=torch.float32
).permute(-1, -2)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]
class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None) -> None:
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 4)
return [input]
def forward_custom(
self,
x: torch.Tensor,
) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
return result[1]
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
return SiluAndMul.forward_native(x)

View File

@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
P = ParamSpec("P")
class QkNormRopePattern:
"""
Match the unfused sequence in attention blocks and replace with the fused op.
Unfused (conceptually):
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
qh = reshape(q, [-1, num_heads, head_dim])
kh = reshape(k, [-1, num_kv_heads, head_dim])
qn = rms_norm(qh, q_weight, eps)
kn = rms_norm(kh, k_weight, eps)
qf = reshape(qn, [-1, num_heads * head_dim])
kf = reshape(kn, [-1, num_kv_heads * head_dim])
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
return qf, kf, v
Fused replacement:
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
eps, q_weight, k_weight, cos_sin_cache, is_neox,
positions.view(-1))
return split(qkv, [qsz, kvsz, kvsz], -1)
"""
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
eps: float,
is_neox: bool,
rope_flashinfer: bool = False,
) -> None:
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=is_neox,
head_size=self.head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
positions = empty_i64(T)
q_weight = empty_bf16(1, self.head_dim)
k_weight = empty_bf16(1, self.head_dim)
if self.rope_flashinfer:
cos_sin_cache = empty_fp32(4096, self.head_dim)
else:
cos_sin_cache = empty_bf16(4096, self.head_dim)
return [
qkv,
positions,
q_weight,
k_weight,
cos_sin_cache,
]
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q path: view -> RMS -> view back to q.shape
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
return q_rope, k_rope, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
qkv=qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.eps,
q_weight=q_weight,
k_weight=k_weight,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
position_ids=positions.view(-1),
)
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
QkNormRopePattern.wrap_trace_fn(
pm.fwd_only,
QkNormRopePattern.fx_view_to_reshape,
),
pm_pass,
)
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
)
dtype = config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.warning_once(
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
)
return
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
if len(attn_layers) == 0:
logger.warning_once(
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, QkNormRopePattern)

View File

@@ -0,0 +1,643 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, NamedTuple
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
"""
Named tuple for identifying the type of RMSNorm + quant fusion.
quant: type of quantization
fused_add: does the op also perform the residual add
"""
quant: QuantKey
fused_add: bool
def __str__(self) -> str:
return (
f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)"
)
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(
kFp8StaticTensorSym, False
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8StaticTensorSym, True
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, False
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, True
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic128Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, False
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8Dynamic64Sym, True
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
}
class RMSNormQuantPattern:
def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
is_tma_aligned: bool = False,
) -> None:
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
fused_key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass) -> None:
# Cannot use methods, as the self argument affects tracing
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result
return at[1]
inputs = [
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
return result, residual
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result, residual
return at[1], at[2]
inputs = [
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
)
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric: bool = True,
is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
self.is_e8m0 = is_e8m0
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__(
epsilon,
key,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, residual, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.has_col_major_scales,
)
# result, residual, scale
return at[1], at[3], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only,
pm_pass,
)
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric: bool = True,
is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__(
epsilon,
key,
has_col_major_scales=self.has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.has_col_major_scales,
)
# result, scale
return at[1], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only,
pm_pass,
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
)
# result, scale
return at[1], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
)
# result, residual, scale
return at[1], at[3], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
FusedAddRMSNormGroupQuantPattern,
)

View File

@@ -0,0 +1,504 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul,
)
from .rms_quant_fusion import (
FusedRMSQuantKey,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
class AiterRMSNormQuantPattern:
def __init__(
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
match_rocm_aiter=match_aiter_quant,
)
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1], result[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
AiterRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
for match_aiter_quant in [True, False]:
# Fuse aiter rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterFusedAddRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
fusion_patterns = [
AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def get_inputs(self) -> list[torch.Tensor]:
return [
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
class AddAiterRMSNormPadPattern:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
def __init__(
self,
epsilon: float,
hidden_size: int,
x_pad_to_multiple: int,
):
self.epsilon = epsilon
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
def get_inputs(self) -> list[torch.Tensor]:
input, weight, residual = self.rmsnorm_matcher.inputs()
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
return [input, weight, residual, router_weight, router_bias]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pad_size = self.x_pad_to_multiple - (
self.hidden_size % self.x_pad_to_multiple
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_rms, router_weight, router_bias
)
result = torch.nn.functional.pad(
result_rms, (0, pad_size), mode="constant", value=0.0
)
return result, residual_out, router_logits
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
residual=residual,
x_pad_to_multiple=self.x_pad_to_multiple,
)
result_padded = at[0]
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_padded[:, : self.hidden_size], router_weight, router_bias
)
residual_out = at[1]
return result_padded, residual_out, router_logits
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)
# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size = 2880
for epsilon in [1e-5, 1e-6]:
for x_pad_to_multiple in [128, 256]:
AddAiterRMSNormPadPattern(
epsilon, hidden_size, x_pad_to_multiple
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops import auto_functionalized
from torch._inductor.fx_passes.post_grad import view_to_reshape
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.attention import (
Attention,
get_attention_context,
)
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherRotaryEmbedding,
)
from .rms_quant_fusion import (
empty_bf16,
empty_i64,
)
logger = init_logger(__name__)
def fused_rope_and_unified_kv_cache_update_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
"""
This impl fetches the KV cache and slot mapping from the forward context,
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
attn_layer.impl.do_rope_and_kv_cache_update(
attn_layer,
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def fused_rope_and_unified_kv_cache_update_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
return torch.empty(0, device=query.device, dtype=query.dtype)
direct_register_custom_op(
op_name="fused_rope_and_unified_kv_cache_update",
op_func=fused_rope_and_unified_kv_cache_update_impl,
mutates_args=["query", "key"],
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
)
class RopeReshapeKVCachePattern:
"""
This pattern matches the following unfused inplace ops:
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
and replaces it with the fused inplace op:
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
q, k, v, positions, cos_sin_cache, is_neox, layer_name
)
"""
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
def __init__(
self,
layer: Attention,
is_neox: bool,
) -> None:
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.num_kv_heads = layer.num_kv_heads
self.head_size = layer.head_size
self.head_size_v = layer.head_size_v
self.is_neox = is_neox
self.q_size = self.num_heads * self.head_size
self.k_size = self.num_kv_heads * self.head_size
self.v_size = self.num_kv_heads * self.head_size_v
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=self.is_neox,
head_size=self.head_size,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
L = 4096
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
positions = empty_i64(T)
cos_sin_cache = empty_bf16(L, self.head_size)
return [
qkv,
positions,
cos_sin_cache,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
return dummy, q, k, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
results = auto_functionalized(
self.FUSED_OP,
query=q,
key=k,
value=v,
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=self.layer_name,
)
return results[0], results[1], results[2], v
# NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
gm = pm.fwd_only(*args, **kwargs)
view_to_reshape(gm)
return gm
pm.register_replacement(
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
)
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
"""
This pass fuses the rotary embedding and KV cache update operations
into a single fused kernel if available.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
This fusion eliminates the need for separate kernel launches and
intermediate memory operations between the RoPE and cache update steps.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rope_kv_cache_fusion_pass"
)
cc = config.compilation_config
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
attn_layers = get_layers_from_vllm_config(config, Attention)
for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported():
for is_neox in [True, False]:
RopeReshapeKVCachePattern(
layer=layer,
is_neox=is_neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass works best for the small-batch decode setting.
# For large-batch e.g. prefill, it is better to use two separate kernels
# since they are compute bound and the fused kernels require further tuning.
return compile_range.end <= self.max_token_num
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)

View File

@@ -0,0 +1,452 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable, Sequence
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
logger = init_logger(__name__)
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
90: 8192, # H100: only for models with hidden_size >= 8192
}
# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
90: 8, # 8MB per GPU for H100
}
def get_sequence_parallelism_threshold(
hidden_size: int,
tp_size: int,
element_size: int,
) -> int | None:
"""
Calculate the minimum token threshold for applying sequence parallelism.
Returns None if sequence parallelism should not be applied based on model size.
Branching logic based on device capability:
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
- If not, returns None (SP disabled for small models on this device)
- If yes, calculates threshold based on per-GPU size
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
(hidden_size * element_size)
"""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
return None
capability = current_platform.get_device_capability()
if capability is None:
return None
device_capability = capability.to_int()
# Check if device has configured thresholds
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
if min_hidden_size is None or min_per_gpu_size_mb is None:
return None
# Only apply sequence parallelism for models meeting the size threshold
if hidden_size < min_hidden_size:
return None
MiB = 1024 * 1024
min_size = min_per_gpu_size_mb * MiB * tp_size
return int(min_size // (hidden_size * element_size))
def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def wrapper(*args: Any) -> torch.Tensor:
return fn(*args)[0]
return wrapper
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
) -> None:
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
arg3_1: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
return rmsnorm[0], rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
residual = residual[0 : reduce_scatter.size(0), ...]
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
all_gather = self._all_gather(rmsnorm[0])
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
pm.register_replacement(
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rms, residual_out = self.rmsnorm_matcher(
all_reduce, rms_norm_weights, residual
)
quant, _ = self.quant_matcher(rms, scale)
return quant, residual_out
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# add a temporary slice which will become a noop
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
residual = residual[0 : reduce_scatter.size(0), ...]
rms, residual_out = self.rmsnorm_matcher(
reduce_scatter, rms_norm_weights, residual
)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, residual_out
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
pm.register_replacement(
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class SequenceParallelismPass(VllmPatternMatcherPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
an RMSNorm (or RMSNorm and then Quantization) operation.
These patterns are replaced with a ReduceScatter operation, followed by
a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is:
Input -> AllReduce -> RMSNorm -> Output
becomes
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements,
it lays the groundwork for subsequent fusion passes, such as
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Get min_token_num threshold
# Read min_token_num from config (calculated during config init)
self.min_token_num = None
if config.model_config is not None:
pass_config = config.compilation_config.pass_config
self.min_token_num = pass_config.sp_min_token_num
if self.min_token_num is not None:
# Take the min to avoid exceeding max_num_batched_tokens
max_batched = config.scheduler_config.max_num_batched_tokens
if max_batched is not None:
self.min_token_num = min(self.min_token_num, max_batched)
logger.debug_once(
f"Sequence parallelism min token threshold: {self.min_token_num}",
scope="global",
)
# Used to clean up redundant views created temporarily
# to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config)
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Determines if sequence parallelism should be applied for the given
compile range.
SP is only beneficial for larger batch sizes where the communication
overhead is amortized. For small batches, the overhead of splitting
and gathering tensors across TP ranks outweighs the benefits.
Returns False (SP disabled) when:
- Using piecewise compilation with non-concrete or TP-indivisible sizes
- min_token_num is None (SP disabled for this device/config)
- The compile range starts below the minimum token threshold
"""
# For piecewise compilation (not using inductor graph partition),
# we need concrete sizes that are divisible by TP for correct splitting
if (
not self.compilation_config.use_inductor_graph_partition
and self.compilation_config.splitting_ops
):
tp_size = get_tensor_model_parallel_world_size()
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
return False
# min_token_num is None when SP is disabled for this device/config
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
if self.min_token_num is None:
return False
# Only apply SP when batch size meets the minimum threshold
return compile_range.start >= self.min_token_num
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
self.noop_cleanup(graph)

View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx.node import Target
def is_func(node: fx.Node, target: Target) -> bool:
return bool(node.op == "call_function" and node.target == target)
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
return is_func(node, auto_functionalized) and node.args[0] == op
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
# An auto-functionalization-aware utility for finding nodes with a specific op
# Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:
yield n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
assert len(node.users) == 1
return next(iter(node.users))

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import functools
import hashlib
import inspect
import json
import types
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
if TYPE_CHECKING:
from vllm.config.utils import Range
from torch._inductor.custom_graph_pass import CustomGraphPass
_pass_context = None
P = ParamSpec("P")
R = TypeVar("R")
class PassContext:
def __init__(self, compile_range: Range):
self.compile_range: Range = compile_range
def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context
@contextmanager
def pass_context(compile_range: Range) -> Generator[None, None, None]:
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(compile_range)
try:
yield
finally:
_pass_context = prev_context
class InductorPass(CustomGraphPass): # type: ignore[misc]
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> str:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return InductorPass.hash_source(self)
@staticmethod
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if isinstance(src, str):
src_str = src
elif isinstance(src, (types.FunctionType, type)):
src_str = inspect.getsource(src)
else:
# object instance
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_range(self, compile_range: Range) -> bool:
return True
class CallableInductorPass(InductorPass):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""
def __init__(
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
) -> None:
self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph) -> None:
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from .fusion.rocm_aiter_fusion import (
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)
if current_platform.is_cuda_alike():
from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
from .utility.scatter_split_replace import ScatterSplitReplacementPass
from .utility.split_coalescing import SplitCoalescingPass
if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
def __init__(self) -> None:
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph) -> None:
VllmInductorPass.dump_prefix = 0 # reset dump index
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig) -> None:
self.pass_config = config.compilation_config.pass_config
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.eliminate_noops:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sp:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.fuse_gemm_comms:
self.passes += [AsyncTPPass(config)]
if self.pass_config.fuse_allreduce_rms:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [
RocmAiterRMSNormQuantFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_rope_kvcache:
self.passes += [SplitCoalescingPass(config)]
self.passes += [ScatterSplitReplacementPass(config)]
self.passes += [RopeKVCacheFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [SplitCoalescingPass(config)]
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass) -> None:
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self) -> str:
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
passes = []
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes:
passes.append(pass_.uuid())
passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
state["passes"] = passes
return InductorPass.hash_dict(state)

View File

@@ -0,0 +1,301 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class FixFunctionalizationPass(VllmInductorPass):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug(
"XPU platform does not support fix functionalizationpass currently."
)
return
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(
user_of_getitem, torch.ops.aten.slice_scatter.default
):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: "query", 2: "key"}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: "input", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "scale", 3: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
and at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input")
)
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input", "scale")
)
elif (
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
):
mutated_args = {1: "result", 2: "result_block_scale"}
self.defunctionalize(
graph,
node,
mutated_args,
args=(
"result",
"result_block_scale",
"input",
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
elif (
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
and at_target
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
):
mutated_args = {
1: "query",
2: "key",
}
self.defunctionalize(graph, node, mutated_args=mutated_args)
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
elif (
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
and at_target
== torch.ops.vllm.function_with_mutated_args_and_return.default
):
mutated_args = {1: "x"}
self.defunctionalize(graph, node, mutated_args=mutated_args)
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if isinstance(node_or_nodes, torch.fx.Node):
self.nodes_to_remove.append(node_or_nodes)
else:
self.nodes_to_remove.extend(node_or_nodes)
def defunctionalize(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: dict[int, torch.fx.Node | str],
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self.replace_users_with_mutated_args(node, mutated_args)
self.insert_defunctionalized(graph, node, args=args)
self._remove(node)
def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
) -> None:
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
# as well as mutated args at getitem[1:...]
if idx == 0:
assert idx not in mutated_args, (
f"result at getitem[0] should not be in mutated_args for {node}"
)
continue
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users = {}
for user in node.users:
if is_func(user, operator.getitem):
idx = user.args[1]
users[idx] = user
return users
def insert_defunctionalized(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
f"node must be auto-functionalized, is {node} instead"
)
# Create a new call to the original function
with graph.inserting_before(node):
function = node.args[0]
if args is None:
fn_node = graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
)
fn_node = graph.call_function(function, args=args)
# If the function returns a value as well as mutating args inplace,
# the functionalized node will have a getitem[0] user that holds this value
# Replace getitem[0] user of the auto-functionalized node
# with the new defunctionalized node directly if it exists
users = self.getitem_users(node)
if 0 in users:
user = users[0]
user.replace_all_uses_with(fn_node)
self._remove(user)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
Example graph 1:
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
Can be replaced with:
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
Example graph 2:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 3:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# remove reshape/slice if it produces the original shape
if is_func(node, torch.ops.aten.reshape.default) or is_func(
node, torch.ops.aten.slice.Tensor
):
input = node.args[0]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape
if self.all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
if self.all_dims_equivalent(base_shape, view_shape):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
# ---------------------- Shape comparison helpers ----------------------
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The dimensions both correspond to the same SymInt
"""
# Case 1
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
) -> bool:
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
# Different ranks can't be equivalent
return False
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from ..vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
assignment if there are no users for the inplace tensor written to by
the slice_scatter call.
The inplace rotary_embedding custom op takes in mutable query and key inputs
that are split+getitem outputs of a single qkv tensor.
When functionalized, we fetch the rotated query and key from the functionalized op
using `getitem` calls. However, we also write to the qkv tensor inplace using a
`slice_scatter`, then split the inplace tensor to get the output tensors again.
Instead, if the inplace tensor has no subsequent users, we can just replace the
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
This is already done in fix_functionalization::FixFunctionalizationPass, but
writing a custom pass for it before defunctionalization allows matching against the
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
"""
import operator
import torch
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class ScatterSplitReplacementPass(VllmInductorPass):
"""Replace getitem+slice_scatter+split nodes with a single getitem when
the inplace subtensor written to by the slice_scatter has no other users.
Here's an example graph with q_size = 512, kv_size = 64:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
q = operator.getitem(split_with_sizes_2, 0)
k = operator.getitem(split_with_sizes_2, 1)
v = operator.getitem(split_with_sizes_2, 2)
After this pass, this sequence of nodes is replaced with:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
v = operator.getitem(split_with_sizes_1, 2)
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
target_ops = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue
kwargs = node.kwargs
at_target = node.args[0]
if at_target in target_ops:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}
for user in node.users:
if is_func(user, operator.getitem):
getitem_nodes[user.args[1]] = user
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of a qkv tensor.
# While functionalized, results at [1] and [2] are scattered
# back into qkv, then split again to get query and key.
# If the inplace tensor has no other users, we can replace
# the slice_scatter+split nodes with the original results.
for user in getitem_nodes[1].users:
slice_scatter_1_node = user
if not is_func(
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
):
continue
for user in getitem_nodes[2].users:
slice_scatter_2_node = user
if not is_func(
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
):
continue
for user in slice_scatter_2_node.users:
split_node = user
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
continue
split_getitem_users = {}
for user in split_node.users:
if is_func(user, operator.getitem):
split_getitem_users[user.args[1]] = user
# Replace query node
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
graph.erase_node(split_getitem_users[0])
# Replace key node
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
graph.erase_node(split_getitem_users[1])
# Redirect value node to original qkv tensor
split_getitem_users[2].replace_input_with(split_node, query.args[0])
# Erase unused nodes
graph.erase_node(split_node)
graph.erase_node(slice_scatter_2_node)
graph.erase_node(slice_scatter_1_node)
count += 1
logger.debug("Eliminated %d slice_scatter+split nodes", count)

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
input tensor with the same split sizes.
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
graph may contain multiple ``split_with_sizes`` calls on the same tensor
that CSE fails to merge. This pass detects and replaces the duplicates
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
see a single split node with all users attached.
See also:
- vLLM #33295 (original issue)
- PyTorch #174472 (upstream CSE gap)
"""
import operator
import torch
from torch import fx
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class SplitCoalescingPass(VllmInductorPass):
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
node when they share the same input tensor and split sizes."""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
# Map from input tensor node -> list of split nodes seen so far.
split_nodes: dict[fx.Node, list[fx.Node]] = {}
for node in graph.nodes:
if not is_func(node, torch.ops.aten.split_with_sizes.default):
continue
if not all(is_func(user, operator.getitem) for user in node.users):
continue
arg_node, split_sizes = node.args[:2]
if arg_node not in split_nodes:
split_nodes[arg_node] = [node]
continue
# Find existing node with same split_sizes
canonical = next(
(
n
for n in split_nodes[arg_node]
if list(n.args[1]) == list(split_sizes)
),
None,
)
if canonical is not None:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
count += 1
else:
split_nodes[arg_node].append(node)
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import operator
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar
import regex as re
import torch
from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .inductor_pass import InductorPass
logger = init_logger(__name__)
@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False
class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities.
"""
dump_prefix: ClassVar[int | None] = None
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig):
# Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.compilation_config.splitting_ops,
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
)
self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device: str | None = (
config.device_config.device if config.device_config else None
)
self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
self.dump_graph(graph, "after")
self.end_and_log()
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
)
def begin(self) -> None:
self._start_time = time.perf_counter_ns()
def end_and_log(self) -> None:
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class VllmPatternMatcherPass(VllmInductorPass):
"""
A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
"""
matched_count: int = 0
"""The number of matched patterns in the pass."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
)
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return str(
self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
This method does its best to print something that looks like Python code
for easier debugging and potentially navigation. If any errors appear in
the output, please add to this method.
TODO(luka): use pattern object to manually produce pattern graph
"""
debug_dump_path = config.compile_debug_dump_path()
if not debug_dump_path:
return
debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils.system_utils import unique_filepath
file_path = unique_filepath(
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
)
with file_path.open("w") as f:
print(
f"# This file was produced by VllmPatternMatcherPass."
f"dump_patterns for {self.pass_name}.\n"
f"# It does its best to produce valid-Python-looking code but"
f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f,
)
for node, patterns in pm_pass.patterns.items():
# fix the operator.getitem repr
if node[1] == operator.getitem:
node_repr = f"({repr(node[0])}, operator.getitem)"
else:
node_repr = repr(node)
node_repr = self._replace_op_overloads(node_repr)
print(f"\n\n# Patterns for op: {node_repr}", file=f)
for i, pattern in enumerate(patterns):
# reserve auto_functionalized ahead of time
pp = PatternPrettyPrinter()
pp.namespace.create_name("auto_functionalized", None)
# Assemble pattern
out_node = pp.pretty_print(pattern.pattern)
pattern_repr = "\n".join(
[f"def pattern_{i}():"]
+ [
f"{pp.memoized_objs_names[key]} = "
f"{pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
]
+ [f"return {out_node}"]
).replace("\n", "\n ")
pattern_repr = self._replace_op_overloads(pattern_repr)
print(f"{pattern_repr}\n", file=f)
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig) -> None:
super().__init__(config)
self.name = name
def __call__(self, graph: torch.fx.Graph) -> None:
self.dump_graph(graph, self.name)

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import io
import json
import pickle
import time
from collections.abc import Callable
from pickle import Pickler
from typing import Any
import torch._functorch.config
import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclasses.dataclass
class RangeEntry:
compile_range: Range
compiled: bool = False
runnable: Callable[..., Any] = None # type: ignore
class PiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule | None,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
submod_name: str = "",
):
"""
The backend for piecewise compilation.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
This class supports two mutually exclusive modes:
1. Compilation (graph is set, compiled_runnables is None):
Used during initial compilation when we have the FX graph
and need to compile it for each shape range.
2. Precompilation (graph is None, compiled_runnables is set):
Used when loading from cache/AOT artifacts where we already
have pre-compiled callables and don't need the original graph.
Exactly one of graph or compiled_runnables must be provided.
"""
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
"exactly one of graph and compiled_runnables should be set."
)
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables
self.submod_name = submod_name
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
self.is_encoder_compilation = vllm_backend.is_encoder
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (
last_compile_range.end
== vllm_config.scheduler_config.max_num_batched_tokens
)
self.compile_ranges[-1] = Range(
start=last_compile_range.start, end=max_int32
)
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
self.returns_tuple = returns_tuple
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
if self.compile_sizes is not None:
for size in self.compile_sizes:
if isinstance(size, str):
assert size == "cudagraph_capture_sizes"
raise NotImplementedError(
"cudagraph_capture_sizes not supported in compile_sizes."
"This should be handled in `post_init_cudagraph_sizes`."
)
else:
assert isinstance(size, int)
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
# Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
from vllm.compilation.backends import _on_compilation_complete_callback
self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper(
self, compiled_graph: Callable[..., Any]
) -> Callable[..., Any]:
def compiled_graph_wrapper(*args: Any) -> Any:
graph_output = compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
time_before_saving = time.perf_counter()
self.vllm_backend.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None:
self.on_compilation_complete()
def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj: object) -> Any:
if isinstance(obj, CachingAutotuner):
obj.prepare_for_pickle()
return pickle.loads, (
pickle.dumps(
obj,
),
)
return NotImplemented
def serialize(fn: Callable[..., Any]) -> bytes:
assert hasattr(fn, "serialize"), "fn must have serialize method"
with torch._functorch.config.patch("bundled_autograd_cache", True):
entry = fn.serialize()
f = io.BytesIO()
StandaloneCompiledArtifactsPickler(f).dump(entry)
result = f.getvalue()
return result
out = {}
for range_key, entry in self.range_entries.items():
if not entry.compiled:
logger.debug(
"entry with range %s not compiled, so cannot get its bytes",
range_key,
)
continue
if hasattr(entry.runnable, "serialize"):
out[str(range_key)] = serialize(entry.runnable)
return out
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
# matching their values.
# This is problem because it can lead to unintended specializations!
# if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
# torch.compile probably should not accept
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
assert self.graph is not None
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse."""
is_cudagraph_size = (
self.compile_sizes is not None and compile_range.start in self.compile_sizes
)
subgraph_index = self.piecewise_compile_index
submod_name = self.submod_name
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_piecewise_compile_start",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"piecewise_index": subgraph_index,
"submod_name": submod_name,
"total_piecewise_compiles": self.total_piecewise_compiles,
"compile_range_start": compile_range.start,
"compile_range_end": compile_range.end,
"is_single_size": compile_range.is_single_size(),
"is_cudagraph_capture_size": is_cudagraph_size,
}
),
)
# Log the subgraph graph dump only once per subgraph (not per size)
# to reduce log file size. The graph code is the same for all sizes.
if not self._graph_logged:
self._graph_logged = True
assert self.graph is not None
trace_structured(
"graph_dump",
metadata_fn=lambda: {
"name": f"vllm_{submod_name}",
},
payload_fn=lambda: self.graph.print_readable(print_output=False),
)
def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
if not range_entry.compiled:
if self.compiled_runnables is not None:
range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args_list = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else list(args)
)
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if self.compile_sizes is None:
return None
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args: Any) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
)
self._maybe_compile_for_range_entry(range_entry, args)
return range_entry.runnable(*args)

321
vllm/compilation/wrapper.py Normal file
View File

@@ -0,0 +1,321 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
from abc import abstractmethod
from collections.abc import Callable, Generator
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any, ParamSpec, TypeVar
import torch
import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__)
R = TypeVar("R")
P = ParamSpec("P")
def _noop_add_global_state_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass
def _noop_add_torch_function_mode_stack_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass
@contextmanager
def _compilation_context() -> Generator[None, None, None]:
"""Context manager for compilation settings and patches.
This manager:
1. Sets higher dynamo cache limits for compilation. (Needed for
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
Generally a recompilation can happen whenever we use a new
backend instance in torch.compile.
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
3. Patches out add_torch_function_mode_stack_guard to skip
TORCH_FUNCTION_MODE_STACK guards.
4. Restores everything when compilation completes
"""
# Save original values
original_global_state_guard = (
torch._C._dynamo.guards.GuardManager.add_global_state_guard
)
original_torch_function_mode_stack_guard = (
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
)
original_cache_size = torch._dynamo.config.cache_size_limit
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
try:
# Set higher cache limits for compilation
torch._dynamo.config.cache_size_limit = 2048
torch._dynamo.config.accumulated_cache_size_limit = 8192
# Patch guard manager
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
_noop_add_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
_noop_add_torch_function_mode_stack_guard
)
yield
finally:
# Restore original values
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
original_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
original_torch_function_mode_stack_guard
)
torch._dynamo.config.cache_size_limit = original_cache_size
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
class TorchCompileWithNoGuardsWrapper:
"""
A wrapper class for torch.compile, it ensures that all guards are dropped
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
When guards are dropped, the first time __call__ is invoked, a single
compilation is triggered. Dynamo should never be traced again after that
since we drop all guards.
"""
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
assert hasattr(self, "_check_shape_invariants")
self._check_shape_invariants(*args, **kwargs)
return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Any:
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
with layerwise_nvtx_marker_context(
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
self,
in_tensor=args_list,
kwargs=kwargs_dict,
) as ctx:
ctx.result = callable_fn(*args, **kwargs)
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self) -> None:
self.compiled = False
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode
self.layerwise_nvtx_tracing_enabled = (
vllm_config.observability_config.enable_layerwise_nvtx_tracing
)
if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
backend = vllm_config.compilation_config.init_backend(vllm_config)
options = {}
if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config
self.first_compile = True
self.evaluate_guards = (
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
)
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
if self.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if ds_type == DynamicShapesType.UNBACKED:
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
compiled_ptr = self.check_invariants_and_forward
aot_context = nullcontext()
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
else:
msg = "torch._dynamo.config.enable_aot_compile is not "
msg += "available. AOT compile is disabled and please "
msg += "upgrade PyTorch version to use AOT compile."
logger.warning(msg)
with aot_context:
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
self._compiled_bytecode: CodeType | None = None
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self._compiled_callable, "aot_compile"):
raise RuntimeError(
"aot_compile is not supported by the current configuration. "
"Please make sure torch.compile is enabled with the latest "
f"version of PyTorch (current using torch: {torch.__version__})"
)
return self._compiled_callable.aot_compile((args, kwargs))
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if envs.VLLM_USE_BYTECODE_HOOK:
if (
self.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
return self._compiled_callable(*args, **kwargs)
if not self._compiled_bytecode:
# Make sure a compilation is triggered by clearing dynamo
# cache.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
else:
with self._dispatch_to_compiled_code():
return self._call_with_optional_nvtx_range(
self.forward, *args, **kwargs
)
else:
ctx = (
nullcontext()
if self.first_compile or not self.evaluate_guards
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
@abstractmethod
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
def original_code_object(self) -> CodeType:
"""Return the original code object of the forward method."""
return self.__class__.forward.__code__
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object():
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while frame and frame.f_back:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code
if frame.f_locals["self"] is not self:
return
self._compiled_bytecode = new_code
path = self.vllm_config.compile_debug_dump_path()
if path:
decompiled_file = path / "transformed_code.py"
if not decompiled_file.exists():
try:
# usually the decompilation will succeed for most models,
# as we guarantee a full-graph compilation in Dynamo.
# but there's no 100% guarantee, since decompliation is
# not a reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
except Exception:
pass
if (
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and "update" in new_code.co_names
):
import depyf
src = depyf.decompile(new_code)
msg = (
"Assigning / modifying buffers of nn.Module during forward pass is not "
"allowed when using cudagraph inside the compiler because it will "
"cause silent errors. Please use eager mode or fix the code. The "
"following code contains clues about which buffer is being modified "
f"(please search for the usage of the function `update`):\n{src}"
)
raise RuntimeError(msg)
@contextmanager
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
# noqa: E501
"""
Context manager to dispatch to internally compiled code for torch<2.8.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa: E501 line too long
original = self.original_code_object()
assert self._compiled_bytecode is not None
self.__class__.forward.__code__ = self._compiled_bytecode
try:
yield
finally:
self.__class__.forward.__code__ = original

130
vllm/config/__init__.py Normal file
View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.attention import AttentionConfig
from vllm.config.cache import CacheConfig
from vllm.config.compilation import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
PassConfig,
)
from vllm.config.device import DeviceConfig
from vllm.config.ec_transfer import ECTransferConfig
from vllm.config.kernel import KernelConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
from vllm.config.model import (
ModelConfig,
iter_architecture_defaults,
str_dtype_to_torch_dtype,
try_match_architecture_defaults,
)
from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.offload import (
OffloadBackend,
OffloadConfig,
PrefetchOffloadConfig,
UVAOffloadConfig,
)
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.config.structured_outputs import StructuredOutputsConfig
from vllm.config.utils import (
ConfigType,
SupportsMetricsInfo,
config,
get_attr_docs,
is_init_field,
replace,
update_config,
)
from vllm.config.vllm import (
VllmConfig,
get_cached_compilation_config,
get_current_vllm_config,
get_current_vllm_config_or_none,
get_layers_from_vllm_config,
set_current_vllm_config,
)
from vllm.config.weight_transfer import WeightTransferConfig
# __all__ should only contain classes and functions.
# Types and globals should be imported from their respective modules.
__all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache
"CacheConfig",
# From vllm.config.compilation
"CompilationConfig",
"CompilationMode",
"CUDAGraphMode",
"PassConfig",
# From vllm.config.device
"DeviceConfig",
# From vllm.config.ec_transfer
"ECTransferConfig",
# From vllm.config.kernel
"KernelConfig",
# From vllm.config.kv_events
"KVEventsConfig",
# From vllm.config.kv_transfer
"KVTransferConfig",
# From vllm.config.load
"LoadConfig",
# From vllm.config.lora
"LoRAConfig",
# From vllm.config.model
"ModelConfig",
"iter_architecture_defaults",
"str_dtype_to_torch_dtype",
"try_match_architecture_defaults",
# From vllm.config.multimodal
"MultiModalConfig",
# From vllm.config.observability
"ObservabilityConfig",
# From vllm.config.offload
"OffloadBackend",
"OffloadConfig",
"PrefetchOffloadConfig",
"UVAOffloadConfig",
# From vllm.config.parallel
"EPLBConfig",
"ParallelConfig",
# From vllm.config.pooler
"PoolerConfig",
# From vllm.config.scheduler
"SchedulerConfig",
# From vllm.config.speculative
"SpeculativeConfig",
# From vllm.config.speech_to_text
"SpeechToTextConfig",
# From vllm.config.structured_outputs
"StructuredOutputsConfig",
# From vllm.config.profiler
"ProfilerConfig",
# From vllm.config.utils
"ConfigType",
"SupportsMetricsInfo",
"config",
"get_attr_docs",
"is_init_field",
"replace",
"update_config",
# From vllm.config.vllm
"VllmConfig",
"get_cached_compilation_config",
"get_current_vllm_config",
"get_current_vllm_config_or_none",
"set_current_vllm_config",
"get_layers_from_vllm_config",
"WeightTransferConfig",
]

69
vllm/config/attention.py Normal file
View File

@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import field_validator
from vllm.config.utils import config
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@config
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for attention instead of
the unified triton kernel."""
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = True
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
use_prefill_query_quantization: bool = False
"""If set, quantize query for attention in prefill."""
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
from vllm.config.utils import get_hash_factors, hash_factors
ignored_factors: list[str] = []
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
return value

250
vllm/config/cache.py Normal file
View File

@@ -0,0 +1,250 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import format_gib, get_cpu_memory
if TYPE_CHECKING:
from vllm.config.parallel import ParallelConfig
else:
ParallelConfig = Any
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
@config
class CacheConfig:
"""Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
only block sizes up to 32 are supported.
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current
platform."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
utilization. If unspecified, will use the default value of 0.9. This is a
per-instance limit, and only applies to the current vLLM instance. It does
not matter if you have another vLLM instance running on the same GPU. For
example, if you have two vLLM instances running on the same GPU, you can
set the GPU memory utilization to 0.5 for each instance."""
swap_space: float = Field(default=4, ge=0)
"""Size of the CPU swap space per GPU (in GiB)."""
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
bfloat16 instead, this is an invalid option for models that do not default
to fp8.
"""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
num_gpu_blocks_override: int | None = None
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
if specified. Does nothing if `None`. Used for testing preemption."""
sliding_window: int | None = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: bool = True
"""Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing. This is the
current default, as SHA256 is the most secure choice to avoid potential
hash collisions.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256.\n
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
non-cryptographic hashing. Requires the optional ``xxhash`` package.
IMPORTANT: Use of a hashing algorithm that is not considered
cryptographically secure theoretically increases the risk of hash collisions,
which can cause undefined behavior or even leak private information in
multi-tenant environments. Even if collisions are still very unlikely, it is
important to consider your security risk tolerance against the performance
benefits before turning this on.\n
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
reproducible hashing. Requires the optional ``xxhash`` package."""
cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
DEPRECATED: This field is deprecated and will be removed in v0.16.
Please use OffloadConfig.uva.cpu_offload_gb instead.
"""
cpu_offload_params: set[str] = Field(default_factory=set)
"""The set of parameter name segments to target for CPU offloading.
DEPRECATED: This field is deprecated and will be removed in v0.16.
Please use OffloadConfig.uva.cpu_offload_params instead.
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""
cpu_kvcache_space_bytes: int | None = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: int | None = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
mamba_block_size: int | None = Field(default=None, gt=0)
"""Size of a contiguous cache block in number of tokens for mamba cache.
Can be set only when prefix caching is enabled.
Value must be a multiple of 8 to align with causal_conv1d kernel."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model
config."""
mamba_ssm_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
mamba_cache_mode: MambaCacheMode = "none"
"""The cache strategy for Mamba layers.
- "none": set when prefix caching is disabled.
- "all": cache the mamba state of all tokens at position i * block_size. This is
the default behavior (for models that support it) when prefix caching is
enabled.
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)
"""The number of blocks to allocate for GPU memory."""
num_cpu_blocks: int | None = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""
kv_sharing_fast_prefill: bool = False
"""This feature is work in progress and no prefill optimization takes place
with this flag enabled currently.
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overridden with metadata
necessary for implementing this optimization in some models (e.g. Gemma3n)
"""
kv_cache_memory_bytes: int | None = None
"""Size of KV Cache per GPU in bytes. By default, this is set to None
and vllm can automatically infer the kv cache size based on
gpu_memory_utilization. However, users may want to manually specify
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
control of how much memory gets used when compared with using
gpu_memory_utilization. Note that kv_cache_memory_bytes
(when not-None) ignores gpu_memory_utilization"""
kv_offloading_size: float | None = None
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
the total buffer size summed across all TP ranks. By default, this is set
to None, which means no KV offloading is enabled. When set, vLLM will
enable KV cache offloading to CPU using the kv_offloading_backend."""
kv_offloading_backend: KVOffloadingBackend = "native"
"""The backend to use for KV cache offloading. Supported backends include
'native' (vLLM native CPU offloading), 'lmcache'.
KV offloading is only activated when kv_offloading_size is set."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
ignored_factors = {
# Runtime/derived knobs that don't affect compiled graph shape
"gpu_memory_utilization",
"swap_space",
"is_attention_free",
"num_gpu_blocks_override",
"enable_prefix_caching",
"prefix_caching_hash_algo",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
# Post-init/derived counters
"num_gpu_blocks",
"num_cpu_blocks",
# WIP feature toggle not impacting compiled graph shape
"kv_sharing_fast_prefill",
}
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
@field_validator("cache_dtype", mode="after")
@classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
if cache_dtype.startswith("fp8"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor."
)
return cache_dtype
def verify_with_parallel_config(
self,
parallel_config: ParallelConfig,
) -> None:
swap_space_bytes = math.ceil(self.swap_space * GiB_bytes)
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = swap_space_bytes * num_gpus_per_node
msg = (
f"{format_gib(cpu_memory_usage)} GiB out of the "
f"{format_gib(total_cpu_memory)} GiB total CPU memory "
"is allocated for the swap space."
)
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. %s", msg)

1196
vllm/config/compilation.py Normal file

File diff suppressed because it is too large Load Diff

73
vllm/config/device.py Normal file
View File

@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import field
from typing import Any, Literal
import torch
from pydantic import ConfigDict, SkipValidation
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Device | torch.device | None] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError(
"Failed to infer device type, please set "
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
"to turn on verbose logging to help debug the issue."
)
else:
# Device type is assigned explicitly
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)

107
vllm/config/ec_transfer.py Normal file
View File

@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from vllm.config.utils import config
ECProducer = Literal["ec_producer", "ec_both"]
ECConsumer = Literal["ec_consumer", "ec_both"]
ECRole = Literal[ECProducer, ECConsumer]
@config
class ECTransferConfig:
"""Configuration for distributed EC cache transfer."""
ec_connector: str | None = None
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
"""
engine_id: str | None = None
"""The engine id for EC transfers."""
ec_buffer_device: str | None = "cuda"
"""The device used by ec connector to buffer the EC cache.
Currently only support 'cuda'."""
ec_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
ec_role: ECRole | None = None
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
are 'ec_producer', 'ec_consumer', 'ec_both'."""
ec_rank: int | None = None
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
0 for encoder, 1 for pd instance.
Currently only 1P1D is supported."""
ec_parallel_size: int = 1
"""The number of parallel instances for EC cache transfer. For
PyNcclConnector, this should be 2."""
ec_ip: str = "127.0.0.1"
"""The EC connector ip, used to build distributed connection."""
ec_port: int = 14579
"""The EC connector port, used to build distributed connection."""
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
ec_connector_module_path: str | None = None
"""The Python module path to dynamically load the EC connector from.
Only supported in V1."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
raise ValueError(
f"Unsupported ec_role: {self.ec_role}. "
f"Supported roles are {get_args(ECRole)}"
)
if self.ec_connector is not None and self.ec_role is None:
raise ValueError(
"Please specify ec_role when ec_connector "
f"is set, supported roles are {get_args(ECRole)}"
)
@property
def is_ec_transfer_instance(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
@property
def is_ec_producer(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
@property
def is_ec_consumer(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.ec_connector_extra_config.get(key, default)

76
vllm/config/kernel.py Normal file
View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Literal
from pydantic import Field, field_validator
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
MoEBackend = Literal[
"auto",
"triton",
"deep_gemm",
"cutlass",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"aiter",
]
@config
class KernelConfig:
"""Configuration for kernel selection and warmup behavior."""
enable_flashinfer_autotune: bool = Field(default=None)
"""If True, run FlashInfer autotuning during kernel warmup."""
moe_backend: MoEBackend = "auto"
"""Backend for MoE expert computation kernels. Available options:
- "auto": Automatically select the best backend based on model and hardware\n
- "triton": Use Triton-based fused MoE kernels\n
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n
- "cutlass": Use vLLM CUTLASS kernels\n
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n
- "marlin": Use Marlin kernels (weight-only quantization)\n
- "aiter": Use AMD AITer kernels (ROCm only)"""
@field_validator("moe_backend", mode="before")
@classmethod
def _normalize_moe_backend(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower().replace("-", "_")
return value
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("enable_flashinfer_autotune", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialization is delayed."""
if value is None:
return value
return handler(value)

54
vllm/config/kv_events.py Normal file
View File

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal
from pydantic import Field
from vllm.config.utils import config
@config
class KVEventsConfig:
"""Configuration for KV event publishing."""
enable_kv_cache_events: bool = False
"""If True, enable KV cache events for tracking block storage and removal.
Events can be published externally by zmq using the event publisher config.
"""
publisher: Literal["null", "zmq"] = Field(default=None)
"""The publisher to use for publishing kv events. Can be "null", "zmq".
"""
endpoint: str = "tcp://*:5557"
"""The zmq endpoint to use for publishing kv events.
"""
replay_endpoint: str | None = None
"""The zmq endpoint to use for replaying kv events.
"""
buffer_steps: int = 10_000
"""The number of steps to cache for replay endpoint. Will only save
events from the last N steps for the replay endpoint.
"""
hwm: int = 100_000
"""The zmq high water mark for the event publisher. After queueing N events,
events will start dropping if the consumer is not keeping up.
"""
max_queue_size: int = 100_000
"""The maximum number of events to queue while waiting for publishing.
"""
topic: str = ""
"""The topic to use for the event publisher. Consumers can subscribe to
this topic to receive events.
"""
def __post_init__(self):
if self.publisher is None:
self.publisher = "zmq" if self.enable_kv_cache_events else "null"

116
vllm/config/kv_transfer.py Normal file
View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
KVRole = Literal[KVProducer, KVConsumer]
@config
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""
kv_connector: str | None = None
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: str | None = None
"""The engine id for KV transfers."""
kv_buffer_device: str = "cuda"
"""The device used by kv connector to buffer the KV cache. Choices are
'cuda' and 'cpu'."""
kv_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
kv_role: KVRole | None = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
kv_rank: int | None = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
0 for prefill instance, 1 for decode instance.
Currently only 1P1D is supported."""
kv_parallel_size: int = 1
"""The number of parallel instances for KV cache transfer. For
P2pNcclConnector, this should be 2."""
kv_ip: str = "127.0.0.1"
"""The KV connector ip, used to build distributed connection."""
kv_port: int = 14579
"""The KV connector port, used to build distributed connection."""
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
kv_connector_module_path: str | None = None
"""The Python module path to dynamically load the KV connector from.
Only supported in V1."""
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "fail"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks
'fail': immediately fail the request with an error finish reason (default)"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
raise ValueError(
f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are {get_args(KVRole)}"
)
if self.kv_connector is not None and self.kv_role is None:
raise ValueError(
"Please specify kv_role when kv_connector "
f"is set, supported roles are {get_args(KVRole)}"
)
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and self.kv_role in get_args(KVRole)
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and self.kv_role in get_args(KVProducer)
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and self.kv_role in get_args(KVConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)

122
vllm/config/load.py Normal file
View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.model_executor.model_loader import LoadFormats
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
else:
LoadFormats = Any
TensorizerConfig = Any
logger = init_logger(__name__)
@config
class LoadConfig:
"""Configuration for loading the model weights."""
load_format: str | LoadFormats = "auto"
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
- "pt" will load the weights in the pytorch bin format.\n
- "safetensors" will load the weights in the safetensors format.\n
- "npcache" will load the weights in pytorch format and store a numpy cache
to speed up the loading.\n
- "dummy" will initialize the weights with random values, which is mainly
for profiling.\n
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
loading. See the Tensorize vLLM Model script in the Examples section for
more information.\n
- "runai_streamer" will load the Safetensors weights using Run:ai Model
Streamer.\n
- "runai_streamer_sharded" will load weights from pre-sharded checkpoint
files using Run:ai Model Streamer.\n
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
- "sharded_state" will load weights from pre-sharded checkpoint files,
supporting efficient loading of tensor-parallel models.\n
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models.
- Other custom values can be supported via plugins."""
download_dir: str | None = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
safetensors_load_strategy: str = "lazy"
"""Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage.
- "eager": The entire file is read into CPU memory upfront before loading.
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
- "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
"""
model_loader_extra_config: dict | TensorizerConfig = Field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format."""
device: str | None = None
"""Device to which model weights will be loaded, default to
device_config.device"""
ignore_patterns: list[str] | str = Field(default_factory=lambda: ["original/**/*"])
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
use_tqdm_on_load: bool = True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
pt_load_map_location: str | dict[str, str] = "cpu"
"""
pt_load_map_location: the map location for loading pytorch checkpoint, to
support loading checkpoints can only be loaded on certain devices like
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
mapping from different devices like from GPU 1 to GPU 0:
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
in dictionary needs to be double quoted for json parsing. For more details,
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("load_format", mode="after")
def _lowercase_load_format(cls, load_format: str) -> str:
return load_format.lower()
@field_validator("ignore_patterns", mode="after")
def _validate_ignore_patterns(
cls, ignore_patterns: list[str] | str
) -> list[str] | str:
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
ignore_patterns,
)
return ignore_patterns

107
vllm/config/lora.py Normal file
View File

@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config.cache import CacheConfig
else:
ModelConfig = Any
CacheConfig = Any
logger = init_logger(__name__)
LoRADType = Literal["auto", "float16", "bfloat16"]
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512]
@config(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""
max_lora_rank: MaxLoRARanks = 16
"""Max LoRA rank."""
max_loras: int = Field(default=1, ge=1)
"""Max number of LoRAs in a single batch."""
fully_sharded_loras: bool = False
"""By default, only half of the LoRA computation is sharded with tensor
parallelism. Enabling this will use the fully sharded layers. At high
sequence length, max rank or tensor parallel size, this is likely faster.
"""
max_cpu_loras: int | None = None
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
model always expects a LoRA to be active when a given modality is present.
Note that currently, if a request provides multiple additional
modalities, each of which have their own LoRA, we do NOT apply
default_mm_loras because we currently only support one lora adapter
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
enable_tower_connector_lora: bool = False
"""If `True`, LoRA support for the tower (vision encoder) and connector
of multimodal models will be enabled. This is an experimental feature and
currently only supports some MM models such as the Qwen VL series. The default
is False."""
specialize_active_lora: bool = False
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
When set to True, separate cuda graphs will be captured for different counts
of active LoRAs (powers of 2 up to max_loras), which can improve performance
for variable LoRA usage patterns at the cost of increased startup time and
memory usage. Only takes effect when cudagraph_specialize_lora is True.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
factors.append(self.max_lora_rank)
factors.append(self.max_loras)
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.enable_tower_connector_lora)
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")
def _validate_lora_config(self) -> Self:
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_loras ({self.max_loras})."
)
return self
def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)

2056
vllm/config/model.py Normal file

File diff suppressed because it is too large Load Diff

57
vllm/config/model_arch.py Normal file
View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelArchitectureConfig:
"""
Configuration for model architecture that required by vLLM runtime
"""
architectures: list[str] | None
"""List of model architecture class names (e.g., ['LlamaForCausalLM']).
It can be None upon calling `vllm_config.with_hf_config(config.text_config)`"""
model_type: str
"""Model type identifier (e.g., 'llama', 'gpt_oss')."""
text_model_type: str | None
"""Text model type identifier (e.g., 'llama4_text')."""
hidden_size: int
"""Hidden size of the model."""
total_num_hidden_layers: int
"""Number of hidden layers in the model."""
total_num_attention_heads: int
"""Number of attention heads in the model."""
head_size: int
"""Head dimension of the model."""
vocab_size: int
"""Vocabulary size of the model."""
total_num_kv_heads: int
"""Number of key value heads in the model."""
num_experts: int
"""Number of experts in the model."""
quantization_config: dict[str, Any] | None
"""Quantization configuration dictionary containing quantization parameters."""
is_deepseek_mla: bool
"""Whether the model is a DeepSeek MLA model."""
derived_max_model_len_and_key: tuple[float, str | None]
"""Derived maximum model length and key from the hf config."""

281
vllm/config/multimodal.py Normal file
View File

@@ -0,0 +1,281 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import Any, Literal, TypeAlias, TypedDict, final
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@dataclass
class BaseDummyOptions:
"""Base options for generating dummy data during profiling."""
count: int = Field(999, ge=0)
@dataclass(config=ConfigDict(extra="forbid"))
class VideoDummyOptions(BaseDummyOptions):
"""Options for generating dummy video data during profiling."""
num_frames: int | None = Field(None, gt=0)
width: int | None = Field(None, gt=0)
height: int | None = Field(None, gt=0)
@dataclass(config=ConfigDict(extra="forbid"))
class ImageDummyOptions(BaseDummyOptions):
"""Options for generating dummy image data during profiling."""
width: int | None = Field(None, gt=0)
height: int | None = Field(None, gt=0)
@dataclass(config=ConfigDict(extra="forbid"))
class AudioDummyOptions(BaseDummyOptions):
"""Options for generating dummy audio data during profiling."""
length: int | None = Field(None, gt=0)
@final
class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ImageDummyOptions
"""Options for dummy images."""
video: VideoDummyOptions
"""Options for dummy videos."""
audio: AudioDummyOptions
"""Options for dummy audios."""
MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
"""
A dictionary containing an entry for each modality type of dummy data.
The built-in modalities are defined by
[`MultiModalDummyOptionsBuiltins`][vllm.config.multimodal.MultiModalDummyOptionsBuiltins].
"""
@config
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
language_model_only: bool = False
"""If True, disables all multimodal inputs by setting all modality limits to 0.
Equivalent to setting `--limit-mm-per-prompt` to 0 for every modality."""
limit_per_prompt: MMDummyOptions = Field(default_factory=dict)
"""The maximum number of input items and options allowed per
prompt for each modality.
Defaults to 999 for each modality.
Legacy format (count only):
{"image": 16, "video": 2}
Configurable format (with options):
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
"image": {"count": 5, "width": 512, "height": 512}}
Mixed format (combining both):
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}}
"""
enable_mm_embeds: bool = False
"""If `True`, enables passing multimodal embeddings:
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`.
When enabled with `--limit-mm-per-prompt` set to 0 for a modality,
precomputed embeddings skip count validation for that modality,
saving memory by not loading encoder modules while still enabling
embeddings as an input. Limits greater than 0 still apply to embeddings.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'`"""
mm_processor_kwargs: dict[str, object] | None = None
"""Arguments to be forwarded to the model's processor for multi-modal data,
e.g., image processor. Overrides for the multi-modal processor obtained
from `transformers.AutoProcessor.from_pretrained`.
The available overrides depend on the model that is being run.
For example, for Phi-3-Vision:
`{"num_crops": 4}`."""
mm_processor_cache_gb: float = Field(default=4, ge=0)
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
This cache is duplicated for each API process and engine core process,
resulting in a total memory usage of
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_processor_cache_type: MMCacheType = "lru"
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
mm_shm_cache_max_object_size_mb: int = Field(default=128, ge=0)
"""Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is
`"shm"`."""
mm_encoder_only: bool = False
"""
When enabled, skips the language component of the model.
This is usually only valid in disaggregated Encoder process.
"""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using tensor
parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)\n
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend: AttentionBackendEnum | None = None
"""Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from
`vllm.v1.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string."""
skip_mm_profiling: bool = False
"""When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
estimating the peak memory usage of the activation of multimodal encoder and
embedding cache."""
video_pruning_rate: float | None = Field(default=None, ge=0.0, lt=1.0)
"""Sets pruning rate for video pruning via Efficient Video Sampling.
Value sits in range [0;1) and determines fraction of media tokens
from each video to be pruned.
"""
@field_validator("limit_per_prompt", mode="before")
@classmethod
def _validate_limit_per_prompt(
cls,
value: dict[str, int | dict[str, int]],
) -> MMDummyOptions:
out: MMDummyOptions = {}
for k, v in value.items():
# Handle legacy format where only count is specified
if isinstance(v, int):
v = {"count": v}
# Convert to the appropriate DummyOptions subclass
if k == "video":
out[k] = VideoDummyOptions(**v)
elif k == "image":
out[k] = ImageDummyOptions(**v)
elif k == "audio":
out[k] = AudioDummyOptions(**v)
else:
out[k] = BaseDummyOptions(**v)
return out
@field_validator("mm_encoder_attn_backend", mode="before")
@classmethod
def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None:
if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
if value is None or isinstance(value, AttentionBackendEnum):
return value
assert isinstance(value, str), (
"mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
)
return AttentionBackendEnum[value.upper()]
@model_validator(mode="after")
def _validate_multimodal_config(self):
if self.mm_processor_cache_type != "shm" and (
self.mm_shm_cache_max_object_size_mb
!= MultiModalConfig.mm_shm_cache_max_object_size_mb
):
raise ValueError(
"'mm_shm_cache_max_object_size_mb' should only be set when "
"'mm_processor_cache_type' is 'shm'."
)
return self
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = [
self.mm_encoder_attn_backend.name
if self.mm_encoder_attn_backend is not None
else None,
self.mm_encoder_tp_mode,
]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality (backward compatible).
"""
if self.language_model_only:
return 0
limit_data = self.limit_per_prompt.get(modality)
if limit_data is None:
# Unspecified modality is set to 999 by default
return 999
return limit_data.count
def merge_mm_processor_kwargs(
self,
inference_kwargs: Mapping[str, object],
) -> dict[str, object]:
"""
Get the keyword arguments to pass to the multi-modal processor
according to the extra arguments passed during inference.
"""
kwargs = self.mm_processor_kwargs or {}
return kwargs | dict(inference_kwargs)
def is_multimodal_pruning_enabled(self):
return self.video_pruning_rate is not None and self.video_pruning_rate > 0

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cached_property
from typing import Any, Literal, cast
from packaging.version import parse
from pydantic import Field, field_validator, model_validator
from vllm import version
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
DetailedTraceModules = Literal["model", "worker", "all"]
@config
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""
show_hidden_metrics_for_version: str | None = None
"""Enable deprecated Prometheus metrics that have been hidden since the
specified version. For example, if a previously deprecated metric has been
hidden since the v0.7.0 release, you use
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
you migrate to new metrics. The metric is likely to be removed completely
in an upcoming release."""
@cached_property
def show_hidden_metrics(self) -> bool:
"""Check if the hidden metrics should be shown."""
if self.show_hidden_metrics_for_version is None:
return False
return version._prev_minor_version_was(self.show_hidden_metrics_for_version)
otlp_traces_endpoint: str | None = None
"""Target URL to which OpenTelemetry traces will be sent."""
collect_detailed_traces: list[DetailedTraceModules] | None = None
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
set, it will collect detailed traces for the specified modules. This
involves use of possibly costly and or blocking operations and hence might
have a performance impact.
Note that collecting detailed timing information for each request can be
expensive."""
kv_cache_metrics: bool = False
"""Enable KV cache residency metrics (lifetime, idle time, reuse gaps).
Uses sampling to minimize overhead.
Requires log stats to be enabled (i.e., --disable-log-stats not set)."""
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
cudagraph_metrics: bool = False
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
enable_layerwise_nvtx_tracing: bool = False
"""Enable layerwise NVTX tracing. This traces the execution of each layer or
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
enable_mm_processor_stats: bool = False
"""Enable collection of timing statistics for multimodal processor operations.
This is for internal use only (e.g., benchmarks) and is not exposed as a CLI
argument."""
enable_logging_iteration_details: bool = False
"""Enable detailed logging of iteration details.
If set, vllm EngineCore will log iteration details
This includes number of context/generation requests and tokens
and the elapsed cpu time for the iteration."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
return self.collect_detailed_traces is not None and (
"model" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces
)
@cached_property
def collect_model_execute_time(self) -> bool:
"""Whether to collect model execute time for the request."""
return self.collect_detailed_traces is not None and (
"worker" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces
)
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("show_hidden_metrics_for_version")
@classmethod
def _validate_show_hidden_metrics_for_version(cls, value: str | None) -> str | None:
if value is not None:
# Raises an exception if the string is not a valid version.
parse(value)
return value
@field_validator("otlp_traces_endpoint")
@classmethod
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
if value is not None:
from vllm.tracing import is_tracing_available, otel_import_error_traceback
if not is_tracing_available():
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}"
)
return value
@field_validator("collect_detailed_traces")
@classmethod
def _validate_collect_detailed_traces(
cls, value: list[DetailedTraceModules] | None
) -> list[DetailedTraceModules] | None:
"""Handle the legacy case where users might provide a comma-separated
string instead of a list of strings."""
if value is not None and len(value) == 1 and "," in value[0]:
value = cast(list[DetailedTraceModules], value[0].split(","))
return value
@model_validator(mode="after")
def _validate_tracing_config(self):
if self.collect_detailed_traces and not self.otlp_traces_endpoint:
raise ValueError(
"collect_detailed_traces requires `--otlp-traces-endpoint` to be set."
)
return self

153
vllm/config/offload.py Normal file
View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Configuration for model weight offloading."""
import warnings
from typing import Literal
from pydantic import Field, model_validator
from vllm.config.utils import config
OffloadBackend = Literal["auto", "uva", "prefetch"]
@config
class UVAOffloadConfig:
"""Configuration for UVA (Unified Virtual Addressing) CPU offloading.
Uses zero-copy access from CPU-pinned memory. Simple but requires
fast CPU-GPU interconnect.
"""
cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
This uses UVA (Unified Virtual Addressing) for zero-copy access.
"""
cpu_offload_params: set[str] = Field(default_factory=set)
"""The set of parameter name segments to target for CPU offloading.
Unmatched parameters are not offloaded. If this set is empty, parameters
are offloaded non-selectively until the memory limit defined by
`cpu_offload_gb` is reached.
Examples:
- For parameter name "mlp.experts.w2_weight":
- "experts" or "experts.w2_weight" will match.
- "expert" or "w2" will NOT match (must be exact segments).
This allows distinguishing parameters like "w2_weight" and "w2_weight_scale".
"""
@config
class PrefetchOffloadConfig:
"""Configuration for prefetch-based CPU offloading.
Groups layers and uses async H2D prefetch to hide transfer latency.
"""
offload_group_size: int = Field(default=0, ge=0)
"""Group every N layers together. Offload last `offload_num_in_group`
layers of each group. Default is 0 (disabled).
Example: group_size=8, num_in_group=2 offloads layers 6,7,14,15,22,23,...
Unlike cpu_offload_gb, this uses explicit async prefetching to hide transfer
latency.
"""
offload_num_in_group: int = Field(default=1, ge=1)
"""Number of layers to offload per group.
Must be <= offload_group_size. Default is 1."""
offload_prefetch_step: int = Field(default=1, ge=0)
"""Number of layers to prefetch ahead.
Higher values hide more latency but use more GPU memory. Default is 1."""
offload_params: set[str] = Field(default_factory=set)
"""The set of parameter name segments to target for prefetch offloading.
Unmatched parameters are not offloaded. If this set is empty, ALL
parameters of each offloaded layer are offloaded.
Uses segment matching: "w13_weight" matches "mlp.experts.w13_weight"
but not "mlp.experts.w13_weight_scale".
"""
@config
class OffloadConfig:
"""Configuration for model weight offloading to reduce GPU memory usage."""
offload_backend: OffloadBackend = "auto"
"""The backend for weight offloading. Options:
- "auto": Selects based on which sub-config has non-default values
(prefetch if offload_group_size > 0, uva if cpu_offload_gb > 0).
- "uva": UVA (Unified Virtual Addressing) zero-copy offloading.
- "prefetch": Async prefetch with group-based layer offloading.
"""
uva: UVAOffloadConfig = Field(default_factory=UVAOffloadConfig)
"""Parameters for UVA offloading backend."""
prefetch: PrefetchOffloadConfig = Field(default_factory=PrefetchOffloadConfig)
"""Parameters for prefetch offloading backend."""
@model_validator(mode="after")
def validate_offload_config(self) -> "OffloadConfig":
"""Validate offload configuration constraints."""
if self.offload_backend == "prefetch" or self.prefetch.offload_group_size > 0:
if self.prefetch.offload_num_in_group > self.prefetch.offload_group_size:
raise ValueError(
f"offload_num_in_group ({self.prefetch.offload_num_in_group})"
f" must be <= offload_group_size"
f" ({self.prefetch.offload_group_size})"
)
if self.prefetch.offload_prefetch_step < 1:
raise ValueError(
f"offload_prefetch_step"
f" ({self.prefetch.offload_prefetch_step})"
f" must be >= 1 when prefetch offloading is enabled"
f" (offload_group_size > 0)"
)
# Warn if both backends have non-default values
uva_active = self.uva.cpu_offload_gb > 0
prefetch_active = self.prefetch.offload_group_size > 0
if self.offload_backend == "uva" and prefetch_active:
warnings.warn(
"Prefetch offload fields are set but offload_backend='uva'. "
"Prefetch settings will be ignored.",
stacklevel=2,
)
elif self.offload_backend == "prefetch" and uva_active:
warnings.warn(
"UVA offload fields are set but offload_backend='prefetch'. "
"UVA settings will be ignored.",
stacklevel=2,
)
elif self.offload_backend == "auto" and uva_active and prefetch_active:
warnings.warn(
"Both UVA and prefetch offload fields are set with "
"offload_backend='auto'. Prefetch backend will be selected. "
"Set offload_backend explicitly to suppress this warning.",
stacklevel=2,
)
return self
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the offload configs.
All fields are included because PrefetchOffloader patches module
forwards and inserts custom ops (wait_prefetch, start_prefetch)
into the computation graph. Changing any offload setting can
alter which layers are hooked and how prefetch indices are
computed, so the compilation cache must distinguish them.
"""
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors=set())
hash_str = hash_factors(factors)
return hash_str

713
vllm/config/parallel.py Normal file
View File

@@ -0,0 +1,713 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import Field, field_validator, model_validator
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv
from ray.util.placement_group import PlacementGroup
from vllm.v1.executor import Executor
else:
RuntimeEnv = Any
PlacementGroup = Any
Executor = Any
logger = init_logger(__name__)
ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
All2AllBackend = Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
]
@config
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size: int = 1000
"""Window size for expert load recording."""
step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts: int = Field(default=0, ge=0)
"""Number of redundant experts to use for expert parallelism."""
log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
log_balancedness_interval: int = 1
"""
Interval for logging the balancedness.
"""
use_async: bool = False
"""
Whether to use non-blocking EPLB.
"""
policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB)."""
@model_validator(mode="after")
def _validate_eplb_config(self) -> Self:
if self.use_async and self.policy != "default":
raise ValueError("Async EPLB is only supported with the default policy.")
if self.log_balancedness and self.log_balancedness_interval <= 0:
raise ValueError("log_balancedness_interval must be greater than 0.")
return self
@config
class ParallelConfig:
"""Configuration for the distributed execution."""
pipeline_parallel_size: int = 1
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: int | None = None
"""Local rank of the data parallel group,
set only in SPMD mode."""
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
data_parallel_backend: DataParallelBackend = "mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
is provided explicitly to vllm serve."""
data_parallel_hybrid_lb: bool = False
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
is_moe_model: bool | None = None
"""Whether the deployed model is MoE (if known)."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
"""Expert parallelism configuration."""
expert_placement_strategy: ExpertPlacementStrategy = "linear"
"""The expert placement strategy for MoE layers:\n
- "linear": Experts are placed in a contiguous manner. For example, with 4
experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
experts [2, 3].\n
- "round_robin": Experts are placed in a round-robin manner. For example,
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts."""
all2all_backend: All2AllBackend = "allgather_reducescatter"
"""All2All backend for MoE expert parallel communication. Available options:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models."""
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
"""Number of ubatch size."""
dbo_decode_token_threshold: int = 32
"""The threshold for dual batch overlap for batches only containing decodes.
If the number of tokens in the request is greater than this threshold,
microbatching will be used. Otherwise, the request will be processed in a
single batch."""
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
"""The threshold for dual batch overlap for batches that contain one or more
prefills. If the number of tokens in the request is greater than this
threshold, microbatching will be used. Otherwise, the request will be
processed in a single batch."""
disable_nccl_for_dp_synchronization: bool | None = Field(default=None)
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
to use Gloo instead of NCCL for its all reduce.
Defaults to True when async scheduling is enabled, False otherwise.
"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
ray_runtime_env: RuntimeEnv | None = None
"""Ray runtime environment to pass to distributed workers."""
placement_group: PlacementGroup | None = None
"""ray distributed model workers placement group."""
distributed_executor_backend: (
str | DistributedExecutorBackend | type[Executor] | None
) = None
"""Backend to use for distributed model workers, either "ray" or "mp"
(multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
is less than or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, an error will be raised. To use "mp"
you must also set nnodes, and to use "ray" you must manually set
distributed_executor_backend to "ray".
Note that tpu only support Ray for distributed inference."""
worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform."""
sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decoding.
If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
master_addr: str = "127.0.0.1"
"""distributed master address for multi-node distributed
inference when distributed_executor_backend is mp."""
master_port: int = 29501
"""distributed master port for multi-node distributed
inference when distributed_executor_backend is mp."""
node_rank: int = 0
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""
world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
rank: int = 0
"""Global rank in distributed setup."""
_data_parallel_master_port_list: list[int] = Field(default_factory=list)
"""List of open port auto-queried for data parallel messaging.
Set to be private as it's not intended to be configured by users.
"""
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
dcp_kv_cache_interleave_size: int = 1
"""
Interleave size of kv_cache storage while using DCP.
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
and will be deprecated when PCP is fully supported.
"""
cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
and `total_cp_world_size = pcp_world_size * dcp_world_size`.
store interleave_size tokens on total_cp_rank i,
then store next interleave_size tokens on total_cp_rank i+1.
Interleave_size=1: token-level alignment, where token `i` is stored on
total_cp_rank `i % total_cp_world_size`.
Interleave_size=block_size: block-level alignment, where tokens are
first populated to the preceding ranks. Tokens are then stored
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
Block_size should be divisible by cp_kv_cache_interleave_size.
"""
data_parallel_index: int = Field(init=False)
"""Equal to the data parallel rank but not used for torch process groups
and not overridden for dense models."""
_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.
Note:
This is an internal config that is only valid for and
should only be set by API server scale-out.
"""
_api_process_rank: int = Field(default=0, ge=-1)
"""
The rank of this API process, or `-1` for engine core processes
under API server scale-out.
Note:
This is an internal config that is only valid for and
should only be set by API server scale-out.
"""
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
return None if value is None else handler(value)
@model_validator(mode="after")
def _validate_parallel_config(self) -> Self:
if self._api_process_rank >= self._api_process_count:
raise ValueError(
"Invalid value of `_api_process_rank`. "
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
f"but found: {self._api_process_rank}"
)
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
f"must be <= data_parallel_size ({self.data_parallel_size})"
)
if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
raise ValueError(
"data_parallel_external_lb can only be set when data_parallel_size > 1"
)
if self.enable_eplb:
if not current_platform.is_cuda_alike():
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices or ROCm devices now."
)
if not self.enable_expert_parallel:
raise ValueError("enable_expert_parallel must be True to use EPLB.")
if self.tensor_parallel_size * self.data_parallel_size <= 1:
raise ValueError(
"EPLB requires tensor_parallel_size or data_parallel_size "
f"to be greater than 1, but got "
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts is set to "
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
"enabled. Either enable EPLB or unset "
"num_redundant_experts."
)
# Note(hc): In the current implementation of decode context
# parallel(DCP), tp_size needs to be divisible by dcp_size,
# because the world size does not change by dcp, it simply
# reuses the GPUs of TP group, and split one TP group into
# tp_size//dcp_size DCP groups.
if self.tensor_parallel_size % self.decode_context_parallel_size != 0:
raise ValueError(
f"tp_size={self.tensor_parallel_size} must be divisible by"
f"dcp_size={self.decode_context_parallel_size}."
)
return self
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
return self.world_size * self.data_parallel_size
@property
def use_ubatching(self) -> bool:
return self.enable_dbo or self.ubatch_size > 1
@property
def num_ubatches(self) -> int:
return 2 if self.enable_dbo else self.ubatch_size
@property
def local_engines_only(self) -> bool:
"""
Client manages local+remote EngineCores in pure internal LB case.
Client manages local EngineCores in hybrid and external LB case.
"""
return self.data_parallel_external_lb or self.data_parallel_hybrid_lb
def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
pop a new port from the prepared port list each time we need to
initialize a new process group related to data parallelism.
"""
if self._data_parallel_master_port_list:
answer = self._data_parallel_master_port_list.pop()
else:
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
return answer
def stateless_init_dp_group(self) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
# process binds the port the others will subsequently fail
# with `torch.distributed.DistNetworkError: EADDRINUSE`.
# To make the initialization more robust we retry a few times
# with a fresh port whenever this specific error is observed.
from torch.distributed import DistNetworkError
from vllm.distributed.utils import (
stateless_init_torch_distributed_process_group,
)
max_retries = 5
last_exc: Exception | None = None
for _ in range(max_retries):
try:
# use gloo since the engine process might not have cuda device
return stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip,
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend=current_platform.dist_backend,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
if "EADDRINUSE" in str(e):
logger.warning("Address already in use. Retrying with a new port.")
last_exc = e
continue # try again with a new port
raise e
# If we get here all retries have failed.
assert last_exc is not None
raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (
self.all2all_backend
in (
"allgather_reducescatter",
"naive",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
)
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1
)
@property
def node_rank_within_dp(self) -> int:
return self.node_rank % self.nnodes_within_dp
@property
def nnodes_within_dp(self) -> int:
if self.nnodes == 1:
return 1
data_parallel_node_size = (
self.data_parallel_size // self.data_parallel_size_local
)
return self.nnodes // data_parallel_node_size
@property
def local_world_size(self) -> int:
return self.world_size // self.nnodes_within_dp
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
# dp rank 0: has_unfinished_seqs=True
# dp rank 1: has_unfinished_seqs=False
# aggregated: has_unfinished_seqs=True
# so this is an OR operation, i.e. MAX in integers
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished
@staticmethod
def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
if kv_cache_memory == -1:
kv_cache_memory = torch.iinfo(torch.int64).max
tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
# we cannot use broadcast for stateless dp group since it depends
# on global rank
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
return tensor.item()
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
This hash is also used for DP worker configuration validation
to prevent hangs from mismatched collective communication patterns.
"""
ignored_factors = {
# Derived/runtime topology, networking, or launch details
"data_parallel_rank",
"data_parallel_rank_local",
"data_parallel_size_local",
"data_parallel_index",
"data_parallel_backend",
"data_parallel_external_lb",
"data_parallel_hybrid_lb",
"data_parallel_master_ip",
"data_parallel_master_port",
"_data_parallel_master_port_list",
"data_parallel_rpc_port",
"rank",
"master_addr",
"master_port",
"node_rank",
"nnodes",
"max_parallel_loading_workers",
"disable_custom_all_reduce",
"ray_workers_use_nsight",
"ray_runtime_env",
"placement_group",
"distributed_executor_backend",
"worker_cls",
"sd_worker_cls",
"worker_extension_cls",
"_api_process_count",
"_api_process_rank",
}
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
def __post_init__(self) -> None:
# Continue with the rest of the initialization
self.world_size = (
self.pipeline_parallel_size
* self.tensor_parallel_size
* self.prefill_context_parallel_size
)
if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
# For external launcher,
# we need to set the data parallel rank automatically
self.data_parallel_rank = int(os.environ["RANK"]) // (
self.world_size // self.data_parallel_size
)
logger.info(
"Set data_parallel_rank to %d automatically.",
self.data_parallel_rank,
)
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
f"data_parallel_rank ({self.data_parallel_rank})"
f" must be in the range [0, {self.data_parallel_size})"
)
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_size > 1 and self.is_moe_model is False:
raise ValueError(
"Offline data parallel mode is not supported/useful"
" for dense models."
)
self.data_parallel_index = self.data_parallel_rank
if self.distributed_executor_backend == "external_launcher":
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from vllm.v1.executor import ray_utils
backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available()
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif current_platform.is_cuda() and self.nnodes > 1:
backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
):
gpu_count = cuda_device_count_stateless()
raise ValueError(
f"World size ({self.world_size}) is larger than the number of "
f"available GPUs ({gpu_count}) in this node. If this is "
"intentional and you are using:\n"
"- ray, set '--distributed-executor-backend ray'.\n"
"- multiprocessing, set '--nnodes' appropriately."
)
elif self.data_parallel_backend == "ray":
logger.info(
"Using ray distributed inference because "
"data_parallel_backend is ray"
)
backend = "ray"
elif ray_found:
if self.placement_group:
backend = "ray"
else:
from ray import is_initialized as ray_is_initialized
if ray_is_initialized():
from ray.util import get_current_placement_group
if get_current_placement_group():
backend = "ray"
self.distributed_executor_backend = backend
logger.debug("Defaulting to use %s for distributed inference", backend)
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
if self.max_parallel_loading_workers is not None:
logger.warning(
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
allowed_backends = ("mp", "uni", "external_launcher")
if (
self.distributed_executor_backend not in allowed_backends
and self.nnodes > 1
):
raise ValueError(
"nnodes > 1 can only be set when distributed executor "
"backend is mp, uni or external_launcher."
)
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and getattr(self.distributed_executor_backend, "uses_ray", False)
)
@model_validator(mode="after")
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.v1.executor import Executor
# Enable batch invariance settings if requested
if vllm_is_batch_invariant():
self.disable_custom_all_reduce = True
if (
self.distributed_executor_backend is not None
and not isinstance(self.distributed_executor_backend, str)
and not (
isinstance(self.distributed_executor_backend, type)
and issubclass(self.distributed_executor_backend, Executor)
)
):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher', "
" custom Executor subclass or its import path."
)
if self.use_ray:
from vllm.v1.executor import ray_utils
ray_utils.assert_ray_available()
if not current_platform.use_custom_allreduce():
self.disable_custom_all_reduce = True
logger.debug(
"Disabled the custom all-reduce kernel because it is not "
"supported on current platform."
)
if self.nnodes > 1:
self.disable_custom_all_reduce = True
logger.debug(
"Disabled the custom all-reduce since we are running on multi-node."
)
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."
)
return self

146
vllm/config/pooler.py Normal file
View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal, get_args
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
TokenPoolingType = Literal["ALL", "STEP"]
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: SequencePoolingType | TokenPoolingType | None = None
"""
The pooling method used for pooling.
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
with this field. Alternatively, users can set `seq_pooling_type` and
`tok_pooling_type` explicitly.
This field is mainly for user convenience. Internal code should always use
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
"""
seq_pooling_type: SequencePoolingType | None = None
"""
The pooling method used for sequence pooling.
"""
tok_pooling_type: TokenPoolingType | None = None
"""
The pooling method used for tokenwise pooling.
"""
use_activation: bool | None = None
"""
Whether to apply activation function to the pooler outputs.
`None` uses the pooler's default, which is `True` in most cases.
"""
## for embedding models
dimensions: int | None = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: bool = False
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: int | None = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""
## for classification models
logit_bias: float | None = None
"""
If provided, apply classification logit biases. Defaults to None.
"""
## for reward models
step_tag_id: int | None = None
"""
If set, only the score corresponding to the `step_tag_id` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""
returned_token_ids: list[int] | None = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of `good_token` and `bad_token` in the
`math-shepherd-mistral-7b-prm` model.
"""
def __post_init__(self) -> None:
if pooling_type := self.pooling_type:
if self.seq_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `seq_pooling_type`"
)
if self.tok_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `tok_pooling_type`"
)
if pooling_type in SEQ_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.seq_pooling_type = pooling_type
elif pooling_type in TOK_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.tok_pooling_type = pooling_type
else:
raise NotImplementedError(pooling_type)
def get_seq_pooling_type(self) -> SequencePoolingType:
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
return self.seq_pooling_type
def get_tok_pooling_type(self) -> TokenPoolingType:
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
return self.tok_pooling_type
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

124
vllm/config/profiler.py Normal file
View File

@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Literal
from pydantic import Field, model_validator
from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
def _is_uri_path(path: str) -> bool:
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
These paths should not be converted to absolute paths.
"""
if "://" in path:
scheme = path.split("://")[0]
# Windows drive letters are single characters (e.g., C://)
# Valid URI schemes have more than one character
return len(scheme) > 1
return False
@config
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
profiler: ProfilerKind | None = None
"""Which profiler to use. Defaults to None. Options are:
- 'torch': Use PyTorch profiler.\n
- 'cuda': Use CUDA profiler."""
torch_profiler_dir: str = ""
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
torch_profiler_use_gzip: bool = True
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
torch_profiler_dump_cuda_time_total: bool = True
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
torch_profiler_record_shapes: bool = False
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
torch_profiler_with_memory: bool = False
"""If `True`, enables memory profiling in the torch profiler.
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
"""
delay_iterations: int = Field(default=0, ge=0)
"""Number of engine iterations to skip before starting profiling.
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
"""
max_iterations: int = Field(default=0, ge=0)
"""Maximum number of engine iterations to profile after starting profiling.
Defaults to 0, meaning no limit.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")
def _validate_profiler_config(self) -> Self:
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
logger.warning_once(
"Using 'torch' profiler with delay_iterations or max_iterations "
"while ignore_frontend is False may result in high overhead."
)
profiler_dir = self.torch_profiler_dir
if profiler_dir and self.profiler != "torch":
raise ValueError(
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
)
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
# These paths should not be converted to absolute paths
if profiler_dir and not _is_uri_path(profiler_dir):
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
return self

300
vllm/config/scheduler.py Normal file
View File

@@ -0,0 +1,300 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator
from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.v1.core.sched.interface import SchedulerInterface
logger = init_logger(__name__)
RunnerType = Literal["generate", "pooling", "draft"]
SchedulerPolicy = Literal["fcfs", "priority"]
@config
class SchedulerConfig:
"""Scheduler configuration."""
max_model_len: InitVar[int]
"""Maximum length of a sequence (including prompt and generated text).
Note: This is stored in the ModelConfig, and is used only here to
provide fallbacks and validate other attributes."""
is_encoder_decoder: InitVar[bool]
"""True if the model is an encoder-decoder model.
Note: This is stored in the ModelConfig, and is used only here to
disable chunked prefill and prefix caching for encoder-decoder models.
"""
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens that can be processed in a single iteration.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_scheduled_tokens: int | None = Field(default=None)
"""Maximum number of tokens that the scheduler may issue in a single iteration.
This is usually equal to max_num_batched_tokens, but can be smaller in cases
when the model might append tokens into the batch (such as speculative decoding).
Defaults to max_num_batched_tokens."""
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
max_long_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump
the queue in front of longer prompts in some cases, improving latency."""
long_prefill_token_threshold: int = 0
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
enable_chunked_prefill: bool = True
"""If True, prefill requests can be chunked based
on the remaining `max_num_batched_tokens`.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
is_multimodal_model: bool = False
"""True if the model is multimodal."""
# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = Field(init=False)
"""Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# TODO (ywang96): Make this configurable.
encoder_cache_size: int = Field(init=False)
"""Multimodal encoder cache size, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
policy: SchedulerPolicy = "fcfs"
"""The scheduling policy to use:\n
- "fcfs" means first come first served, i.e. requests are handled in order
of arrival.\n
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
This ensures that if a request has a mixed prompt
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
# (default) or "mod.custom_class".
scheduler_cls: str | type[object] | None = Field(default=None)
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class"."""
disable_hybrid_kv_cache_manager: bool | None = None
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention.
If set to None, the default value will be determined based on the environment
and starting configuration.
"""
async_scheduling: bool | None = Field(default=None)
"""If set to False, disable async scheduling. Async scheduling helps to
avoid gaps in GPU utilization, leading to better latency and throughput.
"""
stream_interval: int = Field(default=1, ge=1)
"""The interval (or buffer size) for streaming in terms of token length.
A smaller value (1) makes streaming smoother by sending each token immediately,
while a larger value (e.g., 10) reduces host overhead and may increase throughput
by batching multiple tokens before sending."""
@staticmethod
def default_factory(**kwargs):
"""
Factory method to create `SchedulerConfig` with default values for `InitVar`s.
"""
if "max_model_len" not in kwargs:
kwargs["max_model_len"] = 8192
if "is_encoder_decoder" not in kwargs:
kwargs["is_encoder_decoder"] = False
return SchedulerConfig(**kwargs)
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
if self.scheduler_cls is None:
if self.async_scheduling:
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
return AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
return Scheduler
# This warning can be removed once the Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
logger.warning_once(
"Using custom scheduler class %s. This scheduler interface is "
"not public and compatibility may not be maintained.",
self.scheduler_cls,
)
if not isinstance(self.scheduler_cls, str):
return cast(type["SchedulerInterface"], self.scheduler_cls)
return resolve_obj_by_qualname(self.scheduler_cls)
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
# max_num_batched_tokens need to be included in the hash due
# to two reasons:
# 1. LoRA creates static buffers based on max_num_batched_tokens.
# The tensor sizes and strides get captured in the torch.compile
# graph explicitly.
# 2. Inductor decides whether using 32-bit or 64-bit indexing integer
# based on the data sizes. `max_num_batched_tokens` has an
# impact on that. For more details, please check
# https://github.com/vllm-project/vllm/issues/29585
factors.append(self.max_num_batched_tokens)
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
return None if value is None else handler(value)
def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both."
)
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
if self.enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens,
)
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(max_model_len * 0.04)
logger.info(
"Concurrent partial prefills enabled with "
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
"long_prefill_token_threshold=%d",
self.max_num_partial_prefills,
self.max_long_partial_prefills,
self.long_prefill_token_threshold,
)
self.verify_max_model_len(max_model_len)
def verify_max_model_len(self, max_model_len: int) -> Self:
if (
self.max_num_batched_tokens < max_model_len
and not self.enable_chunked_prefill
):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs})."
)
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
self.max_num_seqs * max_model_len,
)
if self.max_num_partial_prefills > 1:
if not self.enable_chunked_prefill:
raise ValueError(
"Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1."
)
if self.long_prefill_token_threshold > max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({max_model_len})."
)
if self.max_long_partial_prefills > self.max_num_partial_prefills:
raise ValueError(
f"{self.max_long_partial_prefills=} must be less than or equal to "
f"{self.max_num_partial_prefills=}."
)
return self

789
vllm/config/speculative.py Normal file
View File

@@ -0,0 +1,789 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING:
from transformers import PretrainedConfig
import vllm.model_executor.layers.quantization as me_quant
else:
PretrainedConfig = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
logger = init_logger(__name__)
MTPModelTypes = Literal[
"deepseek_mtp",
"mimo_mtp",
"glm4_moe_mtp",
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"nemotron_h_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
"qwen3_5_mtp",
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config
class SpeculativeConfig:
"""Configuration for speculative decoding."""
enforce_eager: bool | None = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens: int = Field(default=None, gt=0)
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: str | None = None
"""The name of the draft model, eagle head, or additional weights, if
provided."""
method: SpeculativeMethod | None = None
"""The name of the speculative method to use. If users provide and set the
`model` param, the speculative method type will be detected automatically
if possible, if `model` param is not provided, the method name must be
provided.
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
tensor_parallel_size: int | None = None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
quantization: me_quant.QuantizationMethods | None = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: str | None = None
"""The specific model version to use for the draft model. It can be a
branch name, a tag name, or a commit id. If unspecified, will use the
default version."""
code_revision: str | None = None
"""The specific revision to use for the draft model code on Hugging Face
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version."""
# Advanced control
disable_padded_drafter_batch: bool = False
"""Disable input padding for speculative decoding. If set to True,
speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
use_local_argmax_reduction: bool = False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: int | None = Field(default=None, ge=1)
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
# Alternative drafting strategies
speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation.
"""
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""
# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 24
"""The maximum depth of the suffix decoding global and prompt trees. The
tree depth limits the sum of the prefix match and speculation lengths."""
suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree. If
exceeded, will trigger eviction in FIFO order. If set to 0, the global
suffix tree is disabled and past responses are not cached (prompt trees
are still used)."""
suffix_decoding_max_spec_factor: float = 1.0
"""The maximum spec factor for suffix decoding. The spec factor controls
speculation lengths based on the prefix match length: max_spec_tokens =
max_spec_factor * prefix_match_length."""
suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding. Will only speculate
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
)
if hf_config.model_type in ("pangu_ultra_moe"):
hf_config.model_type = "pangu_ultra_moe_mtp"
if hf_config.model_type == "pangu_ultra_moe_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
)
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"],
}
)
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"],
}
)
if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
hf_config.model_type = "glm4_moe_lite_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["Glm4MoeLiteMTPModel"],
}
)
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
)
if (
hf_config.model_type == "nemotron_h"
and hasattr(hf_config, "num_nextn_predict_layers")
and hf_config.num_nextn_predict_layers > 0
):
# Check if this is an MTP variant
hf_config.model_type = "nemotron_h_mtp"
if hf_config.model_type == "nemotron_h_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update(
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
)
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
)
if hf_config.model_type == "exaone_moe":
hf_config.model_type = "exaone_moe_mtp"
if hf_config.model_type == "exaone_moe_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
)
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
is_moe = hf_config.model_type == "qwen3_5_moe"
hf_config.model_type = "qwen3_5_mtp"
n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
hf_config.update(
{
"n_predict": n_predict,
"architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
}
)
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update(
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
if hf_config.model_type == "step3p5":
hf_config.model_type = "step3p5_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
return hf_config
def __post_init__(self):
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
# infer method from user args
if self.method is None:
if self.model in ("ngram", "[ngram]"):
self.method = "ngram"
else:
self.method = "draft_model"
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method
)
self.method = "mtp"
if self.model is None and self.num_speculative_tokens is not None:
if self.method == "mtp":
if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudagraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
)
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
# Set default values if not provided
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
if self.prompt_lookup_max is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
if self.prompt_lookup_min is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}"
)
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if self.model is not None:
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.trust_remote_code,
allowed_local_media_path=self.target_model_config.allowed_local_media_path,
allowed_media_domains=self.target_model_config.allowed_media_domains,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.tokenizer_revision,
spec_target_max_model_len=self.target_model_config.max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
config_format=self.target_model_config.config_format,
)
# Automatically detect the method
if self.method in ("eagle", "eagle3"):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
",which may result in lower acceptance rate"
)
elif self.draft_model_config.hf_config.model_type in (
"longcat_flash_mtp"
):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have "
"one layer. Might need some code changes "
"to support multiple layers."
)
elif self.method == "draft_model":
pass
else:
raise NotImplementedError(
f"Unsupported speculative method: '{self.method}'"
)
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
from vllm.transformers_utils.configs import SpeculatorsConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(
self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig),
):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle",
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
):
self.draft_model_config.hf_config.num_lookahead_tokens = (
self.num_speculative_tokens
)
n_predict = getattr(
self.draft_model_config.hf_config, "n_predict", None
)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
):
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}"
)
if self.speculative_token_tree is None:
if self.num_speculative_tokens is None:
raise ValueError(
"A speculative model was provided, but neither "
"`speculative_token_tree` nor `num_speculative_tokens` "
"was provided"
)
# Generate chain of tokens.
self.speculative_token_tree = str(
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
)
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t))
)
self.draft_tensor_parallel_size = (
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config,
)
)
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
)
)
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config, self.draft_tensor_parallel_size
)
)
return self
def _validate_suffix_decoding(self):
if not has_arctic_inference():
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Install via `pip install arctic-inference==0.1.1`."
)
if self.num_speculative_tokens is None:
# Suffix decoding decides the actual number of speculative tokens
# dynamically and treats num_speculative_tokens as a maximum limit.
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
logger.warning(
"Defaulted num_speculative_tokens to %s for suffix decoding.",
self.num_speculative_tokens,
)
# Validate values
if self.suffix_decoding_max_tree_depth < 1:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
)
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
)
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
)
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
)
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: int | None,
draft_max_model_len: int,
target_max_model_len: int,
) -> int:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if speculative_max_model_len is not None:
if speculative_max_model_len > draft_max_model_len:
raise ValueError(
f"{speculative_max_model_len=} cannot be "
f"larger than {draft_max_model_len=}"
)
if speculative_max_model_len > target_max_model_len:
raise ValueError(
f"{speculative_max_model_len=} cannot be "
f"larger than {target_max_model_len=}"
)
return speculative_max_model_len
return min(
draft_max_model_len,
target_max_model_len,
)
@staticmethod
def _verify_and_get_draft_tp(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int | None,
draft_hf_config: PretrainedConfig,
) -> int:
"""
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type,
)
else:
speculative_draft_tensor_parallel_size = (
target_parallel_config.tensor_parallel_size
)
elif speculative_draft_tensor_parallel_size not in (
1,
target_parallel_config.tensor_parallel_size,
):
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size"
)
return speculative_draft_tensor_parallel_size
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
)
return draft_parallel_config
@model_validator(mode="after")
def _verify_args(self) -> Self:
if self.tensor_parallel_size is not None:
raise ValueError(
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative model unless the draft model config contains an "
"n_predict parameter."
)
if self.num_speculative_tokens <= 0:
raise ValueError(
"Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens})."
)
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config
)
eagle3_target_supported = [
"llama",
"qwen",
"minicpm",
"gpt_oss",
"hunyuan_vl",
"hunyuan_v1_dense",
"afmoe",
"nemotron_h",
]
if (
self.method == "eagle3"
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported
)
):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}"
)
self.verify_equal_vocab_size_if_draft_model()
return self
def verify_equal_vocab_size_if_draft_model(self):
if (
self.method == "draft_model"
and self.target_model_config is not None
and self.draft_model_config is not None
):
target_vocab_size = self.target_model_config.get_vocab_size()
draft_vocab_size = self.draft_model_config.get_vocab_size()
if target_vocab_size != draft_vocab_size:
raise ValueError(
f"Target and draft model should have the same vocabulary size. "
f"Target model vocab_size={target_vocab_size}. "
f"Draft model vocab_size={draft_vocab_size}. "
f"Using models with different tokenizers can cause out-of-bounds "
f"errors during speculative decoding."
)
@property
def max_num_new_slots_for_drafting(self) -> int:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req = 0 # for serial non-draft-model methods, no change needed
if self.parallel_drafting:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req = self.num_speculative_tokens - 1
if self.uses_draft_model():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req += 1
return slots_per_req
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.utils import config
@config
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int | None = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.
`None` means audio duration can be unlimited and won't be chunked."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: int | None = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return (
self.min_energy_split_window_size is not None
and self.max_audio_clip_s is not None
)

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import model_validator
from typing_extensions import Self
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
StructuredOutputsBackend = Literal[
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
]
@config
class StructuredOutputsConfig:
"""Dataclass which contains structured outputs config for the engine."""
backend: StructuredOutputsBackend = "auto"
"""Which engine will be used for structured outputs (e.g. JSON schema,
regex, etc) by default. With "auto", we will make opinionated choices
based on request contents and what the backend libraries currently support,
so the behavior is subject to change in each release."""
disable_fallback: bool = False
"""If `True`, vLLM will not fallback to a different backend on error."""
disable_any_whitespace: bool = False
"""If `True`, json output will always be compact without any whitespace.
If `False`, the model may generate whitespace between JSON fields,
which is still valid JSON. This is only supported for xgrammar
and guidance backends."""
disable_additional_properties: bool = False
"""If `True`, the `guidance` backend will not use `additionalProperties`
in the JSON schema. This is only supported for the `guidance` backend and
is used to better align its behaviour with `outlines` and `xgrammar`."""
reasoning_parser: str = ""
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format."""
reasoning_parser_plugin: str = ""
"""Path to a dynamically reasoning parser plugin that can be dynamically
loaded and registered."""
enable_in_reasoning: bool = False
"""Whether to use structured input for reasoning."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")
def _validate_structured_output_config(self) -> Self:
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
raise ValueError(
"disable_any_whitespace is only supported for "
"xgrammar and guidance backends."
)
if self.disable_additional_properties and self.backend != "guidance":
raise ValueError(
"disable_additional_properties is only supported "
"for the guidance backend."
)
return self

447
vllm/config/utils.py Normal file
View File

@@ -0,0 +1,447 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for vLLM config dataclasses."""
import ast
import enum
import hashlib
import inspect
import json
import os
import pathlib
import textwrap
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, field, fields, is_dataclass
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo
from typing_extensions import dataclass_transform, runtime_checkable
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:
DataclassInstance = Any
ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
@dataclass_transform(field_specifiers=(PydanticField,))
def config(
cls: type[ConfigT] | None = None,
*,
config: ConfigDict | None = None,
**kwargs: Any,
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
"""Decorator to create a pydantic dataclass with default config. The default config
for the dataclass forbids extra fields.
All config classes in vLLM should use this decorator.
Args:
cls: The class to decorate
config: The pydantic ConfigDict to use. If provided, it will be merged with
the default config.
**kwargs: Additional arguments to pass to pydantic.dataclass."""
# Extra fields are forbidden by default
merged_config = ConfigDict(extra="forbid")
if config is not None:
merged_config.update(config)
def decorator(cls):
return dataclass(cls, config=merged_config, **kwargs)
# Called with arguments: @config(config=...)
if cls is None:
return decorator
# Called without arguments: @config
return decorator(cls)
def get_field(cls: ConfigType, name: str) -> Any:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
try:
named_field = next(f for f in fields(cls) if f.name == name)
except StopIteration as e:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
# The arguments to copy to the new field
default = named_field.default
default_factory = named_field.default_factory
init = named_field.init
# Handle pydantic.Field
if isinstance(default, FieldInfo):
if default.init is not None:
init = default.init
if default.default_factory is not None:
default_factory = cast(Callable[[], Any], default.default_factory)
default = MISSING
else:
default = default.default
if default is MISSING and default_factory is MISSING:
logger.warning_once(
"%s.%s has no default or default factory.", cls.__name__, name
)
return field(default=default, default_factory=default_factory, init=init)
def is_init_field(cls: ConfigType, name: str) -> bool:
return get_field(cls, name).init
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
of `dataclasses.field`"""
cls = type(dataclass_instance)
dataclass_dict = dataclass_instance.__dict__
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
dataclass_dict.update(kwargs)
return cls(**dataclass_dict)
def getattr_iter(
object: object,
names: Sequence[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
"""
for i, name in enumerate(names):
if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name)
return default_factory() if default_factory is not None else default
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
https://davidism.com/mit-license/
"""
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (
not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)
):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str: ...
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> dict[str, str]: ...
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(config, field_name), (
f"{type(config)} has no field `{field_name}`"
)
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}"
)
value = update_config(
current_value, # type: ignore[type-var]
value,
)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
def normalize_value(x):
"""Return a stable, JSON-serializable canonical form for hashing.
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
generic containers (Mapping/Set/Sequence) with recursion.
"""
# Fast path
if x is None or isinstance(x, (bool, int, float, str)):
return x
# Enums: tag with FQN to avoid primitive collisions.
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
if isinstance(x, enum.Enum):
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
return (enum_type, normalize_value(x.value))
# Classes (types) are accepted and canonicalized by their fully-qualified
# name (module.qualname) for a stable identifier.
# Instances are only accepted if they expose uuid(); otherwise they are
# rejected to avoid under-hashing object state.
# Callables: accept classes only; reject funcs/lambdas/methods.
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
if isinstance(x, type):
module = getattr(x, "__module__", "")
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
return ".".join([p for p in (module, qual) if p]) or repr(x)
# Prefer stable uuid identifiers for objects that provide them, even if
# they are callable instances (e.g., InductorPass wrappers).
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
return x.uuid()
if callable(x):
raise TypeError("normalize_value: function or callable instance unsupported")
# Torch dtype: stringify (torch.float64 -> "torch.float64").
# We rely on the string form here; dtype-bearing fields that need additional
# disambiguation should encode that at the config layer.
if isinstance(x, torch.dtype):
return str(x)
# Bytes
if isinstance(x, (bytes, bytearray)):
return x.hex()
# Paths (canonicalize)
if isinstance(x, pathlib.Path):
try:
return str(x.expanduser().resolve())
except Exception:
return str(x)
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
if is_dataclass(x):
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
items = tuple(
(f.name, normalize_value(getattr(x, f.name)))
for f in sorted(fields(x), key=lambda f: f.name)
)
return (type_fqn, items)
# Containers (generic)
if isinstance(x, Mapping):
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
if isinstance(x, Set):
return tuple(sorted(repr(normalize_value(v)) for v in x))
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
return tuple(normalize_value(v) for v in x)
# PretrainedConfig
if hasattr(x, "to_json_string") and callable(x.to_json_string):
return x.to_json_string()
# Unsupported type: e.g., modules, generators, open files, or objects
# without a stable JSON/UUID representation. Hard-error to avoid
# under-hashing.
# If you hit this, either reshape your config to use supported primitives
# and containers, or extend normalize_value to provide a stable encoding
# (e.g., via uuid() or to_json_string()) for this type.
raise TypeError(
f"normalize_value: unsupported type '{type(x).__name__}'. "
"Ensure config values use supported primitives/containers or add a "
"stable representation for this type."
)
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
"""Gets the factors used for hashing a config class.
- Includes all dataclass fields not in `ignored_factors`.
- Errors on non-normalizable values.
"""
factors: dict[str, object] = {}
for dc_field in fields(config):
factor = dc_field.name
if factor in ignored_factors:
continue
value = getattr(config, factor, None)
try:
factors[factor] = normalize_value(value)
except TypeError as e:
raise TypeError(
f"get_hash_factors: unsupported type for key '{factor}' "
f"({type(value).__name__})"
) from e
return factors
def hash_factors(items: dict[str, object]) -> str:
"""Return a SHA-256 hex digest of the canonical items structure."""
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
@dataclass
class Range:
"""
A range of numbers.
Inclusive of start, inclusive of end.
"""
start: int
end: int
def is_single_size(self) -> bool:
return self.start == self.end
def __contains__(self, size: int) -> bool:
# Inclusive of start, inclusive of end
return self.start <= size <= self.end
def __eq__(self, other: object) -> bool:
if not isinstance(other, Range):
return False
return self.start == other.start and self.end == other.end
def __hash__(self) -> int:
return hash((self.start, self.end))
def __str__(self) -> str:
return f"({self.start}, {self.end})"
def __repr__(self) -> str:
return self.__str__()
def handle_deprecated(
config: ConfigT,
old_name: str,
new_name_or_names: str | list[str],
removal_version: str,
) -> None:
old_val = getattr(config, old_name)
if old_val is None:
return
if isinstance(new_name_or_names, str):
new_names = [new_name_or_names]
else:
new_names = new_name_or_names
msg = (
f"{old_name} is deprecated and will be removed in {removal_version}. "
f"Use {', '.join(new_names)} instead."
)
logger.warning(msg)
for new_name in new_names:
setattr(config, new_name, old_val)
def get_from_deprecated_env_if_set(
env_name: str,
removal_version: str,
field_name: str | None = None,
) -> str | None:
"""
Get value from deprecated environment variable with warning.
Args:
env_name: Name of the deprecated environment variable
removal_version: Version when it will be removed
field_name: Name of the field to suggest as alternative
Returns:
The environment variable value if set, None otherwise
"""
if envs.is_set(env_name):
value = os.environ.get(env_name)
alt_msg = f" Please use {field_name} instead." if field_name else ""
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in %s.%s",
env_name,
removal_version,
alt_msg,
)
return value
return None
def set_from_deprecated_env_if_set(
config: ConfigT,
env_name: str,
removal_version: str,
field_name: str,
to_bool: bool = False,
to_int: bool = False,
) -> None:
"""
Set object field from deprecated environment variable with warning.
Args:
config: Config object to set the field on
env_name: Name of the deprecated environment variable
removal_version: Version when the env var will be removed
field_name: Name of the field to set
to_bool: Whether to convert the environment variable value to boolean
to_int: Whether to convert the environment variable value to integer
Returns:
None
"""
if to_bool and to_int:
raise ValueError("Cannot convert to both boolean and integer.")
env_value = get_from_deprecated_env_if_set(env_name, removal_version, field_name)
if env_value is not None:
field_value: str | bool | int = env_value
if to_bool:
field_value = env_value.lower() in ("1", "true")
elif to_int:
field_value = int(env_value)
setattr(config, field_name, field_value)

Some files were not shown because too many files have changed in this diff Show More